PyTorch-centric library for evaluating and enhancing the robustness of AI technologies


Responsible AI Toolbox

A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainability of AI models.

Check out our documentation for more information.

The rAI-toolbox works great with PyTorch Lightning and Hydra 🐉 . Check out rai_toolbox.mushin to see how we use these frameworks to create efficient, configurable, and reproducible ML workflows with minimal boilerplate code.


Using rai_toolbox for your research? Please cite the following publication:

  title={Tools and Practices for Responsible AI Engineering},
  author={Soklaski, Ryan and Goodwin, Justin and Brown, Olivia and Yee, Michael and Matterer, Jason},
  journal={arXiv preprint arXiv:2201.05647},


If you would like to contribute to this repo, please refer to our document.


DISTRIBUTION STATEMENT A. Approved for public release. Distribution is unlimited.


  • Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014)
  • SPDX-License-Identifier: MIT

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

A portion of this research was sponsored by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator and was accomplished under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein.

The software/firmware is provided to you on an As-Is basis.

  • Update workflows

    Update workflows

    See example use here:

    • [x] Create base class for workflows
    • [x] Update docs
    • [x] Create tests for workflows
  • Strange computational graph issue with `gradient_ascent` and `LightningModule`

    Strange computational graph issue with `gradient_ascent` and `LightningModule`

    First here's a working simple example of running gradient_ascent that works without error:

    from functools import partial
    import torch as tr
    from torchvision import models
    from rai_toolbox.optim import L2ProjectedOptim
    from rai_toolbox.perturbations.solvers import gradient_ascent
    model = models.resnet18()
    data = tr.rand(10, 3, 100, 100, dtype=tr.float)
    target = tr.randint(0, 2, size=(10,))
    pert = partial(
        gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
    # run gradient ascent
    pert(model=model, data=data, target=target)

    Now setup and run the same thing using Trainer.predict:

    import pytorch_lightning as pl
    class Lit(pl.LightningModule):
        def __init__(self):
            self.model = model
            self.pert = pert
        def predict_step(self, batch, *args, **kwargs):
            data, target = batch
            data = self.pert(model=self.model, data=data, target=target)
            logits = self.model(data)
            return logits.sum()
    trainer = pl.Trainer()
  , target),

    Here we get the following error:

    /tmp/ipykernel_74682/ in predict_step(self, batch, *args, **kwargs)
         27     def predict_step(self, batch, *args, **kwargs):
         28         data, target = batch
    ---> 29         data = self.pert(model=self, data=data, target=target)
         30         logits = self.model(data)
         31         return logits.sum()
    ~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/ in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
        277             # Update the perturbation
        278             optim.zero_grad(set_to_none=True)
    --> 279             loss.backward()
        280             optim.step()
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/ in backward(self, gradient, retain_graph, create_graph, inputs)
        394                 create_graph=create_graph,
        395                 inputs=inputs)
    --> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
        398     def register_hook(self, hook):
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/ in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
        171     # some Python versions print out the first line of a multi-line function
        172     # calls in the traceback and some print out the last line
    --> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
        174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
        175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

    If I enter debug everything seems to be setup correctly except that pmodel(data) does not return a tensor with grad_fn!!

    # pdb at `loss.backward()` line
    > tr.is_grad_enabled()
    > + data
    ... # tensor output without `grad_fn`
    # try reinitializing
    > perturbation_model(data)(data)
    ... # tensor output WITH `grad_fn`

    I have no idea how to debug this and find out what is wrong.

    @rsokl do you get this error in your environment?

  • Docs: perturbation explanation

    Docs: perturbation explanation

    Starting an explanation on our approach to data perturbations. I still intend to add more to this today, but feel free to take a look and let me know your thoughts on how it's going so far. Especially what should/shouldn't be included in this

  • CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook attempts to load the pretrained models and fails with ModuleNotFoundError: No module named 'dill'. Dill module is not included in the standard rai-toolbox[mushin] install.

    Full error traceback below:

    ModuleNotFoundError                       Traceback (most recent call last)
    Input In [9], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = ""
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in load(f, map_location, pickle_module, **pickle_load_args)
        711             return torch.jit.load(opened_file)
        712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
        928 unpickler = UnpicklerWrapper(f, **pickle_load_args)
        929 unpickler.persistent_load = persistent_load
    --> 930 result = unpickler.load()
        932 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
        934 offset = f.tell() if f_should_read_directly else None
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in _legacy_load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
        744     except KeyError:
        745         pass
    --> 746 return super().find_class(mod_name, name)
    ModuleNotFoundError: No module named 'dill'

    Installing dill via pip install dill in the python environment corrects this error.

  • test for ensuring hydra ddp raises is raising for the wrong reason

    test for ensuring hydra ddp raises is raising for the wrong reason


    In the following test:

    launch(Config, pl_main_task) raises TypeError because pl_main_task doesn't accept a single config (pyright warned me about this). I doubt this is what you meant to exercise in this test.

    I am confused about what this test is doing. Config = make_config(trainer=trainer, wrong_config_name=module, devices=2) makes it seem like we are making sure that launch fails for a config with a bad field name, but the test seems like should be exercising ddp

  • CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook and tutorial reference and as pretrained CIFAR-10 models. These models are not included in the standard rai-toolbox[mushin] install (perhaps due to licensing or desire to have most up-to-date models?).

    Models download from urls at robustness Github are named and and will cause the following error on In[10] of CIFAR10-Adversarial-Perturbations.ipynb:

    FileNotFoundError                         Traceback (most recent call last)
    Input In [8], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = ""
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in load(f, map_location, pickle_module, **pickle_load_args)
        696 if 'encoding' not in pickle_load_args.keys():
        697     pickle_load_args['encoding'] = 'utf-8'
    --> 699 with _open_file_like(f, 'rb') as opened_file:
        700     if _is_zipfile(opened_file):
        701         # The zipfile reader is going to advance the current file position.
        702         # If we want to actually tail call to torch.jit.load, we need to
        703         # reset back to the original position.
        704         orig_position = opened_file.tell()
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in _open_file_like(name_or_buffer, mode)
        229 def _open_file_like(name_or_buffer, mode):
        230     if _is_path(name_or_buffer):
    --> 231         return _open_file(name_or_buffer, mode)
        232     else:
        233         if 'w' in mode:
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/, in _open_file.__init__(self, name, mode)
        211 def __init__(self, name, mode):
    --> 212     super(_open_file, self).__init__(open(name, mode))
    FileNotFoundError: [Errno 2] No such file or directory: '/home/scott/.torch/models/'

    Models must be renamed and manually copied to /home/{$USER}/.torch/models to proceed with tutorial.

  • `zen` should not attempt to populate `*args` and `**kwargs`

    `zen` should not attempt to populate `*args` and `**kwargs`

    Previously zen would attempt to find a kwargs field in the config:


    def f(x, **kwargs): return x
    cfg = make_config(x=1)
    zen(f)(cfg)  # AttributeError: 'Config' object has no attribute 'kwargs'

    Now zen skips *args, **kwargs.

    def f(x, **kwargs): return x
    cfg = make_config(x=1)
    zen(f)(cfg)  # returns 1

    In the future we might permit some configured behavior for populating these.

  • Update gradient-descent solver

    Update gradient-descent solver

    • Renames: gradient_descent -> gradient_ascent
    • (bug fix) Ensures that returned loss always has the correct sign. Previously, when targeted=False the returned loss values would be negated relative to the actual loss landscape
    • Adds examples section to docs
    • Ensures that data and target can be any array-like input, not necessarily a tensor
  • Update madry example

    Update madry example

    Currently we use within hydra.main, which no longer works:

    We should update this to leverage zen, but to make sure that the plotting still works (i.e. we give the workflow the necessary context to gather the xarray)

  • Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    In this PR we will attempt to address two issues:

    1. Reproducible Hydra experiments purely from the run directory by pickling both the runtime configuration and the task function
      • An extension of Hydra's experimental rerun
    2. Solving Hydra+DDP for PyTorch Lightning ddp strategy by saving the task function
      • The current solution in HydraDDP has strong constraints on the expected task function. This limits what the user can do in their experiments.

    Hydra Rerun Capability

    Here we take advantage of Hydra Callbacks to save the runtime configuration and the desired task function. Currently our callback takes a task function on initialization but future Hydra version's may allow the Hydra to pass the task function to the callback methds.

    Callback implementation: MushinPickleJobCallback. This takes in a Hydra task function on initialization and saves the task function and runtime configuration in the hydra.runtime.output_dir folder. The pickled files are stored in:


    This implementation uses cloudpickle to support pickling of the task function. The only downside of this approach is that the task function must be hashable for pickling and "instantiable" for Hydra from the command line, e.g., defining the task function in the notebook won't work.

    Note: Submitit is capable of pickling functions that were created in __main__, so this should be possible

    Execution: With the configuration and task function saved in the job directory, we can rerun any experiment using:

    $ python -m rai_toolbox.mushin._hydra_rerun +config=<path to config.pickle> +task_fn=<path to task_fn.pickle>

    Lightning DDP

    Challenges this PR solves for Hydra+DDP:

    • Runs from notebook
    • Supports generic task functions (i.e., solves HydraDDP issue)
    • Task functions can run multiple Trainer methods (e.g., followed by Trainer.test). HydraDDP does not support these types of task functions

    First we must configure our custom Hydra Callback, MushinPickleJobCallback:

    task_fn_cfg = builds(...)
    callback_cfg = dict(
        save_job_info=builds(MushinPickleJobCallback, task_fn=task_fn_cfg)
    cs = ConfigStore.instance()"pickle_job", group="hydra/callbacks", node=callback_cfg)

    The Trainer strategy can then be configured with our costum Lightning ddp strategy, HydraRerunDDP:

    TrainerConfig = builds(Trainer,   strategy=builds(HydraRerunDDP))

    We must set hydra/callbacks in the overrides to launch a job:

    task_fn = instantiate(task_fn_cfg)
    launch(Config, task_fn, overrides=["hydra/callbacks=pickle_job", ...])


    • MushinPickleJobCallback will clean up the PL environment automatically at the end of a job.
    • See tests for examples.

    I plan to update this comment to better describe everything


    • [ ] Should we deprecate HydraDDP in favor of this
    • [ ] Can we pickle and use task functions built in a "main" setting like the notebook?
    • [ ] Structure of Hydra specific and Lightning specific code
    • [ ] More tests: - Validate results, not just pickle file available - Test Hydra rerun without Lightning
  • Implements elastic-net attack

    Implements elastic-net attack

    Derived from:

    Here is a trivial scenario where we are merely perturbing the "logits" themselves so that the specified targets will be optimized for. Let's see that the longer we run the optimizer, the more the learned perturbation shrinks (while still amounting to a successful attack).

    >>> from rai_toolbox.perturbations.solvers import elastic_net_attack
    >>> logits = [[0.497, 0.503]]
    >>> target = [0]
    >>> for num_steps in [1, 10, 100]:
    ...     _, x_adv, _ = elastic_net_attack(
    ...         model=lambda x: x,
    ...         data=logits,
    ...         target=target,
    ...         beta=1e-3,
    ...         c=2,
    ...         steps=num_steps,
    ...         confidence=.01,
    ...         lr=0.5,
    ...     )
    ...     print(f"num-steps: {num_steps}\n{x_adv}")
    num-steps: 1
    tensor([[ 1.4960, -0.4960]])
    num-steps: 10
    tensor([[0.5062, 0.4938]])
    num-steps: 100
    tensor([[0.5018, 0.4982]])
  • Use fused multiply-add to apply `grad_scale` and `grad_bias`

    Use fused multiply-add to apply `grad_scale` and `grad_bias`

    >>> a = torch.randn(4)
    >>> a
    tensor([ 0.0202,  1.0985,  1.3506, -0.6056])
    >>> b = torch.randn(4)
    >>> b
    tensor([-0.9732, -0.3497,  0.6245,  0.4022])
    >>> c = torch.randn(4, 1)
    >>> c
    tensor([[ 0.3743],
    >>> torch.add(b, c, alpha=10)
    tensor([[  2.7695,   3.3930,   4.3672,   4.1450],
            [-18.6971, -18.0736, -17.0994, -17.3216],
            [ -6.7845,  -6.1610,  -5.1868,  -5.4090],
            [ -8.9902,  -8.3667,  -7.3925,  -7.6147]])
    opened by rsokl 0
