TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular potentials

Overview

TorchMD-net

TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular potentials. It offers an efficient and fast implementation and it is integrated in GPU-accelerated molecular dynamics code like ACEMD and OpenMM. See the full paper at https://arxiv.org/abs/2202.02541.

Installation

Create a new conda environment using Python 3.8 via

conda create --name torchmd python=3.8
conda activate torchmd

Install PyTorch

Then, install PyTorch according to your hardware specifications (more information here), e.g. for CUDA 11.1 and the most recent version of PyTorch use

conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia

Install PyTorch Geometric

Install pytorch-geometric with its dependencies through

conda install pytorch-geometric -c rusty1s -c conda-forge

To install PyTorch Geometric via pip or for PyTorch < 1.8, see https://github.com/rusty1s/pytorch_geometric#installation.

Install TorchMD-Net

Download and install the torchmd-net repository via

git clone https://github.com/compsciencelab/torchmd-net.git
pip install -e torchmd-net/

Performance

The TorchMD-net equivariant Transformer (ET) is competitive with previous methods on the MD17 benchmark dataset.

image

Usage

Specifying training arguments can either be done via a configuration yaml file or through command line arguments directly. An example configuration file for a TorchMD Graph Network can be found at examples/graph-network.yaml. For an example on how to train the network on the QM9 dataset, see examples/train_GN_QM9.sh. GPUs can be selected by their index by listing the device IDs (coming from nvidia-smi) in the CUDA_VISIBLE_DEVICES environment variable. Otherwise, the argument --ngpus can be used to select the number of GPUs to train on (-1 uses all available GPUs or the ones specified in CUDA_VISIBLE_DEVICES).

mkdir output
CUDA_VISIBLE_DEVICES=0 python torchmd-net/scripts/torchmd_train.py --conf torchmd-net/examples/graph-network.yaml --dataset QM9 --log-dir output/

Creating a new dataset

If you want to train on custom data, first have a look at torchmdnet.datasets.Custom, which provides functionalities for loading a NumPy dataset consisting of atom types and coordinates, as well as energies, forces or both as the labels. Alternatively, you can implement a custom class according to the torch-geometric way of implementing a dataset. That is, derive the Dataset or InMemoryDataset class and implement the necessary functions (more info here). The dataset must return torch-geometric Data objects, containing at least the keys z (atom types) and pos (atomic coordinates), as well as y (label), dy (derivative of the label w.r.t atom coordinates) or both.

Custom prior models

In addition to implementing a custom dataset class, it is also possible to add a custom prior model to the model. This can be done by implementing a new prior model class in torchmdnet.priors and adding the argument --prior-model . As an example, have a look at torchmdnet.priors.Atomref.

Multi-Node Training

Currently does not work with the most recent PyTorch Lightning version. Tested up to pytorch-lightning==1.2.10

In order to train models on multiple nodes some environment variables have to be set, which provide all necessary information to PyTorch Lightning. In the following we provide an example bash script to start training on two machines with two GPUs each. The script has to be started once on each node. Once train.py is started on all nodes, a network connection between the nodes will be established using NCCL.

In addition to the environment variables the argument --num-nodes has to be specified with the number of nodes involved during training.

export NODE_RANK=0
export MASTER_ADDR=hostname1
export MASTER_PORT=12910

mkdir -p output
CUDA_VISIBLE_DEVICES=0,1 python torchmd-net/scripts/train.py --conf torchmd-net/examples/graph-network.yaml --num-nodes 2 --log-dir output/
  • NODE_RANK : Integer indicating the node index. Must be 0 for the main node and incremented by one for each additional node.
  • MASTER_ADDR : Hostname or IP address of the main node. The same for all involved nodes.
  • MASTER_PORT : A free network port for communication between nodes. PyTorch Lightning suggests port 12910 as a default.

Known Limitations

  • Due to the way PyTorch Lightning calculates the number of required DDP processes, all nodes must use the same number of GPUs. Otherwise training will not start or crash.
  • We observe a 50x decrease in performance when mixing nodes with different GPU architectures (tested with RTX 2080 Ti and RTX 3090).
Comments
  • Pre-trained model

    Pre-trained model

    We are writing a paper about NNP/MM in ACEMD. So far, we have used ANI-2x for protein-ligand simulations, but to demonstrate a general utility, it would be good to include one more NNP.

    Would it be possible to have a pre-trained TorchMD-NET model?

    opened by raimis 32
  • Base learning rate decay on training loss

    Base learning rate decay on training loss

    This changes it to base learning rate decay on training loss rather than validation loss. That gives a much cleaner signal for whether it is still learning. The practical effect is that you can use a smaller value for lr_patience, which leads to faster training.

    In general, training loss tells you whether it is learning, and the difference between training loss and validation loss tells you whether it is overfitting. If the training loss stops decreasing, that means you need to reduce the learning rate. If the training loss is still decreasing but the validation loss stops going down, that means it is overfitting and you should stop. Reducing the learning rate won't help.

    opened by peastman 28
  • Creating a custom dataset

    Creating a custom dataset

    I want to train a model on a custom dataset. I'm trying to follow the example at https://github.com/torchmd/torchmd-cg/blob/master/tutorial/Chignolin_Coarse-Grained_Tutorial.ipynb, but my data is different enough that it isn't quite clear how I should format it.

    My datasets consist of many molecules of different sizes. For each molecule I have

    • an array of atom type indices
    • an array of atom coordinates
    • a potential energy
    • (optional) an array of forces on atoms

    This differs from the tutorial in a few critical ways. My molecules are all different sizes, so I can't just put everything into rectangular arrays. And the training data is different: sometimes I will have only energies, and sometimes I will have both forces and energies which should be trained on together. The example trains only on forces with no energies.

    Any guidance would be appreciated!

    opened by peastman 20
  • Unable to fit model

    Unable to fit model

    I've been trying to train a model on an early subset of the SPICE dataset. All my efforts so far have been unsuccessful. I must be doing something wrong, but I really don't know what. I'm hoping someone else can spot the problem. My configuration file is given below. Here's the HDF5 file for the dataset.

    I've tried training with or without derivatives. I've tried a range of initial learning rates, with or without warmup. I've tried varying model parameters. I've tried restricting it to only molecules with no formal charges. Nothing makes any difference. In all cases, the loss starts out at about 2e7 and never decreases.

    The dataset consists of all SPICE calculations that had been completed when I started working on this a couple of weeks ago. I converted the units so positions are in Angstroms and energies in kJ/mol. I also subtracted off per-atom energies. Atom types are the union of element and formal charge. Here's the mapping:

    typeDict = {('Br', -1): 0, ('Br', 0): 1, ('C', -1): 2, ('C', 0): 3, ('C', 1): 4, ('Ca', 2): 5, ('Cl', -1): 6,
                ('Cl', 0): 7, ('F', -1): 8, ('F', 0): 9, ('H', 0): 10, ('I', -1): 11, ('I', 0): 12, ('K', 1): 13,
                ('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20,
                ('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27}
    

    If anyone can provide insight, I'll be very grateful!

    activation: silu
    atom_filter: -1
    batch_size: 128
    cutoff_lower: 0.0
    cutoff_upper: 8.0
    dataset: HDF5
    dataset_root: SPICE-corrected.hdf5
    derivative: false
    distributed_backend: ddp
    early_stopping_patience: 40
    embedding_dimension: 64
    energy_weight: 1.0
    force_weight: 0.001
    inference_batch_size: 128
    lr: 1.e-4
    lr_factor: 0.8
    lr_min: 1.e-7
    lr_patience: 10
    lr_warmup_steps: 5000
    max_num_neighbors: 80
    max_z: 28
    model: equivariant-transformer
    neighbor_embedding: true
    ngpus: -1
    num_epochs: 1000
    num_heads: 8
    num_layers: 5
    num_nodes: 1
    num_rbf: 64
    num_workers: 4
    rbf_type: expnorm
    save_interval: 5
    seed: 1
    test_interval: 10
    test_size: 0.01
    trainable_rbf: true
    val_size: 0.05
    weight_decay: 0.0
    
    opened by peastman 18
  • NaN when fitting with derivative

    NaN when fitting with derivative

    I'm trying to fit an equivariant transformer model. If I specify derivative: true in the configuration file to use derivatives in fitting, then after only a few training steps the model output becomes nan. This happens even if I also specify force_weight: 0.0. The derivatives shouldn't affect the loss at all in that case, yet it still causes fitting to fail. The obvious explanation would be if I had a nan in the training data somewhere, since that would cause the loss to also be nan even after multiplying by 0. But I verified that's not the case. Immediately after it computes the loss

    https://github.com/torchmd/torchmd-net/blob/b9785d203d0f1e30798db44c1abfeeb7f5dc2eac/torchmdnet/module.py#L89

    I added

    print(loss_dy, torch.all(torch.isfinite(deriv)), torch.all(torch.isfinite(batch.dy)))
    

    Here's the relevant output from the log.

    Epoch 0:   1%|          | 31/5483 [00:06<19:36,  4.63it/s, loss=1.28e+07, v_num=_]tensor(11670.3730, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
    Epoch 0:   1%|          | 32/5483 [00:06<19:32,  4.65it/s, loss=1.25e+07, v_num=_]tensor(273794.6562, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(True, device='cuda:0') tensor(True, device='cuda:0')
    Epoch 0:   1%|          | 33/5483 [00:07<19:28,  4.67it/s, loss=1.25e+07, v_num=_]tensor(nan, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(False, device='cuda:0') tensor(True, device='cuda:0')
    Epoch 0:   1%|          | 34/5483 [00:07<19:25,  4.68it/s, loss=nan, v_num=_]     tensor(nan, device='cuda:0', grad_fn=<MseLossBackward0>) tensor(False, device='cuda:0') tensor(True, device='cuda:0')
    

    batch.dy never contains a non-finite value.

    Any idea what could be causing this?

    opened by peastman 15
  • Running Inference for a large organic molecule

    Running Inference for a large organic molecule

    I'm trying to run inference on an organic molecule of 47 atoms. I'm using the equivariant transformer pretrained on the ANI-1 dataset. I passed atomic numbers and the coordinates as tensors and am interested to get the energy. On running the code, it is giving the following error:

    Traceback (most recent call last):
      File "//test1.py", line 35, in <module>
        energy, forces = model(z, pos)
      File "/opt/conda/envs/torchmd-net/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/torchmdnet/models/model.py", line 171, in forward
        x, v, z, pos, batch = self.representation_model(z, pos, batch, q=q, s=s)
      File "/opt/conda/envs/torchmd-net/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/torchmdnet/models/torchmd_et.py", line 161, in forward
        edge_index, edge_weight, edge_vec = self.distance(pos, batch)
      File "/opt/conda/envs/torchmd-net/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/torchmdnet/models/utils.py", line 221, in forward
        assert not (
    AssertionError: The neighbor search missed some atoms due to max_num_neighbors being too low. Please increase this parameter to include the maximum number of atoms within the cutoff.
    
    opened by nishi-acog 11
  • Consistent naming between dy and forces

    Consistent naming between dy and forces

    This PR renames y to energy and dy to forces in places where we mean forces and not gradients. PyTorch-Geometric Data objects are now constructed as Data(energy=<energy>, ...) or Data(energy=<energy>, forces=<forces>, ...). This should not change any behavior but just concerns naming conventions. Linked to #116

    opened by PhilippThoelke 11
  • Loading jittable models

    Loading jittable models

    Hello. After catching up to main, I am no longer able to load my models after training them. When calling torch.load() on a .pt file, I get the following error:

    ModuleNotFoundError: No module named 'CFConvJittable_07e26a'

    Is there a new procedure for loading models for prediction/simulation?

    help wanted 
    opened by nec4 11
  • radius graph defaults

    radius graph defaults

    Hello everyone! The tools are great so far. I noticed that there is a small but important corner case for the usage of radius_graph from torch_cluster. When searching for neighbors, the behavior is governed by the cutoff distance r and max_num_neighbors (see docs here: https://github.com/rusty1s/pytorch_cluster#radius-graph). The latter is set to a maximum of 32 neighbors for each node. If, for example, the user inputs a large cutoff distance intending to return all neighbors, they will still be truncated at a maximum of 32 even if the user expects more. Furthermore, I'm not sure how radius_graph decides to reject extra neighbors, or how example shuffling during training affects this - for my usage case it seems to make a big difference in the training and inference. Because the SchNet philosophy is to operate on the notion of cutoff distances, not maximum neighbors, would it make sense to add a kwarg to the TorchMD_GN.__init__() raise the limit of the max neighbors for this operation?

    Of course, most users probably will not run into this problem if they stick to small cutoffs because they will never hit the upper neighbor ceiling. However, I would be happy to branch, implement this, write tests, and make a PR if it seems like a good idea.

    opened by nec4 11
  • Implement ZBL potential

    Implement ZBL potential

    This is the first piece of #26. It isn't fully tested yet, but I think it's ready for a first round of comments.

    I created a mechanism for passing arguments to the prior. prior_args is now the option specified by the user in the config file. prior_init_args stores the value returned by get_init_args(), which contains the arguments needed for reconstructing it from a checkpoint.

    This prior requires the dataset to provide several pieces of information. To keep things general, the HDF5 format allows the file to contain a _metadata group which can store arbitrary pieces of information. Most of the other dataset classes should be able to hardcode the necessary values, since they aren't intended to be general.

    opened by peastman 10
  • Cannot create env with mamba

    Cannot create env with mamba

    I have been trying to use torchmd-net and torchmd-cg. I first tried installing torchmd-net following the instructions given. Tried to create the env with mamba and I encountered the following error. Help?

    Encountered problems while solving:
      - nothing provides requested nnpops 0.2
      - nothing provides requested pytorch_cluster 1.5.9
    
    help wanted 
    opened by geemi725 10
  • Created Coulomb prior

    Created Coulomb prior

    This adds a Coulomb interaction. At the moment it requires partial charges to be specified by the dataset. In the future we'll want to let the model predict charges, but that will come later.

    The interaction is applied to all pairs without any cutoff. The point of this prior is to provide accurate long range interactions. It does get scaled by erf(alpha*r), which reduces the effect at short ranges and prevents it from diverging at r=0. That also happens to be the scale factor of the reciprocal space term in Ewald summation. Eventually we'll want to apply this to large systems that are partly modeled with ML and partly with a conventional force field. The idea is that you'll be able to include the Coulomb prior simply by including the ML region in the reciprocal space PME calculation.

    opened by peastman 1
  • Don't overwrite logs when resuming training

    Don't overwrite logs when resuming training

    I often want to resume training from a checkpoint with --load-model. When I do that, I don't want to lose all the information in the log and metrics.csv files. The obvious way to do that is to create a new log directory for the continuation and use --log-dir and --redirect to tell it to put all new files in the new directory. But it doesn't work. Instead it ignores those options and uses the same log directory as the original training run, deleting and overwriting the existing logs in the process. To prevent that, you first need to copy your existing log directory to a new location. I've several times lost work by forgetting to do that.

    How about making it so that --load-model does not override --log-dir and --redirect? That's just telling it what model to load. It wouldn't prevent you from saving logs to a different directory.

    opened by peastman 3
  • Error resuming training

    Error resuming training

    I just encountered an error I've never seen before. I used the --load-model command line argument to resume training from a checkpoint. At first everything seemed to be working correctly, but after completing four epochs it exited with this error.

    Traceback (most recent call last):
      File "/global/homes/p/peastman/torchmd-net/scripts/train.py", line 164, in <module>
        main()
      File "/global/homes/p/peastman/torchmd-net/scripts/train.py", line 160, in main
        trainer.test(model, data)
      File "/global/homes/p/peastman/miniconda3/envs/torchmd/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 936, in test
        return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
      File "/global/homes/p/peastman/miniconda3/envs/torchmd/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
        return trainer_fn(*args, **kwargs)
      File "/global/homes/p/peastman/miniconda3/envs/torchmd/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 983, in _test_impl
        results = self._run(model, ckpt_path=self.ckpt_path)
      File "/global/homes/p/peastman/miniconda3/envs/torchmd/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1222, in _run
        self._log_hyperparams()
      File "/global/homes/p/peastman/miniconda3/envs/torchmd/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1277, in _log_hyperparams
        raise MisconfigurationException(
    pytorch_lightning.utilities.exceptions.MisconfigurationException: Error while merging hparams: the keys ['load_model'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.
    
    opened by peastman 1
  • Clarifications of the method

    Clarifications of the method

    Hello, after reading the paper, I had several questions regarding your approach. Thanks a lot in advance for taking the time to answer them.

    Your embedding layer is more complex than usual: your initial node representation already seems to depend on its neighbour’s representation.

    • Is this beneficial ? Have you done experiments to show it ?

    Graph construction: you use a smooth cutoff function and describe some benefits. You describe a Transformers but still use a cutoff value.

    • Is that statement correct ? Why ? So we do not capture long-range dependencies, right ? Is the smooth cutoff beneficial — you have seen something empirically to either motivate it or show its benefits ?

    You say the feature vector are passed through a normalization layer.

    • Can you explain ? Including some motivation maybe.

    An intermediate node embedding (y_i) utilising attention scores is created and impact final x_i and v_i embeddings. This step weights a projection of each neighbor’s representation ~ $a_{ij} (W \cdot RBF(d_{ij}) \cdot \vec{V}_j)$ by the attention score.

    • You use interatomic distances twice, don’t you ? Is weighting only by attention not enough theoretically ?

    The equivariant message m_ij (component of sum to obtain w_i) is obtained by multiplying s_ij^2 (i.e. v_j scaled by RBF(d_ij)) by the directional info r_ij; then adding to it s_ij^1 (i.e. v_j scaled by RBF(d_ij)) re-multiplied by v_j.

    • Do you think that multiplying the message sequentially by distance info and directional info is the best choice to embed both info. type ? Why not concatenate r_ij (r_i - r_j) and d_ij (norm of r_ij = distance) info and have a single operation for instance ?

    • Is multiplying s_ij^1 by v_j (again) necessary ? (first in s_ij then by multiplying element-wise s_ij to v_j)

    • IMPORTANT. r_ij has dimension 3 while s_ij^2 has dimension F. In Eq (11), how can you apply an element-wise multiplication ? Is it a typo ? How exactly do you combine these two quantities ? What’s your take on the best way to combine 3D info (directional vector) with existing embedding ? This is a true question I am interested in, if you have references or insights on this bit…

    Invariant representation involves the scalar product of the equivariant vector v_i, projected with matrix U1 by (U2 v_i).

    • What is the real benefit / aim of this scalar product ? Is a unique projection not enough ?
    question 
    opened by AlexDuvalinho 1
Releases(0.2.1)
Owner
TorchMD
TorchMD: A deep learning framework for molecular simulations
TorchMD
Collection of generative models in Pytorch version.

pytorch-generative-model-collections Original : [Tensorflow version] Pytorch implementation of various GANs. This repository was re-implemented with r

Hyeonwoo Kang 2.4k Dec 31, 2022
Scenic: A Jax Library for Computer Vision and Beyond

Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c

Google Research 1.6k Dec 27, 2022
Lyapunov-guided Deep Reinforcement Learning for Stable Online Computation Offloading in Mobile-Edge Computing Networks

PyTorch code to reproduce LyDROO algorithm [1], which is an online computation offloading algorithm to maximize the network data processing capability subject to the long-term data queue stability an

Liang HUANG 87 Dec 28, 2022
The official TensorFlow implementation of the paper Action Transformer: A Self-Attention Model for Short-Time Pose-Based Human Action Recognition

Action Transformer A Self-Attention Model for Short-Time Human Action Recognition This repository contains the official TensorFlow implementation of t

PIC4SeRCentre 20 Jan 03, 2023
A blender add-on that automatically re-aligns wrong axis objects.

Auto Align A blender add-on that automatically re-aligns wrong axis objects. Usage There are three options available in the 3D Viewport Sidebar It

29 Nov 25, 2022
Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms

AdvancedHMC.jl AdvancedHMC.jl provides a robust, modular and efficient implementation of advanced HMC algorithms. An illustrative example for Advanced

The Turing Language 167 Jan 01, 2023
Gesture-controlled Video Game. Just swing your finger and play the game without touching your PC

Gesture Controlled Video Game Detailed Blog : https://www.analyticsvidhya.com/blog/2021/06/gesture-controlled-video-game/ Introduction This project is

Devbrat Anuragi 35 Jan 06, 2023
Implementation for Learning to Track with Object Permanence

Learning to Track with Object Permanence A video-based MOT approach capable of tracking through full occlusions: Learning to Track with Object Permane

Toyota Research Institute - Machine Learning 91 Jan 03, 2023
Tesla Light Show xLights Guide With python

Tesla Light Show xLights Guide Welcome to the Tesla Light Show xLights guide! You can create and run your own light shows on Tesla vehicles. Running a

Tesla, Inc. 2.5k Dec 29, 2022
Code for our paper "Interactive Analysis of CNN Robustness"

Perturber Code for our paper "Interactive Analysis of CNN Robustness" Datasets Feature visualizations: Google Drive Fine-tuning checkpoints as saved m

Stefan Sietzen 0 Aug 17, 2021
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 04, 2023
The ARCA23K baseline system

ARCA23K Baseline System This is the source code for the baseline system associated with the ARCA23K dataset. Details about ARCA23K and the baseline sy

4 Jul 02, 2022
My implementation of Fully Convolutional Neural Networks in Keras

Keras-FCN This repository contains my implementation of Fully Convolutional Networks in Keras (Tensorflow backend). Currently, semantic segmentation c

The Duy Nguyen 15 Jan 13, 2020
This package implements THOR: Transformer with Stochastic Experts.

THOR: Transformer with Stochastic Experts This PyTorch package implements Taming Sparsely Activated Transformer with Stochastic Experts. Installation

Microsoft 45 Nov 22, 2022
Official PyTorch implementation of "Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks" (AAAI 2022)

Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks This is the code for reproducing the results of th

2 Dec 27, 2021
All course materials for the Zero to Mastery Machine Learning and Data Science course.

Zero to Mastery Machine Learning Welcome! This repository contains all of the code, notebooks, images and other materials related to the Zero to Maste

Daniel Bourke 1.6k Jan 08, 2023
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

22 Sep 22, 2022
Official code repository of the paper Learning Associative Inference Using Fast Weight Memory by Schlag et al.

Learning Associative Inference Using Fast Weight Memory This repository contains the offical code for the paper Learning Associative Inference Using F

Imanol Schlag 18 Oct 12, 2022
PyTorch implementation for paper StARformer: Transformer with State-Action-Reward Representations.

StARformer This repository contains the PyTorch implementation for our paper titled StARformer: Transformer with State-Action-Reward Representations.

Jinghuan Shang 14 Dec 09, 2022
Implementation of CVAE. Trained CVAE on faces from UTKFace Dataset to produce synthetic faces with a given degree of happiness/smileyness.

Conditional Smiles! (SmileCVAE) About Implementation of AE, VAE and CVAE. Trained CVAE on faces from UTKFace Dataset. Using an encoding of the Smile-s

Raúl Ortega 3 Jan 09, 2022