A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learningtrajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Owner
Google
Google ❤️ Open Source
Google
Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21.

Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21. We optimized wind turbine placement in a wind farm, subject to wake effects, using Q-learni

Manasi Sharma 2 Sep 27, 2022
An implementation of the efficient attention module.

Efficient Attention An implementation of the efficient attention module. Description Efficient attention is an attention mechanism that substantially

Shen Zhuoran 194 Dec 15, 2022
DIRL: Domain-Invariant Representation Learning

DIRL: Domain-Invariant Representation Learning Domain-Invariant Representation Learning (DIRL) is a novel algorithm that semantically aligns both the

Ajay Tanwani 30 Nov 07, 2022
Stochastic Normalizing Flows

Stochastic Normalizing Flows We introduce stochasticity in Boltzmann-generating flows. Normalizing flows are exact-probability generative models that

AI4Science group, FU Berlin (Frank Noé and co-workers) 50 Dec 16, 2022
N-gram models- Unsmoothed, Laplace, Deleted Interpolation

N-gram models- Unsmoothed, Laplace, Deleted Interpolation

Ravika Nagpal 1 Jan 04, 2022
Adversarial Graph Augmentation to Improve Graph Contrastive Learning

ADGCL : Adversarial Graph Augmentation to Improve Graph Contrastive Learning Introduction This repo contains the Pytorch [1] implementation of Adversa

susheel suresh 62 Nov 19, 2022
A module for solving and visualizing Schrödinger equation.

qmsolve This is an attempt at making a solid, easy to use solver, capable of solving and visualize the Schrödinger equation for multiple particles, an

506 Dec 28, 2022
HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep.

HODEmu HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep. and emulates satellite abundance as a function of co

Antonio Ragagnin 1 Oct 13, 2021
Python code to fuse multiple RGB-D images into a TSDF voxel volume.

Volumetric TSDF Fusion of RGB-D Images in Python This is a lightweight python script that fuses multiple registered color and depth images into a proj

Andy Zeng 845 Jan 03, 2023
Collection of machine learning related notebooks to share.

ML_Notebooks Collection of machine learning related notebooks to share. Notebooks GAN_distributed_training.ipynb In this Notebook, TensorFlow's tutori

Sascha Kirch 14 Dec 22, 2022
Projects of Andfun Yangon

AndFunYangon Projects of Andfun Yangon First Commit We can use gsearch.py to sea

Htin Aung Lu 1 Dec 28, 2021
Python script to download the celebA-HQ dataset from google drive

download-celebA-HQ Python script to download and create the celebA-HQ dataset. WARNING from the author. I believe this script is broken since a few mo

133 Dec 21, 2022
NPBG++: Accelerating Neural Point-Based Graphics

[CVPR 2022] NPBG++: Accelerating Neural Point-Based Graphics Project Page | Paper This repository contains the official Python implementation of the p

Ruslan Rakhimov 57 Dec 03, 2022
Checkout some cool self-projects you can try your hands on to curb your boredom this December!

SoC-Winter Checkout some cool self-projects you can try your hands on to curb your boredom this December! These are short projects that you can do you

Web and Coding Club, IIT Bombay 29 Nov 08, 2022
PyTorch implementation for MINE: Continuous-Depth MPI with Neural Radiance Fields

MINE: Continuous-Depth MPI with Neural Radiance Fields Project Page | Video PyTorch implementation for our ICCV 2021 paper. MINE: Towards Continuous D

Zijian Feng 325 Dec 29, 2022
A PyTorch implementation of the WaveGlow: A Flow-based Generative Network for Speech Synthesis

WaveGlow A PyTorch implementation of the WaveGlow: A Flow-based Generative Network for Speech Synthesis Quick Start: Install requirements: pip install

Yuchao Zhang 204 Jul 14, 2022
Code for "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clouds", CVPR 2021

PV-RAFT This repository contains the PyTorch implementation for paper "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clou

Yi Wei 43 Dec 05, 2022
CvT2DistilGPT2 is an encoder-to-decoder model that was developed for chest X-ray report generation.

CvT2DistilGPT2 Improving Chest X-Ray Report Generation by Leveraging Warm-Starting This repository houses the implementation of CvT2DistilGPT2 from [1

The Australian e-Health Research Centre 21 Dec 28, 2022
Independent and minimal implementations of some reinforcement learning algorithms using PyTorch (including PPO, A3C, A2C, ...).

PyTorch RL Minimal Implementations There are implementations of some reinforcement learning algorithms, whose characteristics are as follow: Less pack

Gemini Light 4 Dec 31, 2022
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 126 Dec 27, 2022