Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Overview

Tez: a simple pytorch trainer

NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something doesn't work, please create an issue.

tez (तेज़ / تیز) means sharp, fast & active. This is a simple, to-the-point, library to make your pytorch training easy.

This library is in very early-stage currently! So, there might be breaking changes.

Idea around tez is simple:

  • keep things as simple as possible
  • make it as customizable as possible
  • clean code
  • faster prototyping
  • production ready

Currently, tez supports cpu and gpu training. More coming soon!

Using tez is super-easy. We don't want you to be far away from pytorch. So, you do everything on your own and just use tez to make a few things simpler.

Training using Tez:

  • To train a model, define a dataset and model. The dataset class is the same old class you would write when writing pytorch models.

  • Create your model class. Instead of inheriting from nn.Module, import tez and inherit from tez.Model as shown in the following example.

class MyModel(tez.Model):
    def __init__(self):
        super().__init__()
        .
        .
        # tell when to step the scheduler
        self.step_scheduler_after="batch"

    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}

    def fetch_scheduler(self):
        # create your own scheduler

    def fetch_optimizer(self):
        # create your own optimizer

    def forward(self, ids, mask, token_type_ids, targets=None):
        _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        b_o = self.bert_drop(o_2)
        output = self.out(b_o)

        # calculate loss here
        loss = nn.BCEWithLogitsLoss()(output, targets)

        # calculate the metric dictionary here
        metric_dict = self.monitor_metrics(output, targets)
        return output, loss, metric_dict

Everything is super-intuitive!

  • Now you can train your model!
# init datasets
train_dataset = SomeTrainDataset()
valid_dataset = SomeValidDataset()

# init model
model = MyModel()


# init callbacks, you can also write your own callback
es = tez.callbacks.EarlyStopping(monitor="valid_loss", model_path="model.bin")

# train model. a familiar api!
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    device="cuda",
    epochs=50,
    callbacks=[es],
    fp16=True,
)

# save model (with optimizer and scheduler for future!)
model.save("model.bin")

You can checkout examples in examples/

Comments
  • ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)

    Could you please help here.

    Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth 100% 74.4M/74.4M [00:00<00:00, 107MB/s]

    Loaded pretrained weights for efficientnet-b4 0%| | 0/51 [00:00<?, ?it/s]

    ValueError Traceback (most recent call last) in () 11 epochs=10, 12 callbacks=[es], ---> 13 fp16=True, 14 ) 15 model.save("model.bin")

    6 frames /usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16) 295 self.train_state = enums.TrainingState.EPOCH_START 296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START --> 297 train_loss = self.train_one_epoch(self.train_loader, device) 298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END 299 if self.valid_loader:

    /usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device) 176 losses = AverageMeter() 177 tk0 = tqdm(data_loader, total=len(data_loader)) --> 178 for b_idx, data in enumerate(tk0): 179 self.train_state = enums.TrainingState.TRAIN_STEP_START 180 loss, metrics = self.train_one_step(data, device)

    /usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self) 1102 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1103 -> 1104 for obj in iterable: 1105 yield obj 1106 # Update and possibly print the progressbar.

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self):

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113

    /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430

    ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem augmented = self.augmentations(image=image) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call data = t(**data) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call res[key] = target_function(arg, **dict(params, **target_dependencies)) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply return F.normalize(image, self.mean, self.std, self.max_pixel_value) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize img -= mean ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    opened by nvnvashisth 10
  • zero_grad for accumulation_steps = 1 not working as expected

    zero_grad for accumulation_steps = 1 not working as expected

    As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

    I tried to reproduce it and it is doing the same which I explained above.

    image

    Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

    opened by abdurrehman11 9
  • Can it work without CUDA

    Can it work without CUDA

    I am getting error when I executed the code with CPU configuration.

    Traceback (most recent call last): File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 88, in train() File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 82, in train model.fit( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 309, in fit self._init_model( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 93, in _init_model self.to(self.device) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 852, in to return self._apply(convert) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 530, in _apply module._apply(fn) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 552, in apply param_applied = fn(param) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 850, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\cuda_init.py", line 166, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled

    opened by hemanthh17 7
  • Documentation improvement - How is tez faster?

    Documentation improvement - How is tez faster?

    Great to see a nice Pytorch training library.

    I think it would help users use it maybe to show what kind of performance improvements come from the box with Tez. For example comparing how fp16 is enabled in tez vs vanilla pytorch could ben informative or just a quick list of optimisations that are easy to do with Tez such as fp16.

    opened by swartchris8 5
  • Is it possible to set variable Lr per epoch

    Is it possible to set variable Lr per epoch

    @abhishekkrthakur Was finding this framework great and easy to use . But as fairly new to it was thinking if there is a way to pass variable Lr for training say for every epoch as an example.

    Also is there a way to say continue training from a particular epoch if say the local system crashed or got disturbed during the training process.

    opened by gauravbrills 3
  • Applying metrics after the epoch

    Applying metrics after the epoch

    Dears, I am using tez to classify melanoma images (kaggle SIIM binary classification). With wtfml is possible to get AUC ~ 0.85. With tez, I am only getting AUC ~ 0.6. I saw that this happens, in tez, when using metrics.roc_auc_score(...) inside monitor_metrics method. This gives some ValueError exceptions, that must be handled returning auc = 0.5 (this error occurs when the data have only 1 class).

    In the wtfml, the metrics.roc_auc_score(...) method is used only after Engine.evaluate. In this case, the data always have two classes (because the KStratified gives that).

    I am wondering if it is possible, in tez, to apply the metrics.roc_auc_score(...) only after the epoch, and not in each train_bs. With that, the data always will have two classes, avoiding the ValueError exceptions.

    PS.

    1. In the class definition init I am using: self.step_scheduler_after = "epoch" self.step_scheduler_metric = "valid_auc"
    2. In the monitor_metrics method: try: auc = metrics.roc_auc_score(targets, outputs.ravel()) except ValueError: auc = 0.5 return {"auc": auc}
    3. My model.fit is defined as: model.fit(train_dataset, valid_dataset=valid_dataset, train_bs=32, valid_bs=16, device="cuda", epochs=50, callbacks=[es], fp16=False, n_jobs=2)
    opened by waldcarl 2
  • Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    this problem occur due to running metric calculation

    I got the solution from stackoverflow:

    You cannot have an ROC curve without both positive and negative examples in your dataset. With only one class in the dataset, you cannot measure your false-positive rate, and therefore cannot plot an ROC curve. This is why you get this error message.

    How to handle this problem?

    opened by IamSantoshKumar 2
  • Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    /usr/local/lib/python3.7/dist-packages/torch/cuda/amp/grad_scaler.py:116: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") 0%| | 0/2939 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py:118: UserWarning: torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.")

    TypeError Traceback (most recent call last) in () 143 epochs=3, 144 callbacks=[tb_logger, es], --> 145 fp16=True, 146 ) 147 model.save("model.bin")

    8 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace) 1074 if p < 0.0 or p > 1.0: 1075 raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) -> 1076 return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) 1077 1078

    TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    opened by gokulguptanew 1
  • Text classification examples - Tokenizer is defined twice

    Text classification examples - Tokenizer is defined twice

    The tokenizer is defined both in the model and the dataset in the BERT text classification examples.

    multi_class.py, line 50: self.tokenizer = transformers.BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True )

    opened by obesp 1
  • Small error in image_classification.py

    Small error in image_classification.py

    If augmentation is None then we face error as , variable augmented referenced before assignment UnboundLocalError: local variable 'augmented' referenced before assignment

    elif self.backend == "cv2":
                image = cv2.imread(self.image_paths[item])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.resize is not None:
                    image = cv2.resize(
                        image,
                        (self.resize[1], self.resize[0]),
                        interpolation=cv2.INTER_CUBIC,
                    )
                if self.augmentations is not None:
                    augmented = self.augmentations(image=image)
                    image = augmented["image"]
    

    If the indendation is fixed we can solve this error.

    opened by VpkPrasanna 1
  • Small error in model.py

    Small error in model.py

    Hi! Love this library.
    In tez/model/model.py there is probably a mistake in line 90:

    self.train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=train_bs,
                    num_workers=n_jobs,
                    sampler=valid_sampler,
                    shuffle=True,
                )
    

    I guess train_sampler is meant to be used here, not valid_sampler.

    opened by hocop 1
  • run example code error

    run example code error

    when I run example code:

    accelerate launch   imdb_sentiment_classification.py
    

    after run some epoch get error info

    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 4/5
    [train] accuracy=0.9915, loss=0.0269 [valid] accuracy=0.8953, loss=0.4287 [e=5 steps=2112]                                                                                                 
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:45<06:40, 12.32it/s, accuracy=0.991, epoch=5, loss=0.0269]2022-09-17 07:55:02,832 INFO EarlyStopping counter: 5/5
    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 5/5
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:47<13:31,  6.07it/s, accuracy=0.991, epoch=5, loss=0.0269]
    
    
    
    
    [E ProcessGroupNCCL.cpp:719] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 113654 closing signal SIGTERM
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 2 (pid: 113655) of binary: /root/miniconda3/envs/lightning/bin/python
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    =======================================================
    imdb_sentiment_classification.py FAILED
    -------------------------------------------------------
    Failures:
    [1]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 3 (local_rank: 3)
      exitcode  : -6 (pid: 113656)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113656
    -------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 2 (local_rank: 2)
      exitcode  : -6 (pid: 113655)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113655
    =======================================================
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/accelerate", line 33, in <module>
        sys.exit(load_entry_point('accelerate==0.12.0.dev0', 'console_scripts', 'accelerate')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/accelerate_cli.py", line 43, in main
        args.func(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 734, in launch_command
        multi_gpu_launcher(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 374, in multi_gpu_launcher
        raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
    subprocess.CalledProcessError: Command '['torchrun', '--nproc_per_node', '4', 'imdb_sentiment_classification.py']' returned non-zero exit status 1.
    
    opened by bestpredicts 0
  • Getting error while importing enums from tez.

    Getting error while importing enums from tez.

    Traceback (most recent call last): File "/content/tez/tez/model/model.py", line 12, in from tez import enums File "/content/tez/tez/model/tez.py", line 11, in from tez import enums ImportError: cannot import name 'enums' from 'tez' (/content/tez/tez/model/tez.py)

    Waiting for positive reply.

    opened by VikasRathod314 3
  • Saving validation score

    Saving validation score

    Is it possible to save somehow a list of the validation scores (on epochs or batches) after training? I have some problems with output on my server, it deletes usually, but I really need validation scores to compare models, it would be really convenient, if I could get them in one file, for example.

    opened by 25icecreamflavors 0
  • Saving after training an epoch

    Saving after training an epoch

    How to save the model after each epoch training? I use fit method for 5 epochs and do not really understand hot to save after each one. not only after the last one.

    opened by 25icecreamflavors 2
  • How can we access the input_ids/attention mask in each train batch loop?

    How can we access the input_ids/attention mask in each train batch loop?

    I tried using a train step callback but I am not sure how to get access to the dataloader input_ids and attention mask during each train step. Is this possible?

    BTW Thanks for the library!

    opened by tkmaker 0
Releases(v0.1.8)
Owner
abhishek thakur
Kaggle: www.kaggle.com/abhishek
abhishek thakur
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

1k Dec 28, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
270 Dec 24, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023