A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learningtrajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Owner
Google
Google ❤️ Open Source
Google
Deep Learning with PyTorch made easy 🚀 !

Deep Learning with PyTorch made easy 🚀 ! Carefree? carefree-learn aims to provide CAREFREE usages for both users and developers. It also provides a c

381 Dec 22, 2022
Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer

AdaConv Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer from "Adaptive Convolutions for Structure-

65 Dec 22, 2022
Sample Code for "Pessimism Meets Invariance: Provably Efficient Offline Mean-Field Multi-Agent RL"

Sample Code for "Pessimism Meets Invariance: Provably Efficient Offline Mean-Field Multi-Agent RL" This is the official codebase for Pessimism Meets I

3 Sep 19, 2022
Make your master artistic punk avatar through machine learning world famous paintings.

Master-art-punk Make your master artistic punk avatar through machine learning world famous paintings. 通过机器学习世界名画制作属于你的大师级艺术朋克头像 Nowadays, NFT is beco

Philipjhc 53 Dec 27, 2022
TargetAllDomainObjects - A python wrapper to run a command on against all users/computers/DCs of a Windows Domain

TargetAllDomainObjects A python wrapper to run a command on against all users/co

Podalirius 19 Dec 13, 2022
Deep learning based hand gesture recognition using LSTM and MediaPipie.

Hand Gesture Recognition Deep learning based hand gesture recognition using LSTM and MediaPipie. Demo video using PingPong Robot Files Pretrained mode

Brad 24 Nov 11, 2022
Let's create a tool to convert Thailand budget from PDF to CSV.

thailand-budget-pdf2csv Let's create a tool to convert Thailand Government Budgeting from PDF to CSV! รวมพลัง Dev แปลงงบ จาก PDF สู่ Machine-readable

Kao.Geek 88 Dec 19, 2022
[NeurIPS-2020] Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID.

Self-paced Contrastive Learning (SpCL) The official repository for Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID

Yixiao Ge 286 Dec 21, 2022
Code I use to automatically update my videos' metadata on YouTube

mCodingYouTube This repository contains the code I use to automatically update my videos' metadata on YouTube, including: titles, descriptions, tags,

James Murphy 19 Oct 07, 2022
A geometric deep learning pipeline for predicting protein interface contacts.

A geometric deep learning pipeline for predicting protein interface contacts.

44 Dec 30, 2022
DeepDiffusion: Unsupervised Learning of Retrieval-adapted Representations via Diffusion-based Ranking on Latent Feature Manifold

DeepDiffusion Introduction This repository provides the code of the DeepDiffusion algorithm for unsupervised learning of retrieval-adapted representat

4 Nov 15, 2022
A Python Package For System Identification Using NARMAX Models

SysIdentPy is a Python module for System Identification using NARMAX models built on top of numpy and is distributed under the 3-Clause BSD license. N

Wilson Rocha 175 Dec 25, 2022
DrNAS: Dirichlet Neural Architecture Search

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random va

Xiangning Chen 37 Jan 03, 2023
Unofficial keras(tensorflow) implementation of MAE model from Masked Autoencoders Are Scalable Vision Learners

MAE-keras Unofficial keras(tensorflow) implementation of MAE model described in 'Masked Autoencoders Are Scalable Vision Learners'. This work has been

Yewon 11 Jun 12, 2022
PyTorch implementation for Graph Contrastive Learning with Augmentations

Graph Contrastive Learning with Augmentations PyTorch implementation for Graph Contrastive Learning with Augmentations [poster] [appendix] Yuning You*

Shen Lab at Texas A&M University 382 Dec 15, 2022
ROSITA: Enhancing Vision-and-Language Semantic Alignments via Cross- and Intra-modal Knowledge Integration

ROSITA News & Updates (24/08/2021) Release the demo to perform fine-grained semantic alignments using the pretrained ROSITA model. (15/08/2021) Releas

Vision and Language Group@ MIL 48 Dec 23, 2022
Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Perceiver - Pytorch Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch Install $ pip install perceiver-pytorch Usage

Phil Wang 876 Dec 29, 2022
基于Pytorch实现优秀的自然图像分割框架!(包括FCN、U-Net和Deeplab)

语义分割学习实验-基于VOC数据集 usage: 下载VOC数据集,将JPEGImages SegmentationClass两个文件夹放入到data文件夹下。 终端切换到目标目录,运行python train.py -h查看训练 (torch) Li Xiang 28 Dec 21, 2022

Authors implementation of LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant self-at

35 Oct 18, 2022