mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms.

Related tags

Deep Learningmbrl-lib
Overview

Master License: MIT Python 3.7+ Code style: black

MBRL-Lib

mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms. It provides easily interchangeable modeling and planning components, and a set of utility functions that allow writing model-based RL algorithms with only a few lines of code.

See also our companion paper.

Getting Started

Installation

mbrl-lib is a Python 3.7+ library. To install it, clone the repository,

git clone https://github.com/facebookresearch/mbrl-lib.git

then run

cd mbrl-lib
pip install -e .

If you are interested in contributing, please install the developer tools as well

pip install -e ".[dev]"

Finally, make sure your Python environment has PyTorch (>= 1.7) installed with the appropriate CUDA configuration for your system.

For testing your installation, run

python -m pytest tests/core
python -m pytest tests/algorithms

Mujoco

Mujoco is a popular library for testing RL methods. Installing Mujoco is not required to use most of the components and utilities in MBRL-Lib, but if you have a working Mujoco installation (and license) and want to test MBRL-Lib on it, please run

pip install -r requirements/mujoco.txt

and to test our mujoco-related utilities, run

python -m pytest tests/mujoco

Basic example

As a starting point, check out our tutorial notebook on how to write the PETS algorithm (Chua et al., NeurIPS 2018) using our toolbox, and running it on a continuous version of the cartpole environment.

Provided algorithm implementations

MBRL-Lib provides implementations of popular MBRL algorithms as examples of how to use this library. You can find them in the mbrl/algorithms folder. Currently, we have implemented PETS and MBPO, and we plan to keep increasing this list in the near future.

The implementations rely on Hydra to handle configuration. You can see the configuration files in this folder. The overrides subfolder contains environment specific configurations for each environment, overriding the default configurations with the best hyperparameter values we have found so far for each combination of algorithm and environment. You can run training by passing the desired override option via command line. For example, to run MBPO on the gym version of HalfCheetah, you should call

python main.py algorithm=mbpo overrides=mbpo_halfcheetah 

By default, all algorithms will save results in a csv file called results.csv, inside a folder whose path looks like ./exp/mbpo/default/gym___HalfCheetah-v2/yyyy.mm.dd/hhmmss; you can change the root directory (./exp) by passing root_dir=path-to-your-dir, and the experiment sub-folder (default) by passing experiment=your-name. The logger will also save a file called model_train.csv with training information for the dynamics model.

Beyond the override defaults, You can also change other configuration options, such as the type of dynamics model (e.g., dynamics_model=basic_ensemble), or the number of models in the ensemble (e.g., dynamics_model.model.ensemble_size=some-number). To learn more about all the available options, take a look at the provided configuration files.

Note that running the provided examples and main.py requires Mujoco, but you can try out the library components (and algorithms) on other environments by creating your own entry script and Hydra configuration.

Visualization tools

Our library also contains a set of visualization tools, meant to facilitate diagnostics and development of models and controllers. These currently require Mujoco installation, but we are planning to add more support and extensions in the future. Currently, the following tools are provided:

  • Visualizer: Creates a video to qualitatively assess model predictions over a rolling horizon. Specifically, it runs a user specified policy in a given environment, and at each time step, computes the model's predicted observation/rewards over a lookahead horizon for the same policy. The predictions are plotted as line plots, one for each observation dimension (blue lines) and reward (red line), along with the result of applying the same policy to the real environment (black lines). The model's uncertainty is visualized by plotting lines the maximum and minimum predictions at each time step. The model and policy are specified by passing directories containing configuration files for each; they can be trained independently. The following gif shows an example of 200 steps of pre-trained MBPO policy on Inverted Pendulum environment.

    Example of Visualizer

  • DatasetEvaluator: Loads a pre-trained model and a dataset (can be loaded from separate directories), and computes predictions of the model for each output dimension. The evaluator then creates a scatter plot for each dimension comparing the ground truth output vs. the model's prediction. If the model is an ensemble, the plot shows the mean prediction as well as the individual predictions of each ensemble member.

    Example of DatasetEvaluator

  • FineTuner: Can be used to train a model on a dataset produced by a given agent/controller. The model and agent can be loaded from separate directories, and the fine tuner will roll the environment for some number of steps using actions obtained from the controller. The final model and dataset will then be saved under directory "model_dir/diagnostics/subdir", where subdir is provided by the user.

  • True Dynamics Multi-CPU Controller: This script can run a trajectory optimizer agent on the true environment using Python's multiprocessing. Each environment runs in its own CPU, which can significantly speed up costly sampling algorithm such as CEM. The controller will also save a video if the render argument is passed. Below is an example on HalfCheetah-v2 using CEM for trajectory optimization.

    Control Half-Cheetah True Dynamics

Note that the tools above require Mujoco installation, and are specific to models of type OneDimTransitionRewardModel. We are planning to extend this in the future; if you have useful suggestions don't hesitate to raise an issue or submit a pull request!

Documentation

Please check out our documentation and don't hesitate to raise issues or contribute if anything is unclear!

License

mbrl-lib is released under the MIT license. See LICENSE for additional details about it. See also our Terms of Use and Privacy Policy.

Citing

If you use this project in your research, please cite:

@Article{Pineda2021MBRL,
  author  = {Luis Pineda and Brandon Amos and Amy Zhang and Nathan O. Lambert and Roberto Calandra},
  journal = {Arxiv},
  title   = {MBRL-Lib: A Modular Library for Model-based Reinforcement Learning},
  year    = {2021},
  url     = {https://arxiv.org/abs/2104.10159},
}
Comments
  • Feature pybullet

    Feature pybullet

    Continuation of incomplete PR from https://github.com/facebookresearch/mbrl-lib/pull/87 This is my first time contributing to an open-source project so any advice is welcome, technical or otherwise

    Types of changes

    • [x] Docs change / refactoring / dependency upgrade
    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [x] New feature (non-breaking change which adds functionality)
    • [x] Breaking change (fix or feature that would cause existing functionality to change)

    Motivation and Context / Related issue

    This adds support for PyBullet, an open-source alternative to MuJoCo. MuJoCo-compatible and RobotSchool environments are supported via pybullet-gym.

    How Has This Been Tested (if it applies)

    python -m pytest tests/pybullet

    Checklist

    • [x] The documentation is up-to-date with the changes I made.
    • [x] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
    • [ ] All tests passed, and additional code has been covered with new tests.
    CLA Signed 
    opened by dtch1997 44
  • Add trajectory-based dynamics model

    Add trajectory-based dynamics model

    TODO for this WIP PR:

    • [x] New PID based / linear feedback agent(s)
    • [ ] Make PID accept vector inputs
    • [x] Training example
    • [ ] Migrate example to colab
    • [ ] Add tests

    Types of changes

    • [ ] Docs change / refactoring / dependency upgrade
    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [x] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)

    Motivation and Context / Related issue

    I'm collaborating with some folks on Berkeley looking to apply the trajectory-based model to real world robotics, so I wanted to integrate it into this library to give it more longevity.

    The paper is here. The core of the paper is proposing a long-term prediction focused dynamics model. The parametrization is:

    $$ s_{t+1} = f_\theta(s_0, t, \phi),$$

    where $\phi$ are closed form control parameters (e.g. PID)

    Potentially this #66 , I think we will need to modify the replay buffer to

    • store control parameter vector
    • store time indices (which may be close with the trajectory formulation)

    How Has This Been Tested (if it applies)

    I am going to build a notebook to validate and demonstrate it, currently it is a fork of the PETS example. I will iterate

    Checklist

    • [ ] The documentation is up-to-date with the changes I made.
    • [x] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
    • [ ] All tests passed, and additional code has been covered with new tests.
    CLA Signed 
    opened by natolambert 19
  • MBPO cannot work on HumanoidTruncatedObsEnv and original Humanoid Env[Bug]

    MBPO cannot work on HumanoidTruncatedObsEnv and original Humanoid Env[Bug]

    Steps to reproduce

    1. I tried to run MBPO on HumanoidTruncatedObsEnv with the default parameters in this repo but the final reward is around 180(seems like random policy and not work)
    2. I tried to run MBPO on original Humanoid env(without truncated obs) and still cannot work

    and I have tried different seeds and they all cannot work

    Observed Results

    • The results of episode reward :

    image

    Expected Results

    • The expected results (episode reward) may around 6k
    bug 
    opened by jity16 18
  • [Bug] PETS not working

    [Bug] PETS not working

    Steps to reproduce

    1. install mbrl with python3.8 & mujoco_py 2.0.2.0
    2. python -m mbrl.examples.main algorithm=pets overrides=pets_halfcheetah

    Observed Results

    env_step,episode_reward,step 1000.0,-224.74164192363065,1 2000.0,-216.55716608141833,2 3000.0,-23.61229154142554,3 4000.0,-226.04264782442579,4 5000.0,299.97272326884257,5 6000.0,-424.2352836475372,6 7000.0,-605.4988140825888,7 8000.0,-276.8960448750668,8 9000.0,-570.0111469500497,9 10000.0,-510.15227529837796,10 11000.0,-521.2191905188236,11 12000.0,-380.6738015630948,12 13000.0,-401.0656166902861,13 14000.0,-342.89326195274214,14 15000.0,-387.0973047072805,15 16000.0,271.654545187927,16 17000.0,-357.9662191309233,17 18000.0,-144.4911364581224,18 19000.0,-227.65608581868534,19 20000.0,-270.1466421280269,20 21000.0,-218.2495164661332,21 22000.0,-291.59770272027646,22 23000.0,5.605493817390425,23 24000.0,-260.5804876267262,24 25000.0,-311.1006996761441,25 26000.0,-87.68273024315891,26 27000.0,-224.6058292677028,27 28000.0,-243.66672977662145,28 29000.0,-417.3611859069211,29 30000.0,-205.45597669987774,30 31000.0,-220.6631462332176,31 32000.0,-306.92107250798256,32 33000.0,-321.6192194136308,33 34000.0,156.56899647240394,34 35000.0,-373.6946869809165,35 36000.0,-297.54081355112413,36 37000.0,-403.86887923659464,37 38000.0,-394.61809157238,38 39000.0,-397.597218596027,39 40000.0,-270.5546716816992,40 41000.0,-275.0500238719418,41 42000.0,-339.1503604637613,42 43000.0,-394.371951392158,43 44000.0,-284.8456374765922,44 45000.0,-230.30455468451476,45 46000.0,-452.69669066476587,46 47000.0,-369.8052064885858,47 48000.0,-277.8216601977107,48 49000.0,83.44271984210994,49 50000.0,-165.98679718221237,50 51000.0,-286.4235189537889,51 52000.0,-420.1238034618763,52 53000.0,-348.4956325925755,53 54000.0,-262.9499726805828,54 55000.0,-82.70856034802993,55 56000.0,-283.44756999937294,56 57000.0,-296.14589401299133,57 58000.0,-310.71395667647914,58 59000.0,-92.32547170477757,59 60000.0,-343.62926472041903,60 61000.0,194.0718436837866,61 62000.0,-449.34500076620725,62 63000.0,-317.03787784175205,63 64000.0,-203.2571831873085,64 65000.0,-90.52911874178189,65 66000.0,-188.53310534801767,66 67000.0,-131.71672373665217,67 68000.0,-241.95741966590174,68 69000.0,-329.25808904770525,69 70000.0,-146.0802349071957,70 71000.0,-474.47665284478336,71 72000.0,-191.43021635327702,72

    Expected Results

    like results in #97

    bug 
    opened by sofan110 18
  • Pddm

    Pddm

    (WIP) PDDM implementation

    • [x] Docs change / refactoring / dependency upgrade
    • [x] New feature (non-breaking change which adds functionality)

    Motivation and Context / Related issue

    PR for PDDM's MPPI planner, support for sequenced batches, and in the near future proper settings and benchmarks for MuJoCo environments.

    Checklist

    • [x] The documentation is up-to-date with the changes I made.
    • [x] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
    • [x] MPPI planner
    • [x] MPPI refinement iterations
    • [x] PDDM
    • [x] Support for sequenced batches
    • [x] Multistage Gaussian MLP loss
    • [x] Testing for MPPI planer and PDDM
    • [ ] Benchmarks/Tuning and comparisons with the original implementation
    CLA Signed 
    opened by freiberg-roman 13
  • Training browser

    Training browser

    Types of changes

    Adds a simple browser to chart training results from multiple runs

    • [ ] Docs change / refactoring / dependency upgrade
    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [X] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)

    Motivation and Context / Related issue

    Adds a quick and easy way to browse/compare results

    How Has This Been Tested (if it applies)

    I ran a few different training runs, with different algorithms and use this to compare them

    Checklist

    • [ ] The documentation is up-to-date with the changes I made.
    • [X] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
    • [ ] All tests passed, and additional code has been covered with new tests.
    CLA Signed 
    opened by a3ahmad 12
  • Support pybullet-based Gym Environments

    Support pybullet-based Gym Environments

    Don't accept this yet -- this is still a work-in-progress. Remaining work:

    General-purpose environment loader:

    • [ ] Agree on interface
    • [ ] Refactor mujoco.py

    Add support for freezing environments:

    • [X] Locomotors
    • [ ] Manipulators
    • [ ] Pendula

    Add documentation for:

    • [X] Installing/using PyBullet
    • [ ] Various functions in mujoco.py
    • [ ] Comparing RobotSchool and MuJoCo-compatible PyBullet environments.

    Tests:

    • [X] Freezing environments.
    • [ ] Comparison between MuJoCo-compatible PyBullet and actual MuJoCo environments.

    Other:

    • [ ] Gracefully handle case that PyBullet is not installed.
    • [ ] Properly package pybullet-gym
      • [ ] setup.py needs to copy 3d assets as well.
      • [ ] (Optional) Put it on Pip

    Types of changes

    • [X] Docs change / refactoring / dependency upgrade
    • [ ] Bug fix (non-breaking change which fixes an issue)
    • [X] New feature (non-breaking change which adds functionality)
    • [ ] Breaking change (fix or feature that would cause existing functionality to change)

    Motivation and Context / Related issue

    This adds support for PyBullet, an open-source alternative to MuJoCo. MuJoCo-compatible and RobotSchool environments are supported via pybullet-gym.

    How Has This Been Tested (if it applies)

    Using this for research.

    Checklist

    • [ ] The documentation is up-to-date with the changes I made.
    • [ ] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
    • [ ] All tests passed, and additional code has been covered with new tests.
    CLA Signed 
    opened by gauravmm 9
  • Difference in PETS implementation from the original TF version.

    Difference in PETS implementation from the original TF version.

    This follows from the conversation in #98. I have noticed some discrepancy between the TF and mbrl-lib implementation of PETS.

    Difference in normalization.

    https://github.com/kchua/handful-of-trials/blob/master/dmbrl/modeling/utils/TensorStandardScaler.py#L45

    In the original version, the normalization is guarded against observation dimensions with small stddev by setting the dimensions with small stddev to 1. This prevents the normalized inputs from exploding when the stddev is small. This happens in environments such as Reacher or Pusher where some observation dimensions consist of goals. In that situation, it seems that the goal is never changing during an episode and the stddev will be 0. Hence setting the small stddev to be 1.0 would be helpful in that case.

    Another very subtle thing happening in the above code is that the normalization is performed with NumPy instead of in TF, and I think the inputs here are in float64. In that case, the stddev computation is more accurate than those in float32, so the threshold 1e-12 is sensible. Using PyTorch to perform normalization, for example, would require changes to the threshold. I think some values like 1e-5 would be more appropriate in that case (not backed up by any numerical analysis).

    Difference in activation function

    The original implementation uses the swish activation function whereas in mbrl-lib we use silu. I am confused about the choice of silu in mbrl-lib and would love to know more about the difference in empirical performance.

    Difference in CEM stopping criteria

    In the TF implementation, the CEM optimizer uses an additional termination criterion on the variance: https://github.com/kchua/handful-of-trials/blob/77fd8802cc30b7683f0227c90527b5414c0df34c/dmbrl/misc/optimizers/cem.py#L71 I doubt that criterion is ever satisfied during training but I am mentioning this here for completeness.

    Difference in optimizer weight decay

    The original TF implementation uses a carefully selected set of weight decays for different layers of the dynamics model whereas the decay in mbrl-lib is the same for all layers. However, the original implementation does not add weight decays on the biases. See

    https://github.com/kchua/handful-of-trials/blob/master/dmbrl/modeling/layers/FC.py#L219

    In PyTorch, the default Adam will add weight decay on all parameters. That also means that they are added to the max_logvar and min_logvar whereas in the TF version the only regularization on the max/min-logvars is through the var_loss.

    Maybe a side note, have the authors tried using AdamW instead of Adam for the weight decays? I recently learned that naive weight decay in Adam does not behave as you may expect. See https://arxiv.org/abs/1711.05101

    Difference in optimizer parameters

    The default epsilon in TensorFlow's Adam is 1e-7, https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam Scratch this, they are 1e-8 in TF 1 https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras/optimizers/Adam.

    Anyway, I am mentioning these here after a thorough look at both mbrl-lib and TF PETS to debug my own JAX implementation. Turns out my mistake was in the MPC code. I hope these notes are useful since the author mentions that currently, the current implementation does not get good performance on Half-Cheetah. Maybe it's because of one of these details, if not, fingers crossed the difference can be spotted by someone else :D

    opened by ethanluoyc 8
  • Using Wrapper Class for Custom GYM Env

    Using Wrapper Class for Custom GYM Env

    I have a custom open AI gym env and I am trying to use mbrl wrapper but getting error name 'model_env_args' is not defined. I am trying to follow example here, https://arxiv.org/pdf/2104.10159.pdf. Here's my code.

    import gym import mbrl.models as models import numpy as np net = models.GaussianMLP(in_size=14, out_size=12, device="cpu") wrapper = models.OneDTransitionRewardModel(net, target_is_delta=True, learned_rewards=True) model_env = models.ModelEnv(wrapper, *model_env_args, term_fn=hopper)

    opened by MishraIN 7
  • [Feature Request] Logging of custom training metrics

    [Feature Request] Logging of custom training metrics

    🚀 Feature Request

    When training a model with ModelTrainer, it would be nice to be able to log some custom metrics (ideally in tensorboard), defined by the model (e.g., the values of the individual loss terms if the loss of the model is a sum of multiple terms). Right now one can only access the overall loss of the model.

    Motivation

    Is your feature request related to a problem? Please describe.

    At the moment I am working on a model that optimizes a sum of reconstruction loss, reward prediction loss, and a kl divergence term. For debugging purposes it would be nice to monitor how the individual losses evolve over time. This logging can not be done by the model class on its own since it needs some information from the RL algorithm (e.g. the current iteration of the algorithm / number of samples drawn from the environment) for the logged values to be meaningful.

    Pitch

    Describe the solution you'd like

    The simplest solution certainly is to just allow passing kwargs to ModelTrainer.train(), which are passed through to Model.update(). This would allow to pass some custom logging function / object that then logs values passed by the model implementation. This is of course not the most elegant solution, but the kwargs could also be used for other purposes (e.g. passing some additional information to Model.update() if a model implementation requires this).

    Describe alternatives you've considered

    An alternative to this would be to let Model.update() return a dictionary of metrics in addition to the loss. This dictionary could then be returned by ModelTrainer.train() or it could be processed by the callback passed to the function. This would of course cause breaking changes since the method signature of Model would need to be changed.

    Are you willing to open a pull request? (See CONTRIBUTING) Yes

    enhancement 
    opened by jan1854 7
  • pets_example.ipynb problem

    pets_example.ipynb problem

    i run the pets_example.ipynb and what i get the following error:

    i am not sure if it's my package's compatible problem. so i am not sure following error is bug or not. python:3.7.10 nmupy: 1.20.1 matplotlib: 3.4.2 torch:1.7.1 py3.7_cuda10.1.243_cudnn7.6.3_0

    TypeError: normal() received an invalid combination of arguments when run the main loop i found the model_env arg 'rng' is np.random.default_rng(seed=0), not torch.normal

    # Create a gym-like environment to encapsulate the model
    #model_env = models.ModelEnv(env, dynamics_model, term_fn, reward_fn, rng)
    

    TypeError: can't convert cuda:0 device type tensor to numpy. when run the plot part when the gpu is on, val_score tensor is (0.0023, device='cuda:0') and cause error in plot part

    def train_callback(_model, _total_calls, _epoch, tr_loss, val_score, _best_val):
       train_losses.append(tr_loss)
       #val_scores.append(val_score.mean())   # this returns val score per ensemble model
    
    opened by app1ep1e 7
  • [Bug] Centering, scaling and clamping the population in iCEM

    [Bug] Centering, scaling and clamping the population in iCEM

    Steps to reproduce

    1. Run any example configuration using iCEM as action optimizer, e.g. python -m mbrl.examples.main algorithm=mbpo overrides=pets_icem_cartpole

    Observed Results

    After sampling according to a powerlaw PSD in iCEM, the population is centered on the mean, scaled to the variance and clamped to be within the action space. This process uses the dummy variable population2. However, it appears that the result is not assigned back to the population variable, and it is hence ignored during the rest of the optimization procedure. As a result, I believe that the population is not correctly sampled, and the objective function can be evaluated on actions that potentially do not belong to the action space.

    Expected Results

    Centering, scaling and clamping should be applied directly to population instead of population2.

    Relevant Code

    The relevant lines are L438-L441 in mbrl/planning/trajectory_opt.py

    https://github.com/facebookresearch/mbrl-lib/blob/f90a29743894fd6db05e73445af0ed83baa845bc/mbrl/planning/trajectory_opt.py#L438-L441

    which I believe could be changed to

              population = torch.minimum(
                  population * torch.sqrt(var) + mu, self.upper_bound
              )
              population = torch.maximum(population, self.lower_bound)
    
    bug 
    opened by marbaga 0
  • [WIP] HF Hub Integration

    [WIP] HF Hub Integration

    Working towards closing #169

    Things to do (roughly):

    • Verify base functionality,
    • Colab example for loading / saving / visualizing models,
    • Upload pretrained models to hub from @luisenp.
    CLA Signed 
    opened by natolambert 1
  • [Feature Request] Upload Dynamics Models to the HuggingFace Hub

    [Feature Request] Upload Dynamics Models to the HuggingFace Hub

    🚀 Feature Request

    Add functionality to upload dynamics models /policies to the HF hug at end of training or during training for sharing / fine-tuning.

    This would like like

    model.from_pretrained("mbrl/cheetah.bin")
    model.save_pretrained("mbrl/hopper.bin")
    

    Motivation

    We want to be able to re-use computation and make easier demo's showcasing this library.

    Happy to help with this.

    Additional context

    Add any other context or screenshots about the feature request here.

    enhancement 
    opened by natolambert 6
  • hyperparameters optimization

    hyperparameters optimization

    🚀 Feature Request

    I would like to optimize the hyperparameters on a custom environment for PE-TS and other algorithms.

    Motivation

    How did you find the optimal hyperparameters for the algorithms? for example PE-TS cartpole

    Pitch

    PE-TS example I did the grid search for 4 parameters: horizon_size, alpha, number of hidden layers, hidden layer dimension.

    problems: what parameters are more crutial to optimize.

    Do you have bayesian optimisation script for hyperparamters

    Describe alternatives you've considered I can make a pull request for the PE-TS grid search or/and bayesian optmization with optuna library.

    enhancement 
    opened by ss555 1
  • [Feature Request] Output Normalization / Scaling

    [Feature Request] Output Normalization / Scaling

    🚀 Feature Request

    When training non delta-state models, the outputs of dynamics models can take large values (way outside a unit Gaussian). In the past I have tried using output scalars to let the outputs try to learn something close to a unit Gaussian rather than variables with diverse scales.

    Motivation

    Is your feature request related to a problem? Please describe. I think it would help the PR for the trajectory-based model, #158 .

    Pitch

    Describe the solution you'd like I think there could be an optional output scalar that acts normally to the input one?

    Are you willing to open a pull request? (See CONTRIBUTING) Sure.

    Additional context

    Add any other context or screenshots about the feature request here.

    enhancement 
    opened by natolambert 4
  • [Feature Request] Add option to use `functorch` for `BasicEnsemble`

    [Feature Request] Add option to use `functorch` for `BasicEnsemble`

    🚀 Feature Request

    Change BasicEnsemble to optionally use functorch.vmap.

    Motivation and Pitch

    Is your feature request related to a problem? Please describe.

    BasicEnsemble lets the user provide arbitrary models, which are stacked together using a very naive loop-based implementation. We should be able to do this more efficiently now using functorch.

    enhancement good first issue 
    opened by luisenp 2
Releases(v0.1.5)
  • v0.1.5(Jan 14, 2022)

    • Fixes important bug in v0.1.4 that was causing PETS to break.
    • Model.reset() and Model.sample() signature has changed. They no longer receive TransitionBatch objects, and they both return a dictionary of strings to tensors representing a model state that should be passed to sample() to simulate transitions. This dictionary can contain things like previous actions, predicted observation, latent states, beliefs, and any other such quantity that the model need to maintain to simulate trajectories when using ModelEnv.
    • Ensemble class and sub-classes are assumed to operate on 1-D models.
    Source code(tar.gz)
    Source code(zip)
  • v0.1.4(Sep 27, 2021)

    This version adds two new optimizers for CEM:

    • Improved CEM as described here.
    • MPPI as used in PDDM.
    • Changed config structure so that action optimizer is passed as another config file.
    • Added a new iterator for sequences that returns a fixed number of random batches in every loop.
    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Jul 24, 2021)

    This version changes the Model API so that loss, eval_score and update methods return a metadata dictionary that can be used for logging. It also adds the option to use double precision for normalization.

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Jul 19, 2021)

Owner
Facebook Research
Facebook Research
ALFRED - A Benchmark for Interpreting Grounded Instructions for Everyday Tasks

ALFRED A Benchmark for Interpreting Grounded Instructions for Everyday Tasks Mohit Shridhar, Jesse Thomason, Daniel Gordon, Yonatan Bisk, Winson Han,

ALFRED 204 Dec 15, 2022
Seg-Torch for Image Segmentation with Torch

Seg-Torch for Image Segmentation with Torch This work was sparked by my personal research on simple segmentation methods based on deep learning. It is

Eren Gölge 37 Dec 12, 2022
基于Paddle框架的arcface复现

arcface-Paddle 基于Paddle框架的arcface复现 ArcFace-Paddle 本项目基于paddlepaddle框架复现ArcFace,并参加百度第三届论文复现赛,将在2021年5月15日比赛完后提供AIStudio链接~敬请期待 参考项目: InsightFace Padd

QuanHao Guo 16 Dec 15, 2022
The official homepage of the (outdated) COCO-Stuff 10K dataset.

COCO-Stuff 10K dataset v1.1 (outdated) Holger Caesar, Jasper Uijlings, Vittorio Ferrari Overview Welcome to official homepage of the COCO-Stuff [1] da

Holger Caesar 263 Dec 11, 2022
Graph Convolutional Networks for Temporal Action Localization (ICCV2019)

Graph Convolutional Networks for Temporal Action Localization This repo holds the codes and models for the PGCN framework presented on ICCV 2019 Graph

Runhao Zeng 318 Dec 06, 2022
Official implementation of the paper Chunked Autoregressive GAN for Conditional Waveform Synthesis

PyEmits, a python package for easy manipulation in time-series data. Time-series data is very common in real life. Engineering FSI industry (Financial

Descript 150 Dec 06, 2022
Neural machine translation between the writings of Shakespeare and modern English using TensorFlow

Shakespeare translations using TensorFlow This is an example of using the new Google's TensorFlow library on monolingual translation going from modern

Motoki Wu 245 Dec 28, 2022
On-device wake word detection powered by deep learning.

Porcupine Made in Vancouver, Canada by Picovoice Porcupine is a highly-accurate and lightweight wake word engine. It enables building always-listening

Picovoice 2.8k Dec 29, 2022
Exploration of some patients clinical variables.

Answer_ALS_clinical_data Exploration of some patients clinical variables. All the clinical / metadata data is available here: https://data.answerals.o

1 Jan 20, 2022
Breaching - Breaching privacy in federated learning scenarios for vision and text

Breaching - A Framework for Attacks against Privacy in Federated Learning This P

Jonas Geiping 139 Jan 03, 2023
A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

A boosting-based Multiple Instance Learning (MIL) package that includes MIL-Boost and MCIL-Boost

Jun-Yan Zhu 27 Aug 08, 2022
Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion

Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion Preface This directory provides an implementation of the algori

Jean-Samuel Leboeuf 0 Nov 03, 2021
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
Text-to-Music Retrieval using Pre-defined/Data-driven Emotion Embeddings

Text2Music Emotion Embedding Text-to-Music Retrieval using Pre-defined/Data-driven Emotion Embeddings Reference Emotion Embedding Spaces for Matching

Minz Won 50 Dec 05, 2022
Deep Learning Package based on TensorFlow

White-Box-Layer is a Python module for deep learning built on top of TensorFlow and is distributed under the MIT license. The project was started in M

YeongHyeon Park 7 Dec 27, 2021
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Perfect implement. Model shared. x0.5 (Top1:60.646) and 1.0x (Top1:69.402).

Shufflenet-v2-Pytorch Introduction This is a Pytorch implementation of faceplusplus's ShuffleNet-v2. For details, please read the following papers:

423 Dec 07, 2022
Contrastive Learning with Non-Semantic Negatives

Contrastive Learning with Non-Semantic Negatives This repository is the official implementation of Robust Contrastive Learning Using Negative Samples

39 Jul 31, 2022
This is the repository for Learning to Generate Piano Music With Sustain Pedals

SusPedal-Gen This is the official repository of Learning to Generate Piano Music With Sustain Pedals Demo Page Dataset The dataset used in this projec

Joann Ching 12 Sep 02, 2022
Baseline of DCASE 2020 task 4

Couple Learning for SED This repository provides the data and source code for sound event detection (SED) task. The improvement of the Couple Learning

21 Oct 18, 2022