torchbearer: A model fitting library for PyTorch

Overview

Note: We're moving to PyTorch Lightning! Read about the move here. From the end of February, torchbearer will no longer be actively maintained. We'll continue to fix bugs when they are found and ensure that torchbearer runs on new versions of pytorch. However, we won't plan or implement any new functionality (if there's something you'd like to see in a training library, consider creating an issue on PyTorch Lightning).

logo

PyPI version Python 2.7 | 3.5 | 3.6 | 3.7 PyTorch 1.0.0 | 1.1.0 | 1.2.0 | 1.3.0 | 1.4.0 Build Status codecov Documentation Status Downloads

WebsiteDocsExamplesInstallCitingRelated

A PyTorch model fitting library designed for use by researchers (or anyone really) working in deep learning or differentiable programming. Specifically, we aim to dramatically reduce the amount of boilerplate code you need to write without limiting the functionality and openness of PyTorch.

Examples

General

Quickstart: Get up and running with torchbearer, training a simple CNN on CIFAR-10.
Callbacks: A detailed exploration of callbacks in torchbearer, with some useful visualisations.
Imaging: A detailed exploration of the imaging sub-package in torchbearer, useful for showing visualisations during training.
Serialization: This guide gives an introduction to serializing and restarting training in torchbearer.
History and Replay: This guide gives an introduction to the history returned by a trial and the ability to replay training.
Custom Data Loaders: This guide gives an introduction on how to run custom data loaders in torchbearer.
Data Parallel: This guide gives an introduction to using torchbearer with DataParrallel.
LiveLossPlot: A demonstration of the LiveLossPlot callback included in torchbearer.
PyCM: A demonstration of the PyCM callback included in torchbearer for generating confusion matrices.
NVIDIA Apex: A guide showing how to perform half and mixed precision training in torchbearer with NVIDIA Apex.

Deep Learning

Training a VAE: A demonstration of how to train (add do a simple visualisation of) a Variational Auto-Encoder (VAE) on MNIST with torchbearer.
Training a GAN: A demonstration of how to train (add do a simple visualisation of) a Generative Adversarial Network (GAN) on MNIST with torchbearer.
Generating Adversarial Examples: A demonstration of how to perform a simple adversarial attack with torchbearer.
Transfer Learning with Torchbearer: A demonstration of how to perform transfer learning on STL10 with torchbearer.
Regularisers in Torchbearer: A demonstration of how to use all of the built-in regularisers in torchbearer (Mixup, CutOut, CutMix, Random Erase, Label Smoothing and Sample Pairing).
Manifold Mixup: A demonstration of how to use the Manifold Mixup callback in Torchbearer.
Class Appearance Model: A demonstration of the Class Appearance Model (CAM) callback in torchbearer.

Differentiable Programming

Optimising Functions: An example (and some fun visualisations) showing how torchbearer can be used for the purpose of optimising functions with respect to their parameters using gradient descent.
Linear SVM: Train a linear support vector machine (SVM) using torchbearer, with an interactive visualisation!
Breaking Adam: The Adam optimiser doesn't always converge, in this example we reimplement some of the function optimisations from the AMSGrad paper showing this empirically.

Install

The easiest way to install torchbearer is with pip:

pip install torchbearer

Alternatively, build from source with:

pip install git+https://github.com/pytorchbearer/torchbearer

Citing Torchbearer

If you find that torchbearer is useful to your research then please consider citing our preprint: Torchbearer: A Model Fitting Library for PyTorch, with the following BibTeX entry:

@article{torchbearer2018,
  author = {Ethan Harris and Matthew Painter and Jonathon Hare},
  title = {Torchbearer: A Model Fitting Library for PyTorch},
  journal  = {arXiv preprint arXiv:1809.03363},
  year = {2018}
}

Related

Torchbearer isn't the only library for training PyTorch models. Here are a few others that might better suit your needs (this is by no means a complete list, see the awesome pytorch list or the incredible pytorch for more):

  • skorch, model wrapper that enables use with scikit-learn - crossval etc. can be very useful
  • PyToune, simple Keras style API
  • ignite, advanced model training from the makers of PyTorch, can need a lot of code for advanced functions (e.g. Tensorboard)
  • TorchNetTwo (TNT), can be complex to use but well established, somewhat replaced by ignite
  • Inferno, training utilities and convenience classes for PyTorch
  • Pytorch Lightning, lightweight wrapper on top of PyTorch with advanced multi-gpu and cluster support
  • Pywick, high-level training framework, based on torchsample, support for various segmentation models
Comments
  • RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle

    RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle

    Dear all,

    It seems that torchbearer does not want to work for me. I am trying to simply classify images using resnet. You can find my code here (https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Template/tree/feature/cuda-error), the main training logic is:

    import time
    from comet_ml import Experiment
    import torchbearer
    import torch.optim as optim
    import torch.nn as nn
    from torchsummary import summary
    from Project import Project
    from data import get_dataloaders
    from data.transformation import train_transform, val_transform
    from models import MyCNN, resnet18
    from utils import device, show_dl
    from torchbearer import Trial
    from torchbearer.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
    from callbacks import CometCallback
    from logger import logging
    
    if __name__ == '__main__':
        project = Project()
        # our hyperparameters
        params = {
            'lr': 0.001,
            'batch_size': 64,
            'epochs': 1,
            'model': 'resnet18-finetune',
            'id': time.time()
        }
    
        logging.info(f'Using device={device} 🚀')
        # everything starts with the data
        train_dl, val_dl, test_dl = get_dataloaders(
            project.data_dir,
            val_transform=val_transform,
            train_transform=train_transform,
            batch_size=params['batch_size'],
            num_workers=4,
        )
        # is always good practice to visualise some of the train and val images to be sure data-aug
        # is applied properly
        # show_dl(train_dl)
        # show_dl(test_dl)
        # define our comet experiment
        experiment = Experiment(api_key='8THqoAxomFyzBgzkStlY95MOf',
                                project_name="dl-pytorch-template", workspace="francescosaveriozuppichini")
        experiment.log_parameters(params)
        # create our special resnet18
        cnn = resnet18(n_classes=2).to(device)
        loss = nn.CrossEntropyLoss()
        # print the model summary to show useful information
        logging.info(summary(cnn, (3, 224, 244)))
        # define custom optimizer and instantiace the trainer `Model`
        optimizer = optim.Adam(cnn.parameters(), lr=params['lr'])
        # create our Trial object to train and evaluate the model
        trial = Trial(cnn, optimizer, loss, metrics=['acc', 'loss'],
                      callbacks=[
                          CometCallback(experiment),
                          ReduceLROnPlateau(monitor='val_loss',
                                            factor=0.1, patience=5),
                          EarlyStopping(monitor='val_acc', patience=5, mode='max'),
                          CSVLogger(str(project.checkpoint_dir / 'history.csv')),
                          ModelCheckpoint(str(project.checkpoint_dir / f'{params["id"]}-best.pt'), monitor='val_acc', mode='max')
        ]).to(device)
        trial.with_generators(train_generator=train_dl,
                              val_generator=val_dl, test_generator=test_dl)
        history = trial.run(epochs=params['epochs'], verbose=1)
        logging.info(history)
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
        logging.info(f'test preds=({preds})')
        # experiment.log_metric('test_acc', test_acc)
    
    

    I am running the same logic (same model) with poutyne and I have no problems. I really would like to switch to torchbearer

    Error is:

    2020-02-03 13:32:03,386 - [INFO] - None
      0%|                                                                                                                                                             | 0/1 [00:00<?, ?it/s]C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [13,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [17,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [20,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [21,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [29,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
    Traceback (most recent call last):
      File "c:/Users/Francesco/Documents/PyTorch-Deep-Learning-Template/main.py", line 64, in <module>
        history = trial.run(epochs=params['epochs'], verbose=1)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 133, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 988, in run
        final_metrics = self._fit_pass(state)[torchbearer.METRICS]
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 298, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1033, in _fit_pass
        state[torchbearer.OPTIMIZER].step(lambda: self.closure(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\optim\adam.py", line 58, in step
        loss = closure()
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1033, in <lambda>
        state[torchbearer.OPTIMIZER].step(lambda: self.closure(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\bases.py", line 382, in closure
        state[loss].backward(**state[torchbearer.BACKWARD_ARGS])
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\comet_ml\monkey_patching.py", line 246, in wrapper
        return_value = original(*args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\tensor.py", line 195, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\autograd\__init__.py", line 99, in backward
        allow_unreachable=True)  # allow_unreachable flag
    RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
    

    Do your library work for you? Do you use it in your daily workflow?

    Thank you.

    Cheers,

    Francesco Saverio

    opened by FrancescoSaverioZuppichini 7
  • Accuracy computation for seq2seq model

    Accuracy computation for seq2seq model

    0/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.75it/s, running_loss=0.0338, running_acc=0.326, loss=0.689, loss_std=1.17, acc=35.9, acc_std=0]
    0/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.90it/s, val_loss=0.0341, val_loss_std=0.0122, val_acc=42, val_acc_std=0]
    1/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.76it/s, running_loss=0.00997, running_acc=0.327, loss=0.019, loss_std=0.0166, acc=41.8, acc_std=0]
    1/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.98it/s, val_loss=0.0126, val_loss_std=0.00798, val_acc=42.1, val_acc_std=0]
    2/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.75it/s, running_loss=0.00493, running_acc=0.328, loss=0.00837, loss_std=0.00938, acc=41.8, acc_std=0]
    2/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.89it/s, val_loss=0.00783, val_loss_std=0.00716, val_acc=42.2, val_acc_std=0]
    3/10(t):  45%|████▌     | 454/1000 [01:08<01:21,  6.73it/s, running_loss=0.00458, running_acc=0.316]
    

    Are the accuracies correct? (running_acc=.326, acc=35.9?)

    I may be misunderstanding something, but shouldn't running_acc and acc be the same at the end of each epoch?

    bug 
    opened by kl2792 6
  • Tqdm for Jupyter Notebook

    Tqdm for Jupyter Notebook

    Each iteration of TQDM starts a new line in Jupyter Notebook -- is there any way to integrate one of the suggested fixes into torchbearer?

    (ref: https://github.com/tqdm/tqdm/issues/375, https://stackoverflow.com/a/47200571)

    bug 
    opened by kl2792 6
  • ReduceLROnPlateau

    ReduceLROnPlateau

    Dear all,

    first of all, I love this library.

    The ReduceLROnPlateau is not working when I call trail.evaluate.

    ...
        trial = Trial(cnn, optimizer, loss, metrics=['acc', 'loss'],
                      callbacks=[
                        #   CometCallback(experiment),
                          ReduceLROnPlateau(monitor='val_loss',
                                            factor=0.1, patience=5),
                        #   EarlyStopping(monitor='val_acc', patience=5, mode='max'),
                        #   CSVLogger('history.csv'),
                        #   ModelCheckpoint('best.pt', monitor='val_acc', mode='max')
        ]).to(device)
        trial.with_generators(train_generator=train_dl,
                              val_generator=val_dl, test_generator=test_dl)
        # history = trial.run(params['epochs'], verbose=1)
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
    

    error:

    0/1(e): 100%|███████████████████████████████████| 1/1 [00:00<00:00,  2.18it/s, test_acc=0.4667, test_loss=0.6516]
    Traceback (most recent call last):
      File "c:/Users/Francesco/Documents/PyTorch-Deep-Learning-Template/main.py", line 62, in <module>
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 298, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 133, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1131, in evaluate
        state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in on_end_epoch
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 105, in _for_list
        function(self.callback_list)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in <lambda>
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in on_end_epoch
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 66, in _for_list
        function(callback)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in <lambda>
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\torch_scheduler.py", line 32, in on_end_epoch
        self._scheduler.step(state[torchbearer.METRICS][self._monitor], epoch=state[torchbearer.EPOCH])
    KeyError: 'val_loss'
    

    Probably you need to disable the callback when evaluating (or just checking if the monitored metrics is in state['metrics'].

    Thank you!

    Best Regards,

    Francesco Saverio

    opened by FrancescoSaverioZuppichini 5
  • Trial predict fails with the given example

    Trial predict fails with the given example

    After training the model, I want to get the prediction on the test set. Not the accuracy. I know that should be Trial.evaluate(). It works well. Therefore I used Trial.predict(). Is that right?

    But the error says that AttributeError: 'dict' object has no attribute 'data'.

    I read the instruction of Trial that provides an example,

    
    # Simple trial to predict on some validation and test data
    >>> from torchbearer import Trial
    >>> val_data = torch.rand(5, 5)
    >>> test_data = torch.rand(5, 5)
    >>> t = Trial(None).with_test_data(test_data)
    >>> test_predictions = t.predict(data_key=torchbearer.TEST_DATA)
    
    

    I ran it but got an error AttributeError: 'NoneType' object has no attribute 'eval'

    So, is there any problem in this method?

    bug 
    opened by danielhuoo 5
  • Lean Model Checkpointing

    Lean Model Checkpointing

    Hi,

    I ran a Trial and have my model saves my model using torchbearer.callbacks.checkpointers.Best to a file model.pt.

    When I load the file with torch.load and run try to make a forward pass with it, I get the following error:

    model = MyModule()
    state_dict = torch.load('vae.pt')
    model.load_state_dict(state_dict) # <== I get the error here
    

    AttributeError: 'StateKey' object has no attribute 'startswith'

    I get that model is being saved so that I can be recovered to be ready for torchbearer, but how can we save the model lean?

    It seems like here, the model is only saved for reusability by torchbearer.

    Thanks a lot!

    docs 
    opened by dorukhansergin 5
  • loss_std resulting in complex number and breaking Tensorboard

    loss_std resulting in complex number and breaking Tensorboard

    I'm using torchbearer with PyTorch 0.4 and TensorboardX 1.2. Previously, I was using PyTorch 0.4.1, but I had to downgrade to use the TensorboardX because of a incompatibility with them. After adding the Tensorboard callback, the following error is raised after training for some time:

    {TypeError}can't convert complex to float

    When debugging, I noticed that the add_scalar() of TensorboardX tried to convert the scalar to float and, somehow, the val_loss_std was a complex number. Is there and error in how the std is calculated in order to result in a complex number?

    bug 
    opened by fernandocamargoai 5
  • Support multi input and output

    Support multi input and output

    Right now, it's not possible to:

    • Have a Module multiple inputs (eg. forward(x1, x2)).
    • Have a Module with multiple outputs (returning a tuple).

    I worked around the first problem by creating a Module with a single input and indexing each individual input. But the second problem makes it impossible to use the TripletMarginLoss, for example, since it expects 3 outputs from the Module.

    opened by fernandocamargoai 3
  • Model checkpointers save_weights_only

    Model checkpointers save_weights_only

    As per the discussion in #504 it would be good if the checkpointers had an option to just save the model state dict, rather than the trial one. Not sure what the argument should be, something like save_model_only / save_weights_only? @dorukhansergin @MattPainter01 any thoughts on this?

    enhancement 
    opened by ethanwharris 3
  • Running indefinitely?

    Running indefinitely?

    Currently there is no way to ask torchbearer to run until stopped. This would be useful for reinforcement learning where we don't know how long an episode will be.

    enhancement 
    opened by ethanwharris 3
  • lr_scheduler order changed since PyTorch 1.1.0

    lr_scheduler order changed since PyTorch 1.1.0

    Many thanks for the wonderful library.

    A warning message emerged when a scheduler was used in the callbacks: scheduler = torchbearer.callbacks.torch_scheduler.StepLR(step_size=5, gamma=0.1)

    Hope it could be considered in later update if not yet included.

    Python37\lib\site-packages\torch\optim\lr_scheduler.py:100: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
      "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
    
    bug 
    opened by yongduek 2
  • Queston about training loop

    Queston about training loop

    Hi! I'm trying to fork the repo and add some functionality for an experiment. But that requires an addition in the training loop. I've read the documentation and the code but I can't seem to understand where the training loop itself is defined. Can somebody point me in the right direction?

    Thanks in advance!

    opened by AnabetsyR 4
  • GradientNormClipping callback error

    GradientNormClipping callback error

    When I insert this callback in the trial I get the following error. Is this some kind of bug? It seems like the gradients are not passed in the callback.

    """ File "/home/dimitris/.local/lib/python3.6/site-packages/torch/nn/utils/clip_grad.py", line 30, in clip_grad_norm_ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) RuntimeError: stack expects a non-empty TensorList """

    opened by dimimal 0
  • Native automatic mixed precision for torchbearer

    Native automatic mixed precision for torchbearer

    Native automatic mixed precision support (torch.cuda.amp) is now in master: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.html

    Not sure if you ever tried Nvidia's (our) experimental Apex Amp, but I know it has many pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…). torch.cuda.amp fixes all these, the interface is more flexible and intuitive, and the tighter integration with pytorch brings more future optimizations into scope.

    I think the torch.cuda.amp API is a good fit for a higher-level library because its style is more functional (as in, it doesn't statefully alter anything outside itself). The necessary torch.cuda.amp calls don't have silent/weird effects elsewhere.

    If you want to talk about adding torch.cuda.amp to torchbearer, with an eye towards it becoming the future-proof source of mixed precision, message me on Pytorch slack anytime (or ask me for invites if you're not signed up). I'll check this issue periodically but I'm on Pytorch slack a greater fraction of the time than I care to admit.

    opened by mcarilli 0
  • Y_pred tuple behaviour changed

    Y_pred tuple behaviour changed

    When y_pred is a tuple (i.e. model returns multiple outputs) the criterion now recieves the tuple unpacked. This should either be reverted or more clearly documented.

    bug 
    opened by ethanwharris 0
Releases(0.5.3)
  • 0.5.3(Jan 31, 2020)

  • 0.5.2(Jan 28, 2020)

    [0.5.2] - 2020-28-01

    Added

    • Added option to use mixup loss with cutmix
    • Support for PyTorch 1.4.0

    Changed

    • Changed PyCM save methods to use *args and **kwargs

    Deprecated

    Removed

    Fixed

    • Fixed a bug where the PyCM callback would fail when saving
    Source code(tar.gz)
    Source code(zip)
  • 0.5.1(Nov 6, 2019)

    [0.5.1] - 2019-11-06

    Added

    • Added BCPlus callback for between-class learning
    • Added support for PyTorch 1.3
    • Added a show flag to the ImagingCallback.to_pyplot method, set to false to stop it from calling plt.show
    • Added manifold mixup

    Changed

    • Changed the default behaviour of ImagingCallback.to_pyplot to turn off the axis

    Deprecated

    Removed

    Fixed

    • Fixed a bug when resuming an old state dict with tqdm enabled
    • Fixed a bug in imaging where passing a title to to_pyplot was not possible
    Source code(tar.gz)
    Source code(zip)
  • 0.5.0(Sep 17, 2019)

    [0.5.0] - 2019-09-17

    Added

    • Added PyTorch CyclicLR scheduler

    Changed

    • Torchbearer now supports Modules with multiple inputs and multiple outputs

    Deprecated

    Removed

    • Cyclic LR callback in favour of torch cyclic lr scheduler
    • Removed support for PyTorch 0.4.x

    Fixed

    • Fixed bug where aggregate predictions couldn't handle empty list
    • Fixed a bug where Runtime Errors on forward weren't handled properly
    • Fixed a bug where exceptions on forward wouldn't print the traceback properly
    • Fixed a documentation mistake whereby ReduceLROnPlateau was said to increase learning rate
    Source code(tar.gz)
    Source code(zip)
  • 0.4.0(Sep 17, 2019)

    [0.4.0] - 2019-07-05

    Added

    • Added with_loader trial method that allows running of custom batch loaders
    • Added a Mock Model which is set when None is passed as the model to a Trial. Mock Model always returns None.
    • Added __call__(state) to StateKey so that they can now be used as losses
    • Added a callback to do cutout regularisation
    • Added a with_data trial method that allows passing of train, val and test data in one call
    • Added the missing on_init callback decorator
    • Added a step_on_batch flag to the early stopping callback
    • Added multi image support to imaging
    • Added a callback to unpack state into torchbearer.X at sample time for specified keys and update state after the forward pass based on model outputs. This is useful for using DataParallel which pass the main state dict directly.
    • Added callback for generating confusion matrices with PyCM
    • Added a mixup callback with associated loss
    • Added Label Smoothing Regularisation (LSR) callback
    • Added CutMix regularisation
    • Added default metric from paper for when Mixup loss is used

    Changed

    • Changed history to now just be a list of records
    • Categorical Accuracy metric now also accepts tensors of size (B, C) and gets the max over C for the taget class

    Deprecated

    Removed

    • Removed the variational sub-package, this will now be packaged separately
    • Removed verbose argument from the early stopping callback

    Fixed

    • Fixed a bug where list or dictionary metrics would cause the tensorboard callback to error
    • Fixed a bug where running a trial without training steps would error
    • Fixed a bug where the caching imaging callback didn't reset data so couldn't be run in multiple trials
    • Fixed a bug in the ClassAppearanceModel callback
    • Fixed a bug where the state given to predict was not a State object
    • Fixed a bug with Cutout on gpu
    • Fixed a bug where MakeGrid callback wasn't passing all arguments correctly
    • Fixed a bug in ImagingCallback that would sometimes cause make_grid to throw an error
    • Fixed a bug where the verbose argument would not work unless given as a keyword argument
    • Fixed a bug where the data_key argument would sometimes not work as expected
    • Fixed a bug where cutout required a seed
    • Fixed a bug where cutmix wasn't sendign the beta distribution sample to the device
    Source code(tar.gz)
    Source code(zip)
  • 0.3.2(May 28, 2019)

    [0.3.2] - 2019-05-28

    Added

    Changed

    Deprecated

    Removed

    Fixed

    • Fixed a bug where for_steps would sometimes not work as expected if called in the wrong order
    • Fixed a bug where torchbearer installed via pip would crash on import
    Source code(tar.gz)
    Source code(zip)
  • 0.3.1(May 24, 2019)

    [0.3.1] - 2019-05-24

    Added

    • Added cyclic learning rate finder
    • Added on_init callback hook to run at the end of trial init
    • Added callbacks for weight initialisation in torchbearer.callbacks.init
    • Added with_closure trial method that allows running of custom closures
    • Added base_closure function to bases that allows creation of standard training loop closures
    • Added ImagingCallback class for callbacks which produce images that can be sent to tensorboard, visdom or a file
    • Added CachingImagingCallback and MakeGrid callback to make a grid of images
    • Added the option to give the only_if callback decorator a function of self and state rather than just state
    • Added Layer-sequential unit-variance (LSUV) initialization
    • Added ClassAppearanceModel callback and example page for visualising CNNs
    • Added on_checkpoint callback decorator
    • Added support for PyTorch 1.1.0

    Changed

    • No_grad and enable_grad decorators are now also context managers

    Deprecated

    Removed

    • Removed the fluent decorator, just use return self
    • Removed install dependency on torchvision, still required for some functionality

    Fixed

    • Fixed bug where replay errored when train or val steps were None
    • Fixed a bug where mock optimser wouldn't call it's closure
    • Fixed a bug where the notebook check raised ModuleNotFoundError when IPython not installed
    • Fixed a memory leak with metrics that causes issues with very long epochs
    • Fixed a bug with the once and once_per_epoch decorators
    • Fixed a bug where the test criterion wouldn't accept a function of state
    • Fixed a bug where type inference would not work correctly when chaining Trial methods
    • Fixed a bug where checkpointers would error when they couldn't find the old checkpoint to overwrite
    • Fixed a bug where the 'test' label would sometimes not populate correctly in the default accuracy metric
    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Feb 28, 2019)

    [0.3.0] - 2019-02-28

    Added

    • Added torchbearer.variational, a sub-package for implementations of state of the art variational auto-encoders
    • Added SimpleUniform and SimpleExponential distributions
    • Added a decorator which can be used to cite a research article as part of a doc string
    • Added an optional dimension argument to the mean, std and running_mean metric aggregators
    • Added a var metric and decorator which can be used to calculate the variance of a metric
    • Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
    • Added support for rounding 1D lists to the Tqdm callback
    • Added SimpleWeibull distribution
    • Added support for Python 2.7
    • Added SimpleWeibullSimpleWeibullKL
    • Added SimpleExponentialSimpleExponentialKL
    • Added the option for model parameters only saving to Checkpointers.
    • Added documentation about serialization.
    • Added support for indefinite data loading. Iterators can now be run until complete independent of epochs or iterators can be refreshed during an epoch if complete.
    • Added support for batch intervals in interval checkpointer
    • Added line magic %torchbearer notebook
    • Added 'accuracy' variants of 'acc' default metrics

    Changed

    • Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
    • Tqdm precision argument now rounds to decimal places rather than significant figures
    • Trial will now simply infer if the model has an argument called 'state'
    • Torchbearer now infers if inside a notebook and will use the appropriate tqdm module if not set

    Deprecated

    Removed

    • Removed the old Model API (deprecated since version 0.2.0)
    • Removed the 'pass_state' argument from Trial, this will now be inferred
    • Removed the 'std' decorator from the default metrics

    Fixed

    • Fixed a bug in the weight decay callback which would result in potentially negative decay (now just uses torch.norm)
    • Fixed a bug in the cite decorator causing the citation to not show up correctly
    • Fixed a memory leak in the mse primitive metric
    Source code(tar.gz)
    Source code(zip)
  • 0.2.6.1(Feb 25, 2019)

  • 0.2.6(Dec 19, 2018)

    [0.2.6] - 2018-12-19

    Added

    Changed

    • Y_PRED, Y_TRUE and X can now equivalently be accessed as PREDICTION, TARGET and INPUT respectively

    Deprecated

    Removed

    Fixed

    • Fixed a bug where the LiveLossPlot callback would trigger an error if run and evaluate were called separately
    • Fixed a bug where state key errors would report to the wrong stack level
    • Fixed a bug where the user would wrongly get a state key error in some cases
    Source code(tar.gz)
    Source code(zip)
  • 0.2.5(Dec 19, 2018)

    [0.2.5] - 2018-12-19

    Added

    • Added flag to replay to replay only a single batch per epoch
    • Added support for PyTorch 1.0.0 and Python 3.7
    • MetricTree can now unpack dictionaries from root, this is useful if you want to get a mean of a metric. However, this should be used with caution as it extracts only the first value in the dict and ignores the rest.
    • Added a callback for the livelossplot visualisation tool for notebooks

    Changed

    • All error / accuracy metrics can now optionally take state keys for predictions and targets as arguments

    Deprecated

    Removed

    Fixed

    • Fixed a bug with the EpochLambda metric which required y_true / y_pred to have specific forms
    Source code(tar.gz)
    Source code(zip)
  • 0.2.4(Nov 16, 2018)

    [0.2.4] - 2018-11-16

    Added

    • Added metric functionality to state keys so that they can be used as metrics if desired
    • Added customizable precision to the printer callbacks
    • Added threshold to binary accuracy. Now it will appropriately handle any values in [0, 1]

    Changed

    • Changed the default printer precision to 4s.f.
    • Tqdm on_epoch now shows metrics immediately when resuming

    Deprecated

    Removed

    Fixed

    • Fixed a bug which would incorrectly trigger version warnings when loading in models
    • Fixed bugs where the Trial would not fail gracefully if required objects were not in state
    • Fixed a bug where none criterion didn't work with the add_to_loss callback
    • Fixed a bug where tqdm on_epoch always started at 0
    Source code(tar.gz)
    Source code(zip)
  • 0.2.3(Oct 12, 2018)

    [0.2.3] - 2018-10-12

    Added

    • Added string representation of Trial to give summary
    • Added option to log Trial summary to TensorboardText
    • Added a callback point ('on_checkpoint') which can be used for model checkpointing after the history ios updated

    Changed

    • When resuming training checkpointers no longer delete the state file the trial was loaded from
    • Changed the metric eval to include a data_key which tells us what data we are evaluating on

    Deprecated

    Removed

    Fixed

    • Fixed a bug where callbacks weren't handled correctly in the predict and evaluate methods of Trial
    • Fixed a bug where the history wasn't updated when new metrics were calculated with the evaluate method of Trial
    • Fixed a bug where tensorboard writers couldn't be reused
    • Fixed a bug where the none criterion didn't require gradient
    • Fix bug where tqdm wouldn't get correct iterator length when evaluating on test generator
    • Fixed a bug where evaluating before training tried to update history before it existed
    • Fixed a bug where the metrics would output 'val_acc' even if evaluating on test or train data
    • Fixed a bug where roc metric didn't detach y_pred before sending to numpy
    • Fixed a bug where resuming from a checkpoint saved with one of the callbacks didn't populate the epoch number correctly
    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Sep 18, 2018)

    [0.2.2] - 2018-09-18

    Added

    • The default_for_key metric decorator can now be used to pass arguments to the init of the inner metric
    • The default metric for the key 'top_10_acc' is now the TopKCategoricalAccuracy metric with k set to 10
    • Added global verbose flag for trial that can be overridden by run, evaluate, predict
    • Added an LR metric which retrieves the current learning rate from the optimizer, default for key 'lr'

    Fixed

    • Fixed a bug where the DefaultAccuracy metric would not put the inner metric in eval mode if the first call to reset was after the call to eval
    • Fixed a bug where trying to load a state dict in a different session to where it was saved didn't work properly
    • Fixed a bug where the empty criterion would trigger an error if no Y_TRUE was put in state
    Source code(tar.gz)
    Source code(zip)
  • 0.2.1(Sep 11, 2018)

    [0.2.1] - 2018-09-11

    Added

    • Evaluation and prediction can now be done on any data using data_key keywork arg
    • Text tensorboard/visdom logger that writes epoch/batch metrics to text

    Changed

    • TensorboardX, Numpy, Scikit-learn and Scipy are no longer dependancies and only required if using the tensorboard callbacks or roc metric

    Deprecated

    Removed

    Fixed

    • Model class setting generator incorrectly leading to stop iterations.
    • Argument ordering is consistent in Trial.with_generators and Trial.__init__
    • Added a state dict for the early stopping callback
    • Fixed visdom parameters not getting set in some cases
    Source code(tar.gz)
    Source code(zip)
  • 0.2.0(Aug 21, 2018)

    See [NEW!] in README.md for new key features

    [0.2.0] - 2018-08-21

    Added

    • Added the ability to pass custom arguments to the tqdm callback
    • Added an ignore_index flag to the categorical accuracy metric, similar to nn.CrossEntropyLoss. Usage: metrics=[CategoricalAccuracyFactory(ignore_index=0)]
    • Added TopKCategoricalAccuracy metric (default for key: top_5_acc)
    • Added BinaryAccuracy metric (default for key: binary_acc)
    • Added MeanSquaredError metric (default for key: mse)
    • Added DefaultAccuracy metric (use with 'acc' or 'accuracy') - infers accuracy from the criterion
    • New Trial api torchbearer.Trial to replace the Model api. Trial api is more atomic and uses the fluent pattern to allow chaining of methods.
    • torchbearer.Trial has with_x_generator and with_x_data methods to add training/validation/testing generators to the trial. There is a with_generators method to allow passing of all generators in one call.
    • torchbearer.Trial has for_x_steps and for_steps to allow running of trails without explicit generators or data tensors
    • torchbearer.Trial keeps a history of run calls which tracks number of epochs ran and the final metrics at each epoch. This allows seamless resuming of trial running.
    • torchbearer.Trial.state_dict now returns the trial history and callback list state allowing for full resuming of trials
    • torchbearer.Trial has a replay method that can replay training (with callbacks and display) from the history. This is useful when loading trials from state.
    • The backward call can now be passed args by setting state[torchbearer.BACKWARD_ARGS]
    • torchbearer.Trial implements the forward pass, loss calculation and backward call as a optimizer closure
    • Metrics are now explicitly calculated with no gradient

    Changed

    • Callback decorators can now be chained to allow construction with multiple methods filled
    • Callbacks can now implement state_dict and ``load_state_dict` to allow callbacks to resume with state
    • State dictionary is now accepts StateKey objects which are unique and generated through torchbearer.state.get_state
    • State dictionary now warns when accessed with strings as this allows for collisions
    • Checkpointer callbacks will now resume from a state dict when resume=True in Trial

    Deprecated

    • torchbearer.Model has been deprecated in favour of the new torchbearer.Trial api

    Removed

    • Removed the MetricFactory class. Decorators still work in the same way but the Factory is no longer needed.

    Fixed

    Source code(tar.gz)
    Source code(zip)
  • 0.1.7(Aug 14, 2018)

    [0.1.7] - 2018-08-14

    Added

    • Added visdom logging support to tensorbard callbacks
    • Added option to choose tqdm module (tqdm, tqdm_notebook, ...) to Tqdm callback
    • Added some new decorators to simplify custom callbacks that must only run under certain conditions (or even just once).

    Changed

    • Instantiation of Model will now trigger a warning pending the new Trial API in the next version
    • TensorboardX dependancy now version 1.4

    Deprecated

    Removed

    Fixed

    • Mean and standard deviation calculations now work correctly for network outputs with many dimensions
    • Callback list no longer shared between fit calls, now a new copy is made each fit
    Source code(tar.gz)
    Source code(zip)
  • 0.1.6(Aug 10, 2018)

    [0.1.6] - 2018-08-10

    Added

    • Added a verbose level (options are now 0,1,2) which will print progress for the entire fit call, updating every epoch. Useful when doing dynamic programming with little data.
    • Added support for dictionary outputs of dataloader
    • Added abstract superclass for building TensorBoardX based callbacks

    Changed

    • Timer callback can now also be used as a metric which allows display of specified timings to printers and has been moved to metrics.
    • The loss_criterion is renamed to criterion in torchbearer.Model arguments.
    • The criterion in torchbearer.Model is now optional and will provide a zero loss tensor if it is not given.
    • TensorBoard callbacks refactored to be based on a common super class
    • TensorBoard callbacks refactored to use a common SummaryWriter for each log directory

    Deprecated

    Removed

    Fixed

    • Standard deviation calculation now returns 0 instead of complex value when given very close samples
    Source code(tar.gz)
    Source code(zip)
  • 0.1.5(Jul 30, 2018)

    [0.1.5] - 2018-07-30

    Added

    • Added a on_validation_criterion callback hook
    • Added a DatasetValidationSplitter which can be used to create a validation split if required for datasets like Cifar10 or MNIST
    • Added simple timer callback

    Changed

    Deprecated

    Removed

    Fixed

    • Fixed a bug where checkpointers would not save the model in some cases
    • Fixed a bug with the ROC metric causing it to not work
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 23, 2018)

    [0.1.4] - 2018-07-23

    Added

    • Added a decorator API for metrics which allows decorators to be used for metric construction
    • Added a default_for_key decorator which can be used to associate a string with a given metric in metric lists
    • Added a decorator API for callbacks which allows decorators to be used for simple callback construction
    • Added a add_to_loss callback decorator which allows quicker constructions of callbacks that add values to the loss

    Changed

    • Changed the API for running metrics and aggregators to no longer wrap a metric but instead receive input

    Deprecated

    Removed

    Fixed

    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 18, 2018)

    [0.1.3] - 2018-07-17

    Added

    • Added a flag (step_on_batch) to the LR Scheduler callbacks which allows for step() to be called on each iteration instead of each epoch
    • Added on_sample_validation and on_forward_validation calls for validation callbacks
    • Added GradientClipping callback which simply clips the absolute gradient of the model parameters

    Changed

    • Changed the order of the arguments to the lambda function in the EpochLambda metric for consistency with pytorch and other metrics
    • Checkpointers now create directory to savepath if it doesn't exist
    • Changed the 'on_forward_criterion' callback method to 'on_criterion'
    • Changed epoch number in printer callbacks to be consistent with the rest of torchbearer

    Deprecated

    Removed

    Fixed

    • Fixed tests which were failing as of version 0.1.2
    • Fixed validation_steps not being added to state
    • Fixed checkpointer bug when path contained only filename and no directory path
    • Fixed console printer bug not printing validation statistics
    • Fixed console printer bug calling final_metrics before they existed in state
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Jun 8, 2018)

    [0.1.2] - 2018-06-08

    Added

    • Added support for tuple outputs from generators, bink expects output to be length 2. Specifically, x, y = next() is possible, where x and y can be tuples of arbitrary size or depth
    • Added support for torch dtypes in bink Model.to(...)
    • Added pickle_module and pickle_protocol to checkpointers for consistency with torch.save

    Changed

    • Changed the learning rate scheduler callbacks to no longer require an optimizer and to have the proper arguments

    Deprecated

    Removed

    Fixed

    • Fixed an issue in GradientNormClipping which raised a warning in PyTorch >= 0.4
    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(May 30, 2018)

Owner
The torchbearer project, by @ecs-vlc
Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices, ACM Multimedia 2021

Codes for ECBSR Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices Xindong Zhang, Hui Zeng, Lei Zhang ACM Multimedia 202

xindong zhang 236 Dec 26, 2022
[ICCV'2021] "SSH: A Self-Supervised Framework for Image Harmonization", Yifan Jiang, He Zhang, Jianming Zhang, Yilin Wang, Zhe Lin, Kalyan Sunkavalli, Simon Chen, Sohrab Amirghodsi, Sarah Kong, Zhangyang Wang

SSH: A Self-Supervised Framework for Image Harmonization (ICCV 2021) code for SSH Representative Examples Main Pipeline RealHM DataSet Google Drive Pr

VITA 86 Dec 02, 2022
Learning Continuous Signed Distance Functions for Shape Representation

DeepSDF This is an implementation of the CVPR '19 paper "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" by Park et a

Meta Research 1.1k Jan 01, 2023
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Camera ready code repo for the NeuRIPS 2021 paper: "Impression learning: Online representation learning with synaptic plasticity".

Impression-Learning-Camera-Ready Camera ready code repo for the NeuRIPS 2021 paper: "Impression learning: Online representation learning with synaptic

2 Feb 09, 2022
Art Project "Schrödinger's Game of Life"

Repo of the project "Team Creative Quantum AI: Schrödinger's Game of Life" Installation new conda env: conda create --name qcml python=3.8 conda activ

ℍ◮ℕℕ◭ℍ ℝ∈ᛔ∈ℝ 2 Sep 15, 2022
Segmentation vgg16 fcn - cityscapes

VGGSegmentation Segmentation vgg16 fcn - cityscapes Priprema skupa skripta prepare_dataset_downsampled.py Iz slika cityscapesa izrezuje haubu automobi

6 Oct 24, 2020
Libtorch yolov3 deepsort

Overview It is for my undergrad thesis in Tsinghua University. There are four modules in the project: Detection: YOLOv3 Tracking: SORT and DeepSORT Pr

Xu Wei 226 Dec 13, 2022
Sketch-Based 3D Exploration with Stacked Generative Adversarial Networks

pix2vox [Demonstration video] Sketch-Based 3D Exploration with Stacked Generative Adversarial Networks. Generated samples Single-category generation M

Takumi Moriya 232 Nov 14, 2022
A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization components are included and optional.

Description A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization co

AoxiangFan 9 Nov 10, 2022
Model-free Vehicle Tracking and State Estimation in Point Cloud Sequences

Model-free Vehicle Tracking and State Estimation in Point Cloud Sequences 1. Introduction This project is for paper Model-free Vehicle Tracking and St

TuSimple 92 Jan 03, 2023
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Ilyes Khemakhem 65 Dec 22, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

51 Dec 01, 2022
Super Pix Adv - Offical implemention of Robust Superpixel-Guided Attentional Adversarial Attack (CVPR2020)

Super_Pix_Adv Offical implemention of Robust Superpixel-Guided Attentional Adver

DLight 8 Oct 26, 2022
Functional deep learning

Pipeline abstractions for deep learning. Full documentation here: https://lf1-io.github.io/padl/ PADL: is a pipeline builder for PyTorch. may be used

LF1 101 Nov 09, 2022
Extreme Rotation Estimation using Dense Correlation Volumes

Extreme Rotation Estimation using Dense Correlation Volumes This repository contains a PyTorch implementation of the paper: Extreme Rotation Estimatio

Ruojin Cai 29 Nov 18, 2022
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022
Official implementation of VQ-Diffusion

Vector Quantized Diffusion Model for Text-to-Image Synthesis Overview This is the official repo for the paper: [Vector Quantized Diffusion Model for T

Microsoft 592 Jan 03, 2023
Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION.

LiMuSE Overview Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION. LiMuSE explores group communication on a multi

Auditory Model and Cognitive Computing Lab 17 Oct 26, 2022
Pytorch implementation of One-Shot Affordance Detection

One-shot Affordance Detection PyTorch implementation of our one-shot affordance detection models. This repository contains PyTorch evaluation code, tr

46 Dec 12, 2022