Efficient Householder transformation in PyTorch

Overview

Efficient Householder Transformation in PyTorch

PyPiVersion PythonVersion PyPiDownloads License License: CC BY 4.0

This repository implements the Householder transformation algorithm for calculating orthogonal matrices and Stiefel frames. The algorithm is implemented as a Python package with differentiable bindings to PyTorch. In particular, the package provides an enhanced drop-in replacement for the torch.orgqr function.

Overview

APIs for orthogonal transformations have been around since LAPACK; however, their support in the deep learning frameworks is lacking. Recently, orthogonal constraints have become popular in deep learning as a way to regularize models and improve training dynamics [1, 2], and hence the need to backpropagate through orthogonal transformations arised.

PyTorch 1.7 implements matrix exponential function torch.matrix_exp, which can be repurposed to performing the orthogonal transformation when the input matrix is skew-symmetric. This is the baseline we use in Speed and Precision evaluation.

Compared to torch.matrix_exp, the Householder transformation implemented in this package has the following advantages:

  • Orders of magnitude lower memory footprint
  • Ability to transform non-square matrices (Stiefel frames)
  • A significant speed-up for non-square matrices
  • Better numerical precision for all matrix and batch sizes

Usage

Installation

pip install torch-householder

API

The Householder transformation takes a matrix of Householder reflectors parameters of shape d x r with d >= r > 0 (denoted as 'thin' matrix from now on) and produces an orthogonal matrix of the same shape.

torch_householder_orgqr(param) is the recommended API in the Deep Learning context. Other arguments of this function are provided for compatibility with the torch.orgqr (Reference) interface.

Unlike torch.orgqr, torch_householder_orgqr supports:

  • Both CPU and GPU devices
  • Backpropagation through arguments
  • Batched inputs

The parameter param is a matrix of size d x r or a batch of matrices b x d x r of Householder reflectors parameters. The LAPACK convention suggests to structure the matrix of parameters as shown in the figure on the right.

Given a matrix param of size d x r, here is a simple way to construct a valid matrix of Householder reflectors parameters from it:

hh = param.tril(diagonal=-1) + torch.eye(d, r)

This result can be used as an argument to torch_householder_orgqr.

Example

import torch
from torch_householder import torch_householder_orgqr

D, R = 4, 2
param = torch.randn(D, R)
hh = param.tril(diagonal=-1) + torch.eye(D, R)

mat = torch_householder_orgqr(hh)

print(mat)              # a 4x2 Stiefel frame
print(mat.T.mm(mat))    # a 2x2 identity matrix

Output:

tensor([[ 0.4141, -0.3049],
        [ 0.2262,  0.5306],
        [-0.5587,  0.6074],
        [ 0.6821,  0.5066]])
tensor([[ 1.0000e+00, -2.9802e-08],
        [-2.9802e-08,  1.0000e+00]])

Speed

Given a tuple of b (batch size), d (matrix height), and r (matrix width), we generate a random batch of orthogonal parameters of the given shape and perform a fixed number of orthogonal transformations with both torch.matrix_exp and torch_householder_orgqr functions. We then associate each such tuple with a ratio of run times taken by functions.

We perform a sweep of matrix dimensions d and r, starting with 1 and doubling each time until reaching 32768. The batch dimension is swept similarly until reaching the maximum size of 2048. The sweeps were performed on a single NVIDIA P-100 GPU with 16 GB of RAM using the code from the benchmark:

Speed chart

Since the ORGQR function's convention assumes only thin matrices with d >= r > 0, we skip the evaluation of fat matrices altogether.

The torch.matrix_exp only deals with square matrices; hence to parameterize a thin matrix with d > r, we perform transformation of a square skew-symmetric matrix d x d and then take a d x r minor from the result. This aspect makes torch.matrix_exp especially inefficient for parameterizing Stiefel frames and provides major speed gains to the Householder transformation.

Numerical Precision

The numerical precision of an orthogonal transformation is evaluated using a synthetic test. Given a matrix size d x r, we generate random inputs and perform orthogonal transformation with the tested function. Since the output M of size d x r is expected to be a Stiefel frame, we calculate transformation error using the formula below. This calculation is repeated for each matrix size at least 5 times, and the results are averaged.

We re-use the sweep used for benchmarking speed, compute errors of both functions using the formula above, and report their ratio in the following chart:

Error chart

Conclusions

The speed chart suggests that the Householder transformation is especially efficient when either matrix width r is smaller than matrix height d or when batch size is comparable with matrix dimensions. In these cases, the Householder transformation provides up to a few orders of magnitude speed-up. However, for the rest of the square matrix cases, torch.matrix_exp appears to be faster.

Another benefit of torch_householder_orgqr is its memory usage, which is much lower than that of torch.matrix_exp. This property allows us to transform either much larger matrices, or larger batches with torch_householder_orgqr. To give an example, in sub-figure corresponding to "batch size: 1", the largest matrix transformed with torch_householder_orgqr has the size 16384 x 16384, whereas the largest matrix transformed with torch.matrix_exp is only 4096 x 4096.

As can be seen from the precision chart, tests with the Householder transformation lead to orders of magnitude more accurate orthogonality constraints in all tested configurations.

Citation

To cite this repository, use the following BibTeX:

@misc{obukhov2021torchhouseholder,
    author = {Anton Obukhov},
    year = 2021,
    title = {Efficient Householder transformation in PyTorch},
    url = {www.github.com/toshas/torch-householder}
}
Comments
  • How are the gradients implemented for non-full rank matrices?

    How are the gradients implemented for non-full rank matrices?

    It is not clear how to implement the "gradient" (adjoint) of the QR decomposition for a matrix that it is not full rank (e.g. the zero matrix). How does this package handle this?

    Also, if this is a drop-in replacement for the QR decomposition implemented in PyTorch and it works better, why not making a PR to core PyTorch with this? Where does the speed-up come from vs LAPACK / MAGMA?

    opened by lezcano 9
  • error when installing torch-householder via pip

    error when installing torch-householder via pip

    Hello, I am having trouble installing the package. I don't know if there is a dependency problem. Could there be an installation problem due to the version of CUDA? My specs are using Ubuntu 18.04 and CUDA 11.3. There is no difference from installing in CUDA 10.2 ?

    Collecting torch-householder
      Using cached torch_householder-1.0.1.tar.gz (457 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bee491246ef740cbb618d42e92070ff3
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/f3/7d/a87d4ea6c11f23d237fc81c094a6c18909486fdb9914599479cbeb5d089f/torch_householder-1.0.1.tar.gz#sha256=9a4b240c68947491c4e96a78771497562650f9a555001e062a0969fce206f786 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61 Check the logs for full command output.
      Using cached torch_householder-1.0.0.tar.gz (177 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bb93639781d44e6799b67c9f4f83fae9
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/c2/9c/7af0f1414e24c09ddc67439364c70c7e894c65ed83f94aa82a0ff3308673/torch_householder-1.0.0.tar.gz#sha256=e9f06c29685a6bcbc360af5c8cbae9227d9990e3e9acc744b0ea15654ab782c0 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement torch-householder (from versions: 1.0.0, 1.0.1)
    ERROR: No matching distribution found for torch-householder
    
    opened by jeongwhanchoi 2
  • [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    I have some code that uses torch.matrix_exp and would be very happy to speed it up. Is this possible using your library? I sort of lost my way trying to figure it out from the benchmarks code. Many thanks in advance?

    opened by generalizedeigenvector 1
  • The parametrisation does not seem to be surjective.

    The parametrisation does not seem to be surjective.

    When having a look at the implementation and looking at the differences with torch.householder_product, I found a weird thing.

    At the moment, when called with a tensor of the form hh = param.tril(diagonal=-1) + torch.eye(d, r) (as per the documentation) we are passing nk - k(k+1)/2 parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).

    Now, when this matrix is passed to torch_householder_orgqr, it columns are normalised: https://github.com/toshas/torch-householder/blob/afe72ebf9a01c257a5070a6d01b3f81be9e8efd6/torch_householder/householder.py#L67 Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.

    Do you know what is going on here?

    opened by lezcano 1
  • > \prod_{i=1}^k H(v_i)

    > \prod_{i=1}^k H(v_i)

    \prod_{i=1}^k H(v_i), is that right? Right. Hi, toshas, your work is great, thanks! You confirm that if we denote H(v) = I_n - v v^T / \norm{v}^2 a Householder reflection, the torch_householder_orgqr() computes \prod_{i=1}^k H(v_i), and I have two questions, one is I think H(v) = I_n – 2vv^T/\norm{v}^2, and the other is \prod_{i=1}^k H(v_i) means H_1* H_2*…H_(n-1) or H_(n-1)* H_(n-2)…H_1 ? If it is H_1 H_2*…H_(n-1), it refers to Q in QR decomposition, right? But you said earlier that this is not a QR decomposition, so I feel a little confused. Hope to get your prompt reply, thank you.

    opened by xjzhao001 1
  • Slow performance compared to `torch.linalg.householder_product` in forward pass

    Slow performance compared to `torch.linalg.householder_product` in forward pass

    Problem

    I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

    Code

    I'm constructing the matrix during forward pass like this:

    def __init__(self, ..):
           [..]
    
            # householder decomposition
            decomp, tau = torch.geqrf(filters)
    
            # assume that DCT is orthogonal
            filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])
    
            # register everything as parameter and set gradient flags
            self.filters = torch.nn.Parameter(filters, requires_grad=True)
            self.register_parameter('filter_q', self.filters)
    
    def filters(self):
            valid_coeffs = self.filters.tril(diagonal=-1)
            tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
            return torch.linalg.householder_product(valid_coeffs, tau)
            #return torch_householder_orgqr(valid_coeffs, tau)
    

    Profiles

    All profiles are created with the pytorch profiler with warmup of one and two trial runs:

    Profile torch householder_product (matrix 512x512 f32)

    • forward pass: ~823us
    • backward pass: ~790ms

    Marked forward pass and backward pass visible in light green:

    image

    Profile torch-householder (matrix 512x512 f32)

    • forward pass: ~240ms
    • backward pass: ~513ms

    image

    Questions

    I'm not an expert in torch and do not follow the development closely. There is an issue https://github.com/pytorch/pytorch/issues/50104 for integrating CUDA support to orgqr, may this cause the difference in time?

    • why is the torch-householder library much slower in the forward pass
    • is this performance expected from AD of a matrix w.r.t to its householder or am I doing something wrong here?
    • why does the number actually add up again to ~800ms, this makes me suspect that my profiling is doing something wrong but couldn't find a cause

    I'm also happy to share the traces with you, please just ping then :)

    opened by bytesnake 6
  • [POLL] Should the package switch install-time to run-time native code compilation?

    [POLL] Should the package switch install-time to run-time native code compilation?

    Currently, the package compiles native code (C++) upon package installation. This saves a few seconds during code run time, as the compilation does not happen when the user code starts. However, one scenario when it hurts is when the package is installed from a different environment or a machine than the actual code will be run on. This is a use case with most cluster environments, where packages may be installed from a login node, rather than the actual machine with the GPU.

    Should compilation be rather performed at run time?

    👍 - Move compilation to run time 👎 - Keep as is

    opened by toshas 0
Releases(v1.0.1)
Owner
Anton Obukhov
CV+ML PhD student with industrial past
Anton Obukhov
Решения, подсказки, тесты и утилиты для тренировки по алгоритмам от Яндекса.

Решения и подсказки к тренировке по алгоритмам от Яндекса Что есть внутри Решения с подсказками и комментариями; рекомендую сначала смотреть md файл п

Yankovsky Andrey 50 Dec 26, 2022
Database Reasoning Over Text project for ACL paper

Database Reasoning over Text This repository contains the code for the Database Reasoning Over Text paper, to appear at ACL2021. Work is performed in

Facebook Research 320 Dec 12, 2022
Notes taking website build with Docker + Django + React.

Notes website. Try it in browser! / But how to run? Description. This is monorepository with notes website. Website provides web interface for creatin

Kirill Zhosul 2 Jul 27, 2022
[CVPR 2021] Teachers Do More Than Teach: Compressing Image-to-Image Models (CAT)

CAT arXiv Pytorch implementation of our method for compressing image-to-image models. Teachers Do More Than Teach: Compressing Image-to-Image Models Q

Snap Research 160 Dec 09, 2022
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation

CoTr: Efficient 3D Medical Image Segmentation by bridging CNN and Transformer This is the official pytorch implementation of the CoTr: Paper: CoTr: Ef

218 Dec 25, 2022
This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network.

GPRGNN This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network. Hidden state feature extraction i

Jianhao 92 Jan 03, 2023
Python package provinding tools for artistic interactive applications using AI

Documentation redrawing Python package provinding tools for artistic interactive applications using AI Created by ReDrawing Campinas team for the Open

ReDrawing Campinas 1 Sep 30, 2021
Unsupervised Pre-training for Person Re-identification (LUPerson)

LUPerson Unsupervised Pre-training for Person Re-identification (LUPerson). The repository is for our CVPR2021 paper Unsupervised Pre-training for Per

143 Dec 24, 2022
Brain tumor detection using Convolution-Neural Network (CNN)

Detect and Classify Brain Tumor using CNN. A system performing detection and classification by using Deep Learning Algorithms using Convolution-Neural Network (CNN).

assia 1 Feb 07, 2022
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
Convolutional neural network web app trained to track our infant’s sleep schedule using our Google Nest camera.

Machine Learning Sleep Schedule Tracker What is it? Convolutional neural network web app trained to track our infant’s sleep schedule using our Google

g-parki 7 Jul 15, 2022
PyTorch experiments with the Zalando fashion-mnist dataset

zalando-pytorch PyTorch experiments with the Zalando fashion-mnist dataset Project Organization ├── LICENSE ├── Makefile - Makefile with co

Federico Baldassarre 31 Sep 25, 2021
Open source annotation tool for machine learning practitioners.

doccano doccano is an open source text annotation tool for humans. It provides annotation features for text classification, sequence labeling and sequ

7.1k Jan 01, 2023
A Runtime method overload decorator which should behave like a compiled language

strongtyping-pyoverload A Runtime method overload decorator which should behave like a compiled language there is a override decorator from typing whi

20 Oct 31, 2022
Prososdy Morph: A python library for manipulating pitch and duration in an algorithmic way, for resynthesizing speech.

ProMo (Prosody Morph) Questions? Comments? Feedback? Chat with us on gitter! A library for manipulating pitch and duration in an algorithmic way, for

Tim 71 Jan 02, 2023
code for CVPR paper Zero-shot Instance Segmentation

Code for CVPR2021 paper Zero-shot Instance Segmentation Code requirements python: python3.7 nvidia GPU pytorch1.1.0 GCC =5.4 NCCL 2 the other python

zhengye 86 Dec 13, 2022
Official Implementation of VAT

Semantic correspondence Few-shot segmentation Cost Aggregation Is All You Need for Few-Shot Segmentation For more information, check out project [Proj

Hamacojr 114 Dec 27, 2022
DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

DeepConsensus DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS)

Google 149 Dec 19, 2022
Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic Scenes", ICCV 2021.

Deep 3D Mask Volume for View Synthesis of Dynamic Scenes Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic S

Ken Lin 17 Oct 12, 2022
Apply AnimeGAN-v2 across frames of a video clip

title emoji colorFrom colorTo sdk app_file pinned AnimeGAN-v2 For Videos 🔥 blue red gradio app.py false AnimeGAN-v2 For Videos Apply AnimeGAN-v2 acro

Nathan Raw 36 Oct 18, 2022