Convert BART models to ONNX with quantization. 3X reduction in size, and upto 3X boost in inference speed

Overview

fast-Bart

Reduction of BART model size by 3X, and boost in inference speed up to 3X

BART implementation of the fastT5 library (https://github.com/Ki6an/fastT5)

Pytorch model -> ONNX model -> Quantized ONNX model


Install

Install using requirements.txt file

git clone https://github.com/siddharth-sharma7/fast-Bart
cd fast-Bart
pip install -r requirements.txt

Usage

The export_and_get_onnx_model() method exports the given pretrained Bart model to onnx, quantizes it and runs it on the onnxruntime with default settings. The returned model from this method supports the generate() method of huggingface.

If you don't wish to quantize the model then use quantized=False in the method.

from fastBart import export_and_get_onnx_model
from transformers import AutoTokenizer

model_name = 'facebook/bart-base'
model = export_and_get_onnx_model(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input = "This is a very long sentence and needs to be summarized."
token = tokenizer(input, return_tensors='pt')

tokens = model.generate(input_ids=token['input_ids'],
               attention_mask=token['attention_mask'],
               num_beams=3)

output = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
print(output)

to run the already exported model use get_onnx_model()

you can customize the whole pipeline as shown in the below code example:

from fastBart import (OnnxBart, get_onnx_runtime_sessions,
                    generate_onnx_representation, quantize)
from transformers import AutoTokenizer

model_or_model_path = 'facebook/bart-base'

# Step 1. convert huggingfaces bart model to onnx
onnx_model_paths = generate_onnx_representation(model_or_model_path)

# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
# The process is slow for the decoder and init-decoder onnx files (can take up to 15 mins)
quant_model_paths = quantize(onnx_model_paths)

# step 3. setup onnx runtime
model_sessions = get_onnx_runtime_sessions(quant_model_paths)

# step 4. get the onnx model
model = OnnxBart(model_or_model_path, model_sessions)

                      ...
custom output paths

By default, fastBart creates a models-bart folder in the current directory and stores all the models. You can provide a custom path for a folder to store the exported models. And to run already exported models that are stored in a custom folder path: use get_onnx_model(onnx_models_path="/path/to/custom/folder/")

from fastBart import export_and_get_onnx_model, get_onnx_model

model_name = "facebook/bart-base"
custom_output_path = "/path/to/custom/folder/"

# 1. stores models to custom_output_path
model = export_and_get_onnx_model(model_name, custom_output_path)

# 2. run already exported models that are stored in custom path
# model = get_onnx_model(model_name, custom_output_path)

Functionalities

  • Export any pretrained Bart model to ONNX easily.
  • The exported model supports beam search and greedy search and more via generate() method.
  • Reduce the model size by 3X using quantization.
  • Up to 3X speedup compared to PyTorch execution for greedy search and 2-3X for beam search.
Owner
Siddharth Sharma
Machine learning | NLP | Computer Vision
Siddharth Sharma
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration (NeurIPS 2021) PyTorch implementation of the paper: CoFiNet: Reli

76 Jan 03, 2023
Connecting Java/ImgLib2 + Python/NumPy

imglyb imglyb aims at connecting two worlds that have been seperated for too long: Python with numpy Java with ImgLib2 imglyb uses jpype to access num

ImgLib2 29 Dec 21, 2022
A Genetic Programming platform for Python with TensorFlow for wicked-fast CPU and GPU support.

Karoo GP Karoo GP is an evolutionary algorithm, a genetic programming application suite written in Python which supports both symbolic regression and

Kai Staats 149 Jan 09, 2023
Curated list of awesome GAN applications and demo

gans-awesome-applications Curated list of awesome GAN applications and demonstrations. Note: General GAN papers targeting simple image generation such

Minchul Shin 4.5k Jan 07, 2023
The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text"

Finnish Dialect Identification The repository for our EMNLP 2021 paper "Finnish Dialect Identification: The Effect of Audio and Text". We present a te

Rootroo Ltd 2 Dec 25, 2021
Code for "Reconstructing 3D Human Pose by Watching Humans in the Mirror", CVPR 2021 oral

Reconstructing 3D Human Pose by Watching Humans in the Mirror Qi Fang*, Qing Shuai*, Junting Dong, Hujun Bao, Xiaowei Zhou CVPR 2021 Oral The videos a

ZJU3DV 178 Dec 13, 2022
Generalized Proximal Policy Optimization with Sample Reuse (GePPO)

Generalized Proximal Policy Optimization with Sample Reuse This repository is the official implementation of the reinforcement learning algorithm Gene

Jimmy Queeney 9 Nov 28, 2022
Code for CVPR 2021 oral paper "Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts"

Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts The rapid progress in 3D scene understanding has come with growing dem

Facebook Research 182 Dec 30, 2022
SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning

SPCL SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning Update on 2021/11/25: ArXiv Ver

Binhui Xie (谢斌辉) 11 Oct 29, 2022
Revisting Open World Object Detection

Revisting Open World Object Detection Installation See INSTALL.md. Dataset Our n

58 Dec 23, 2022
Implementation for the paper: Invertible Denoising Network: A Light Solution for Real Noise Removal (CVPR2021).

Invertible Image Denoising This is the PyTorch implementation of paper: Invertible Denoising Network: A Light Solution for Real Noise Removal (CVPR 20

157 Dec 25, 2022
A sample pytorch Implementation of ACL 2021 research paper "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction".

Span-ASTE-Pytorch This repository is a pytorch version that implements Ali's ACL 2021 research paper Learning Span-Level Interactions for Aspect Senti

来自丹麦的天籁 10 Dec 06, 2022
A different spin on dataclasses.

dataklasses Dataklasses is a library that allows you to quickly define data classes using Python type hints. Here's an example of how you use it: from

David Beazley 752 Nov 18, 2022
Human Dynamics from Monocular Video with Dynamic Camera Movements

Human Dynamics from Monocular Video with Dynamic Camera Movements Ri Yu, Hwangpil Park and Jehee Lee Seoul National University ACM Transactions on Gra

215 Jan 01, 2023
MediaPipeのPythonパッケージのサンプルです。2020/12/11時点でPython実装のある4機能(Hands、Pose、Face Mesh、Holistic)について用意しています。

mediapipe-python-sample MediaPipeのPythonパッケージのサンプルです。 2020/12/11時点でPython実装のある以下4機能について用意しています。 Hands Pose Face Mesh Holistic Requirement mediapipe 0.

KazuhitoTakahashi 217 Dec 12, 2022
Experiments for Neural Flows paper

Neural Flows: Efficient Alternative to Neural ODEs [arxiv] TL;DR: We directly model the neural ODE solutions with neural flows, which is much faster a

54 Dec 07, 2022
School of Artificial Intelligence at the Nanjing University (NJU)School of Artificial Intelligence at the Nanjing University (NJU)

F-Principle This is an exercise problem of the digital signal processing (DSP) course at School of Artificial Intelligence at the Nanjing University (

Thyrix 5 Nov 23, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
Official repo for BMVC2021 paper ASFormer: Transformer for Action Segmentation

ASFormer: Transformer for Action Segmentation This repo provides training & inference code for BMVC 2021 paper: ASFormer: Transformer for Action Segme

42 Dec 23, 2022