A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learningtrajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Owner
Google
Google ❤️ Open Source
Google
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Dense Unsupervised Learning for Video Segmentation (NeurIPS*2021)

Dense Unsupervised Learning for Video Segmentation This repository contains the official implementation of our paper: Dense Unsupervised Learning for

Visual Inference Lab @TU Darmstadt 173 Dec 26, 2022
Official PyTorch implementation of Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval.

Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval PyTorch This is the PyTorch implementation of Retrieve in Style: Unsupervised Fa

60 Oct 12, 2022
Decision Transformer: A brand new Offline RL Pattern

DecisionTransformer_StepbyStep Intro Decision Transformer: A brand new Offline RL Pattern. 这是关于NeurIPS 2021 热门论文Decision Transformer的复现。 👍 原文地址: Deci

Irving 14 Nov 22, 2022
A video scene detection algorithm is designed to detect a variety of different scenes within a video

Scene-Change-Detection - A video scene detection algorithm is designed to detect a variety of different scenes within a video. There is a very simple definition for a scene: It is a series of logical

1 Jan 04, 2022
The world's largest toxicity dataset.

The Toxicity Dataset by Surge AI Saving the internet is fun. Combing through thousands of online comments to build a toxicity dataset isn't. That's wh

Surge AI 134 Dec 19, 2022
Forecasting for knowable future events using Bayesian informative priors (forecasting with judgmental-adjustment).

What is judgyprophet? judgyprophet is a Bayesian forecasting algorithm based on Prophet, that enables forecasting while using information known by the

AstraZeneca 56 Oct 26, 2022
NEG loss implemented in pytorch

Pytorch Negative Sampling Loss Negative Sampling Loss implemented in PyTorch. Usage neg_loss = NEG_loss(num_classes, embedding_size) optimizer =

Daniil Gavrilov 123 Sep 13, 2022
Simultaneous NMT/MMT framework in PyTorch

This repository includes the codes, the experiment configurations and the scripts to prepare/download data for the Simultaneous Machine Translation wi

<a href=[email protected]"> 37 Sep 29, 2022
Numba-accelerated Pythonic implementation of MPDATA with examples in Python, Julia and Matlab

PyMPDATA PyMPDATA is a high-performance Numba-accelerated Pythonic implementation of the MPDATA algorithm of Smolarkiewicz et al. used in geophysical

Atmospheric Cloud Simulation Group @ Jagiellonian University 15 Nov 23, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022
Official implementation of the NRNS paper: No RL, No Simulation: Learning to Navigate without Navigating

No RL No Simulation (NRNS) Official implementation of the NRNS paper: No RL, No Simulation: Learning to Navigate without Navigating NRNS is a heriarch

Meera Hahn 20 Nov 29, 2022
Auto White-Balance Correction for Mixed-Illuminant Scenes

Auto White-Balance Correction for Mixed-Illuminant Scenes Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown York University Video Reference code

Mahmoud Afifi 47 Nov 26, 2022
Combinatorial model of ligand-receptor binding

Combinatorial model of ligand-receptor binding The binding of ligands to receptors is the starting point for many import signal pathways within a cell

Mobolaji Williams 0 Jan 09, 2022
Unofficial reimplementation of ECAPA-TDNN for speaker recognition (EER=0.86 for Vox1_O when train only in Vox2)

Introduction This repository contains my unofficial reimplementation of the standard ECAPA-TDNN, which is the speaker recognition in VoxCeleb2 dataset

Tao Ruijie 277 Dec 31, 2022
Fully Convolutional DenseNet (A.K.A 100 layer tiramisu) for semantic segmentation of images implemented in TensorFlow.

FC-DenseNet-Tensorflow This is a re-implementation of the 100 layer tiramisu, technically a fully convolutional DenseNet, in TensorFlow (Tiramisu). Th

Hasnain Raza 121 Oct 12, 2022
This python-based package offers a way of creating a parametric OpenMC plasma source from plasma parameters.

openmc-plasma-source This python-based package offers a way of creating a parametric OpenMC plasma source from plasma parameters. The OpenMC sources a

Fusion Energy 10 Oct 18, 2022
A python library to build Model Trees with Linear Models at the leaves.

A python library to build Model Trees with Linear Models at the leaves.

Marco Cerliani 212 Dec 30, 2022
Spectralformer: Rethinking hyperspectral image classification with transformers

Spectralformer: Rethinking hyperspectral image classification with transformers Danfeng Hong, Zhu Han, Jing Yao, Lianru Gao, Bing Zhang, Antonio Plaza

Danfeng Hong 102 Dec 29, 2022
A modular domain adaptation library written in PyTorch.

A modular domain adaptation library written in PyTorch.

Kevin Musgrave 225 Dec 29, 2022