Efficient Householder transformation in PyTorch

Overview

Efficient Householder Transformation in PyTorch

PyPiVersion PythonVersion PyPiDownloads License License: CC BY 4.0

This repository implements the Householder transformation algorithm for calculating orthogonal matrices and Stiefel frames. The algorithm is implemented as a Python package with differentiable bindings to PyTorch. In particular, the package provides an enhanced drop-in replacement for the torch.orgqr function.

Overview

APIs for orthogonal transformations have been around since LAPACK; however, their support in the deep learning frameworks is lacking. Recently, orthogonal constraints have become popular in deep learning as a way to regularize models and improve training dynamics [1, 2], and hence the need to backpropagate through orthogonal transformations arised.

PyTorch 1.7 implements matrix exponential function torch.matrix_exp, which can be repurposed to performing the orthogonal transformation when the input matrix is skew-symmetric. This is the baseline we use in Speed and Precision evaluation.

Compared to torch.matrix_exp, the Householder transformation implemented in this package has the following advantages:

  • Orders of magnitude lower memory footprint
  • Ability to transform non-square matrices (Stiefel frames)
  • A significant speed-up for non-square matrices
  • Better numerical precision for all matrix and batch sizes

Usage

Installation

pip install torch-householder

API

The Householder transformation takes a matrix of Householder reflectors parameters of shape d x r with d >= r > 0 (denoted as 'thin' matrix from now on) and produces an orthogonal matrix of the same shape.

torch_householder_orgqr(param) is the recommended API in the Deep Learning context. Other arguments of this function are provided for compatibility with the torch.orgqr (Reference) interface.

Unlike torch.orgqr, torch_householder_orgqr supports:

  • Both CPU and GPU devices
  • Backpropagation through arguments
  • Batched inputs

The parameter param is a matrix of size d x r or a batch of matrices b x d x r of Householder reflectors parameters. The LAPACK convention suggests to structure the matrix of parameters as shown in the figure on the right.

Given a matrix param of size d x r, here is a simple way to construct a valid matrix of Householder reflectors parameters from it:

hh = param.tril(diagonal=-1) + torch.eye(d, r)

This result can be used as an argument to torch_householder_orgqr.

Example

import torch
from torch_householder import torch_householder_orgqr

D, R = 4, 2
param = torch.randn(D, R)
hh = param.tril(diagonal=-1) + torch.eye(D, R)

mat = torch_householder_orgqr(hh)

print(mat)              # a 4x2 Stiefel frame
print(mat.T.mm(mat))    # a 2x2 identity matrix

Output:

tensor([[ 0.4141, -0.3049],
        [ 0.2262,  0.5306],
        [-0.5587,  0.6074],
        [ 0.6821,  0.5066]])
tensor([[ 1.0000e+00, -2.9802e-08],
        [-2.9802e-08,  1.0000e+00]])

Speed

Given a tuple of b (batch size), d (matrix height), and r (matrix width), we generate a random batch of orthogonal parameters of the given shape and perform a fixed number of orthogonal transformations with both torch.matrix_exp and torch_householder_orgqr functions. We then associate each such tuple with a ratio of run times taken by functions.

We perform a sweep of matrix dimensions d and r, starting with 1 and doubling each time until reaching 32768. The batch dimension is swept similarly until reaching the maximum size of 2048. The sweeps were performed on a single NVIDIA P-100 GPU with 16 GB of RAM using the code from the benchmark:

Speed chart

Since the ORGQR function's convention assumes only thin matrices with d >= r > 0, we skip the evaluation of fat matrices altogether.

The torch.matrix_exp only deals with square matrices; hence to parameterize a thin matrix with d > r, we perform transformation of a square skew-symmetric matrix d x d and then take a d x r minor from the result. This aspect makes torch.matrix_exp especially inefficient for parameterizing Stiefel frames and provides major speed gains to the Householder transformation.

Numerical Precision

The numerical precision of an orthogonal transformation is evaluated using a synthetic test. Given a matrix size d x r, we generate random inputs and perform orthogonal transformation with the tested function. Since the output M of size d x r is expected to be a Stiefel frame, we calculate transformation error using the formula below. This calculation is repeated for each matrix size at least 5 times, and the results are averaged.

We re-use the sweep used for benchmarking speed, compute errors of both functions using the formula above, and report their ratio in the following chart:

Error chart

Conclusions

The speed chart suggests that the Householder transformation is especially efficient when either matrix width r is smaller than matrix height d or when batch size is comparable with matrix dimensions. In these cases, the Householder transformation provides up to a few orders of magnitude speed-up. However, for the rest of the square matrix cases, torch.matrix_exp appears to be faster.

Another benefit of torch_householder_orgqr is its memory usage, which is much lower than that of torch.matrix_exp. This property allows us to transform either much larger matrices, or larger batches with torch_householder_orgqr. To give an example, in sub-figure corresponding to "batch size: 1", the largest matrix transformed with torch_householder_orgqr has the size 16384 x 16384, whereas the largest matrix transformed with torch.matrix_exp is only 4096 x 4096.

As can be seen from the precision chart, tests with the Householder transformation lead to orders of magnitude more accurate orthogonality constraints in all tested configurations.

Citation

To cite this repository, use the following BibTeX:

@misc{obukhov2021torchhouseholder,
    author = {Anton Obukhov},
    year = 2021,
    title = {Efficient Householder transformation in PyTorch},
    url = {www.github.com/toshas/torch-householder}
}
Comments
  • How are the gradients implemented for non-full rank matrices?

    How are the gradients implemented for non-full rank matrices?

    It is not clear how to implement the "gradient" (adjoint) of the QR decomposition for a matrix that it is not full rank (e.g. the zero matrix). How does this package handle this?

    Also, if this is a drop-in replacement for the QR decomposition implemented in PyTorch and it works better, why not making a PR to core PyTorch with this? Where does the speed-up come from vs LAPACK / MAGMA?

    opened by lezcano 9
  • error when installing torch-householder via pip

    error when installing torch-householder via pip

    Hello, I am having trouble installing the package. I don't know if there is a dependency problem. Could there be an installation problem due to the version of CUDA? My specs are using Ubuntu 18.04 and CUDA 11.3. There is no difference from installing in CUDA 10.2 ?

    Collecting torch-householder
      Using cached torch_householder-1.0.1.tar.gz (457 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bee491246ef740cbb618d42e92070ff3
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/f3/7d/a87d4ea6c11f23d237fc81c094a6c18909486fdb9914599479cbeb5d089f/torch_householder-1.0.1.tar.gz#sha256=9a4b240c68947491c4e96a78771497562650f9a555001e062a0969fce206f786 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61 Check the logs for full command output.
      Using cached torch_householder-1.0.0.tar.gz (177 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bb93639781d44e6799b67c9f4f83fae9
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/c2/9c/7af0f1414e24c09ddc67439364c70c7e894c65ed83f94aa82a0ff3308673/torch_householder-1.0.0.tar.gz#sha256=e9f06c29685a6bcbc360af5c8cbae9227d9990e3e9acc744b0ea15654ab782c0 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement torch-householder (from versions: 1.0.0, 1.0.1)
    ERROR: No matching distribution found for torch-householder
    
    opened by jeongwhanchoi 2
  • [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    I have some code that uses torch.matrix_exp and would be very happy to speed it up. Is this possible using your library? I sort of lost my way trying to figure it out from the benchmarks code. Many thanks in advance?

    opened by generalizedeigenvector 1
  • The parametrisation does not seem to be surjective.

    The parametrisation does not seem to be surjective.

    When having a look at the implementation and looking at the differences with torch.householder_product, I found a weird thing.

    At the moment, when called with a tensor of the form hh = param.tril(diagonal=-1) + torch.eye(d, r) (as per the documentation) we are passing nk - k(k+1)/2 parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).

    Now, when this matrix is passed to torch_householder_orgqr, it columns are normalised: https://github.com/toshas/torch-householder/blob/afe72ebf9a01c257a5070a6d01b3f81be9e8efd6/torch_householder/householder.py#L67 Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.

    Do you know what is going on here?

    opened by lezcano 1
  • > \prod_{i=1}^k H(v_i)

    > \prod_{i=1}^k H(v_i)

    \prod_{i=1}^k H(v_i), is that right? Right. Hi, toshas, your work is great, thanks! You confirm that if we denote H(v) = I_n - v v^T / \norm{v}^2 a Householder reflection, the torch_householder_orgqr() computes \prod_{i=1}^k H(v_i), and I have two questions, one is I think H(v) = I_n – 2vv^T/\norm{v}^2, and the other is \prod_{i=1}^k H(v_i) means H_1* H_2*…H_(n-1) or H_(n-1)* H_(n-2)…H_1 ? If it is H_1 H_2*…H_(n-1), it refers to Q in QR decomposition, right? But you said earlier that this is not a QR decomposition, so I feel a little confused. Hope to get your prompt reply, thank you.

    opened by xjzhao001 1
  • Slow performance compared to `torch.linalg.householder_product` in forward pass

    Slow performance compared to `torch.linalg.householder_product` in forward pass

    Problem

    I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

    Code

    I'm constructing the matrix during forward pass like this:

    def __init__(self, ..):
           [..]
    
            # householder decomposition
            decomp, tau = torch.geqrf(filters)
    
            # assume that DCT is orthogonal
            filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])
    
            # register everything as parameter and set gradient flags
            self.filters = torch.nn.Parameter(filters, requires_grad=True)
            self.register_parameter('filter_q', self.filters)
    
    def filters(self):
            valid_coeffs = self.filters.tril(diagonal=-1)
            tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
            return torch.linalg.householder_product(valid_coeffs, tau)
            #return torch_householder_orgqr(valid_coeffs, tau)
    

    Profiles

    All profiles are created with the pytorch profiler with warmup of one and two trial runs:

    Profile torch householder_product (matrix 512x512 f32)

    • forward pass: ~823us
    • backward pass: ~790ms

    Marked forward pass and backward pass visible in light green:

    image

    Profile torch-householder (matrix 512x512 f32)

    • forward pass: ~240ms
    • backward pass: ~513ms

    image

    Questions

    I'm not an expert in torch and do not follow the development closely. There is an issue https://github.com/pytorch/pytorch/issues/50104 for integrating CUDA support to orgqr, may this cause the difference in time?

    • why is the torch-householder library much slower in the forward pass
    • is this performance expected from AD of a matrix w.r.t to its householder or am I doing something wrong here?
    • why does the number actually add up again to ~800ms, this makes me suspect that my profiling is doing something wrong but couldn't find a cause

    I'm also happy to share the traces with you, please just ping then :)

    opened by bytesnake 6
  • [POLL] Should the package switch install-time to run-time native code compilation?

    [POLL] Should the package switch install-time to run-time native code compilation?

    Currently, the package compiles native code (C++) upon package installation. This saves a few seconds during code run time, as the compilation does not happen when the user code starts. However, one scenario when it hurts is when the package is installed from a different environment or a machine than the actual code will be run on. This is a use case with most cluster environments, where packages may be installed from a login node, rather than the actual machine with the GPU.

    Should compilation be rather performed at run time?

    👍 - Move compilation to run time 👎 - Keep as is

    opened by toshas 0
Releases(v1.0.1)
Owner
Anton Obukhov
CV+ML PhD student with industrial past
Anton Obukhov
Stochastic Extragradient: General Analysis and Improved Rates

Stochastic Extragradient: General Analysis and Improved Rates This repository is the official implementation of the paper "Stochastic Extragradient: G

Hugo Berard 4 Nov 11, 2022
EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit

EvoJAX: Hardware-Accelerated Neuroevolution EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JA

Google 598 Jan 07, 2023
Fashion Landmark Estimation with HRNet

HRNet for Fashion Landmark Estimation (Modified from deep-high-resolution-net.pytorch) Introduction This code applies the HRNet (Deep High-Resolution

SVIP Lab 91 Dec 26, 2022
This repository contains demos I made with the Transformers library by HuggingFace.

Transformers-Tutorials Hi there! This repository contains demos I made with the Transformers library by 🤗 HuggingFace. Currently, all of them are imp

3.5k Jan 01, 2023
Official repo for QHack—the quantum machine learning hackathon

Note: This repository has been frozen while we consider the submissions for the QHack Open Hackathon. We hope you enjoyed the event! Welcome to QHack,

Xanadu 118 Jan 05, 2023
FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective

FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective Official implementation of "FL-WBC: Enhan

Jingwei Sun 26 Nov 28, 2022
Implementation of algorithms for continuous control (DDPG and NAF).

DEPRECATION This repository is deprecated and is no longer maintaned. Please see a more recent implementation of RL for continuous control at jax-sac.

Ilya Kostrikov 288 Dec 31, 2022
High-performance moving least squares material point method (MLS-MPM) solver.

High-Performance MLS-MPM Solver with Cutting and Coupling (CPIC) (MIT License) A Moving Least Squares Material Point Method with Displacement Disconti

Yuanming Hu 2.2k Dec 31, 2022
A simple consistency training framework for semi-supervised image semantic segmentation

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation PseudoSeg is a simple consistency training framework for semi-supervised image semantic s

Google Interns 143 Dec 13, 2022
Some bravo or inspiring research works on the topic of curriculum learning.

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

131 Jan 07, 2023
Low Complexity Channel estimation with Neural Network Solutions

Interpolation-ResNet Invited paper for WSA 2021, called 'Low Complexity Channel estimation with Neural Network Solutions'. Low complexity residual con

Dianxin 10 Dec 10, 2022
This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods

pyLiDAR-SLAM This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods, which can easily be evaluated

Kitware, Inc. 208 Dec 16, 2022
(CVPR 2022 Oral) Official implementation for "Surface Representation for Point Clouds"

RepSurf - Surface Representation for Point Clouds [CVPR 2022 Oral] By Haoxi Ran* , Jun Liu, Chengjie Wang ( * : corresponding contact) The pytorch off

Haoxi Ran 264 Dec 23, 2022
This is an official implementation for "PlaneRecNet".

PlaneRecNet This is an official implementation for PlaneRecNet: A multi-task convolutional neural network provides instance segmentation for piece-wis

yaxu 50 Nov 17, 2022
Image-to-Image Translation in PyTorch

CycleGAN and pix2pix in PyTorch New: Please check out contrastive-unpaired-translation (CUT), our new unpaired image-to-image translation model that e

Jun-Yan Zhu 19k Jan 07, 2023
A wrapper around SageMaker ML Lineage Tracking extending ML Lineage to end-to-end ML lifecycles, including additional capabilities around Feature Store groups, queries, and other relevant artifacts.

ML Lineage Helper This library is a wrapper around the SageMaker SDK to support ease of lineage tracking across the ML lifecycle. Lineage artifacts in

AWS Samples 12 Nov 01, 2022
Serve TensorFlow ML models with TF-Serving and then create a Streamlit UI to use them

TensorFlow Serving + Streamlit! ✨ 🖼️ Serve TensorFlow ML models with TF-Serving and then create a Streamlit UI to use them! This is a pretty simple S

Álvaro Bartolomé 18 Jan 07, 2023
Code for the paper "SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness" (NeurIPS 2021)

SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness (NeurIPS2021) This repository contains code for the paper "Smo

Jongheon Jeong 17 Dec 27, 2022
KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control

KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control Tomas Jakab, Richard Tucker, Ameesh Makadia, Jiajun Wu, Noah Snavely, Angjoo Ka

Tomas Jakab 87 Nov 30, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022