Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Overview

Tez: a simple pytorch trainer

NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something doesn't work, please create an issue.

tez (तेज़ / تیز) means sharp, fast & active. This is a simple, to-the-point, library to make your pytorch training easy.

This library is in very early-stage currently! So, there might be breaking changes.

Idea around tez is simple:

  • keep things as simple as possible
  • make it as customizable as possible
  • clean code
  • faster prototyping
  • production ready

Currently, tez supports cpu and gpu training. More coming soon!

Using tez is super-easy. We don't want you to be far away from pytorch. So, you do everything on your own and just use tez to make a few things simpler.

Training using Tez:

  • To train a model, define a dataset and model. The dataset class is the same old class you would write when writing pytorch models.

  • Create your model class. Instead of inheriting from nn.Module, import tez and inherit from tez.Model as shown in the following example.

class MyModel(tez.Model):
    def __init__(self):
        super().__init__()
        .
        .
        # tell when to step the scheduler
        self.step_scheduler_after="batch"

    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}

    def fetch_scheduler(self):
        # create your own scheduler

    def fetch_optimizer(self):
        # create your own optimizer

    def forward(self, ids, mask, token_type_ids, targets=None):
        _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        b_o = self.bert_drop(o_2)
        output = self.out(b_o)

        # calculate loss here
        loss = nn.BCEWithLogitsLoss()(output, targets)

        # calculate the metric dictionary here
        metric_dict = self.monitor_metrics(output, targets)
        return output, loss, metric_dict

Everything is super-intuitive!

  • Now you can train your model!
# init datasets
train_dataset = SomeTrainDataset()
valid_dataset = SomeValidDataset()

# init model
model = MyModel()


# init callbacks, you can also write your own callback
es = tez.callbacks.EarlyStopping(monitor="valid_loss", model_path="model.bin")

# train model. a familiar api!
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    device="cuda",
    epochs=50,
    callbacks=[es],
    fp16=True,
)

# save model (with optimizer and scheduler for future!)
model.save("model.bin")

You can checkout examples in examples/

Comments
  • ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)

    Could you please help here.

    Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth 100% 74.4M/74.4M [00:00<00:00, 107MB/s]

    Loaded pretrained weights for efficientnet-b4 0%| | 0/51 [00:00<?, ?it/s]

    ValueError Traceback (most recent call last) in () 11 epochs=10, 12 callbacks=[es], ---> 13 fp16=True, 14 ) 15 model.save("model.bin")

    6 frames /usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16) 295 self.train_state = enums.TrainingState.EPOCH_START 296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START --> 297 train_loss = self.train_one_epoch(self.train_loader, device) 298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END 299 if self.valid_loader:

    /usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device) 176 losses = AverageMeter() 177 tk0 = tqdm(data_loader, total=len(data_loader)) --> 178 for b_idx, data in enumerate(tk0): 179 self.train_state = enums.TrainingState.TRAIN_STEP_START 180 loss, metrics = self.train_one_step(data, device)

    /usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self) 1102 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1103 -> 1104 for obj in iterable: 1105 yield obj 1106 # Update and possibly print the progressbar.

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self):

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113

    /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430

    ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem augmented = self.augmentations(image=image) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call data = t(**data) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call res[key] = target_function(arg, **dict(params, **target_dependencies)) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply return F.normalize(image, self.mean, self.std, self.max_pixel_value) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize img -= mean ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    opened by nvnvashisth 10
  • zero_grad for accumulation_steps = 1 not working as expected

    zero_grad for accumulation_steps = 1 not working as expected

    As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

    I tried to reproduce it and it is doing the same which I explained above.

    image

    Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

    opened by abdurrehman11 9
  • Can it work without CUDA

    Can it work without CUDA

    I am getting error when I executed the code with CPU configuration.

    Traceback (most recent call last): File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 88, in train() File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 82, in train model.fit( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 309, in fit self._init_model( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 93, in _init_model self.to(self.device) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 852, in to return self._apply(convert) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 530, in _apply module._apply(fn) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 552, in apply param_applied = fn(param) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 850, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\cuda_init.py", line 166, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled

    opened by hemanthh17 7
  • Documentation improvement - How is tez faster?

    Documentation improvement - How is tez faster?

    Great to see a nice Pytorch training library.

    I think it would help users use it maybe to show what kind of performance improvements come from the box with Tez. For example comparing how fp16 is enabled in tez vs vanilla pytorch could ben informative or just a quick list of optimisations that are easy to do with Tez such as fp16.

    opened by swartchris8 5
  • Is it possible to set variable Lr per epoch

    Is it possible to set variable Lr per epoch

    @abhishekkrthakur Was finding this framework great and easy to use . But as fairly new to it was thinking if there is a way to pass variable Lr for training say for every epoch as an example.

    Also is there a way to say continue training from a particular epoch if say the local system crashed or got disturbed during the training process.

    opened by gauravbrills 3
  • Applying metrics after the epoch

    Applying metrics after the epoch

    Dears, I am using tez to classify melanoma images (kaggle SIIM binary classification). With wtfml is possible to get AUC ~ 0.85. With tez, I am only getting AUC ~ 0.6. I saw that this happens, in tez, when using metrics.roc_auc_score(...) inside monitor_metrics method. This gives some ValueError exceptions, that must be handled returning auc = 0.5 (this error occurs when the data have only 1 class).

    In the wtfml, the metrics.roc_auc_score(...) method is used only after Engine.evaluate. In this case, the data always have two classes (because the KStratified gives that).

    I am wondering if it is possible, in tez, to apply the metrics.roc_auc_score(...) only after the epoch, and not in each train_bs. With that, the data always will have two classes, avoiding the ValueError exceptions.

    PS.

    1. In the class definition init I am using: self.step_scheduler_after = "epoch" self.step_scheduler_metric = "valid_auc"
    2. In the monitor_metrics method: try: auc = metrics.roc_auc_score(targets, outputs.ravel()) except ValueError: auc = 0.5 return {"auc": auc}
    3. My model.fit is defined as: model.fit(train_dataset, valid_dataset=valid_dataset, train_bs=32, valid_bs=16, device="cuda", epochs=50, callbacks=[es], fp16=False, n_jobs=2)
    opened by waldcarl 2
  • Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    this problem occur due to running metric calculation

    I got the solution from stackoverflow:

    You cannot have an ROC curve without both positive and negative examples in your dataset. With only one class in the dataset, you cannot measure your false-positive rate, and therefore cannot plot an ROC curve. This is why you get this error message.

    How to handle this problem?

    opened by IamSantoshKumar 2
  • Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    /usr/local/lib/python3.7/dist-packages/torch/cuda/amp/grad_scaler.py:116: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") 0%| | 0/2939 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py:118: UserWarning: torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.")

    TypeError Traceback (most recent call last) in () 143 epochs=3, 144 callbacks=[tb_logger, es], --> 145 fp16=True, 146 ) 147 model.save("model.bin")

    8 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace) 1074 if p < 0.0 or p > 1.0: 1075 raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) -> 1076 return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) 1077 1078

    TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    opened by gokulguptanew 1
  • Text classification examples - Tokenizer is defined twice

    Text classification examples - Tokenizer is defined twice

    The tokenizer is defined both in the model and the dataset in the BERT text classification examples.

    multi_class.py, line 50: self.tokenizer = transformers.BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True )

    opened by obesp 1
  • Small error in image_classification.py

    Small error in image_classification.py

    If augmentation is None then we face error as , variable augmented referenced before assignment UnboundLocalError: local variable 'augmented' referenced before assignment

    elif self.backend == "cv2":
                image = cv2.imread(self.image_paths[item])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.resize is not None:
                    image = cv2.resize(
                        image,
                        (self.resize[1], self.resize[0]),
                        interpolation=cv2.INTER_CUBIC,
                    )
                if self.augmentations is not None:
                    augmented = self.augmentations(image=image)
                    image = augmented["image"]
    

    If the indendation is fixed we can solve this error.

    opened by VpkPrasanna 1
  • Small error in model.py

    Small error in model.py

    Hi! Love this library.
    In tez/model/model.py there is probably a mistake in line 90:

    self.train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=train_bs,
                    num_workers=n_jobs,
                    sampler=valid_sampler,
                    shuffle=True,
                )
    

    I guess train_sampler is meant to be used here, not valid_sampler.

    opened by hocop 1
  • run example code error

    run example code error

    when I run example code:

    accelerate launch   imdb_sentiment_classification.py
    

    after run some epoch get error info

    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 4/5
    [train] accuracy=0.9915, loss=0.0269 [valid] accuracy=0.8953, loss=0.4287 [e=5 steps=2112]                                                                                                 
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:45<06:40, 12.32it/s, accuracy=0.991, epoch=5, loss=0.0269]2022-09-17 07:55:02,832 INFO EarlyStopping counter: 5/5
    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 5/5
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:47<13:31,  6.07it/s, accuracy=0.991, epoch=5, loss=0.0269]
    
    
    
    
    [E ProcessGroupNCCL.cpp:719] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 113654 closing signal SIGTERM
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 2 (pid: 113655) of binary: /root/miniconda3/envs/lightning/bin/python
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    =======================================================
    imdb_sentiment_classification.py FAILED
    -------------------------------------------------------
    Failures:
    [1]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 3 (local_rank: 3)
      exitcode  : -6 (pid: 113656)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113656
    -------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 2 (local_rank: 2)
      exitcode  : -6 (pid: 113655)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113655
    =======================================================
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/accelerate", line 33, in <module>
        sys.exit(load_entry_point('accelerate==0.12.0.dev0', 'console_scripts', 'accelerate')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/accelerate_cli.py", line 43, in main
        args.func(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 734, in launch_command
        multi_gpu_launcher(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 374, in multi_gpu_launcher
        raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
    subprocess.CalledProcessError: Command '['torchrun', '--nproc_per_node', '4', 'imdb_sentiment_classification.py']' returned non-zero exit status 1.
    
    opened by bestpredicts 0
  • Getting error while importing enums from tez.

    Getting error while importing enums from tez.

    Traceback (most recent call last): File "/content/tez/tez/model/model.py", line 12, in from tez import enums File "/content/tez/tez/model/tez.py", line 11, in from tez import enums ImportError: cannot import name 'enums' from 'tez' (/content/tez/tez/model/tez.py)

    Waiting for positive reply.

    opened by VikasRathod314 3
  • Saving validation score

    Saving validation score

    Is it possible to save somehow a list of the validation scores (on epochs or batches) after training? I have some problems with output on my server, it deletes usually, but I really need validation scores to compare models, it would be really convenient, if I could get them in one file, for example.

    opened by 25icecreamflavors 0
  • Saving after training an epoch

    Saving after training an epoch

    How to save the model after each epoch training? I use fit method for 5 epochs and do not really understand hot to save after each one. not only after the last one.

    opened by 25icecreamflavors 2
  • How can we access the input_ids/attention mask in each train batch loop?

    How can we access the input_ids/attention mask in each train batch loop?

    I tried using a train step callback but I am not sure how to get access to the dataloader input_ids and attention mask during each train step. Is this possible?

    BTW Thanks for the library!

    opened by tkmaker 0
Releases(v0.1.8)
Owner
abhishek thakur
Kaggle: www.kaggle.com/abhishek
abhishek thakur
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
Riemannian Adaptive Optimization Methods with pytorch optim

geoopt Manifold aware pytorch.optim. Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more. Installation Make sur

642 Jan 03, 2023
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation.

PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation. It aims to accelerate research by providing a modular design that all

Preferred Networks, Inc. 96 Nov 28, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 05, 2023
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

1k Dec 28, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022