Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Related tags

Deep LearningMPOP
Overview

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

This is our Pytorch implementation for the paper:

Peiyu Liu, Ze-Feng Gao, Wayne Xin Zhao, Zhi-Yuan Xie, Zhong-Yi Lu and Ji-Rong Wen(2021). Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Introduction

This paper presents a novel pre-trained language models (PLM) compression approach based on the matrix product operator (short as MPO) from quantum many-body physics. It can decompose an original matrix into central tensors (containing the core information) and auxiliary tensors (with only a small proportion of parameters). With the decomposed MPO structure, we propose a novel fine-tuning strategy by only updating the parameters from the auxiliary tensors, and design an optimization algorithm for MPO-based approximation over stacked network architectures. Our approach can be applied to the original or the compressed PLMs in a general way, which derives a lighter network and significantly reduces the parameters to be fine-tuned. Extensive experiments have demonstrated the effectiveness of the proposed approach in model compression, especially the reduction in fine-tuning parameters (91% reduction on average).

image

For more details about the technique of MPOP, refer to our paper

Release Notes

  • First version: 2021/05/21
  • add albert code: 2021/06/08

Requirements

  • python 3.7
  • torch >= 1.8.0

Installation

pip install mpo_lab

Lightweight fine-tuning

In lightweight fine-tuning, we use original ALBERT without fine-tuning as to be compressed. By performing MPO decomposition on each weight matrix, we obtain four auxiliary tensors and one central tensor per tensor set. This provides a good initialization for the task-specific distillation. Refer to run_all_albert_fine_tune.sh

Important arguments:

--data_dir          Path to load dataset
--mpo_lr            Learning rate of tensors produced by MPO
--mpo_layers        Name of components to be decomposed with MPO
--emb_trunc         Truncation number of the central tensor in word embedding layer
--linear_trunc      Truncation number of the central tensor in linear layer
--attention_trunc   Truncation number of the central tensor in attention layer
--load_layer        Name of components to be loaded from exist checkpoint file
--update_mpo_layer  Name of components to be update when training the model

Dimension squeezing

In Dimension squeezing, we compute approiate truncation order for the whole model. In order to re-produce the results in paper, we prepare the model after lightweight fine-tuning. Refer to run_all_albert_fine_tune.sh

albert models google drive

Acknowledgment

Any scientific publications that use our codes should cite the following paper as the reference:

@inproceedings{Liu-ACL-2021,
  author    = {Peiyu Liu and
               Ze{-}Feng Gao and
               Wayne Xin Zhao and
               Z. Y. Xie and
               Zhong{-}Yi Lu and
               Ji{-}Rong Wen},
  title     = "Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression
               based on Matrix Product Operators",
  booktitle = {{ACL}},
  year      = {2021},
}

TODO

  • prepare data and code
  • upload models in order to reproduce experiments
  • supplementary details for paper
Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
Implementation of the paper "Generating Symbolic Reasoning Problems with Transformer GANs"

Generating Symbolic Reasoning Problems with Transformer GANs This is the implementation of the paper Generating Symbolic Reasoning Problems with Trans

Reactive Systems Group 1 Apr 18, 2022
PASSL包含 SimCLR,MoCo,BYOL,CLIP等基于对比学习的图像自监督算法以及 Vision-Transformer,Swin-Transformer,BEiT,CVT,T2T,MLP_Mixer等视觉Transformer算法

PASSL Introduction PASSL is a Paddle based vision library for state-of-the-art Self-Supervised Learning research with PaddlePaddle. PASSL aims to acce

186 Dec 29, 2022
Yolov5 + Deep Sort with PyTorch

딥소트 수정중 Yolov5 + Deep Sort with PyTorch Introduction This repository contains a two-stage-tracker. The detections generated by YOLOv5, a family of obj

1 Nov 26, 2021
A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021)

Manifold Matching via Deep Metric Learning for Generative Modeling A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generat

69 Dec 10, 2022
Rlmm blender toolkit - A set of tools to streamline level generation in UDK straight from Blender

rlmm_blender_toolkit A set of tools to streamline level generation in UDK straig

Rocket League Mapmaking 0 Jan 15, 2022
GND-Nets (Graph Neural Diffusion Networks) in TensorFlow.

GNDC For submission to IEEE TKDE. Overview Here we provide the implementation of GND-Nets (Graph Neural Diffusion Networks) in TensorFlow. The reposit

Wei Ye 3 Aug 08, 2022
The source code of CVPR 2019 paper "Deep Exemplar-based Video Colorization".

Deep Exemplar-based Video Colorization (Pytorch Implementation) Paper | Pretrained Model | Youtube video 🔥 | Colab demo Deep Exemplar-based Video Col

Bo Zhang 253 Dec 27, 2022
A Lightweight Hyperparameter Optimization Tool 🚀

Lightweight Hyperparameter Optimization 🚀 The mle-hyperopt package provides a simple and intuitive API for hyperparameter optimization of your Machin

136 Jan 08, 2023
Deployment of PyTorch chatbot with Flask

Chatbot Deployment with Flask and JavaScript In this tutorial we deploy the chatbot I created in this tutorial with Flask and JavaScript. This gives 2

Patrick Loeber (Python Engineer) 107 Dec 29, 2022
Self-Adaptable Point Processes with Nonparametric Time Decays

NPPDecay This is our implementation for the paper Self-Adaptable Point Processes with Nonparametric Time Decays, by Zhimeng Pan, Zheng Wang, Jeff M. P

zpan 2 Sep 24, 2022
Materials for my scikit-learn tutorial

Scikit-learn Tutorial Jake VanderPlas email: [email protected] twitter: @jakevdp gith

Jake Vanderplas 1.6k Dec 30, 2022
Implementation of fast algorithms for Maximum Spanning Tree (MST) parsing that includes fast ArcMax+Reweighting+Tarjan algorithm for single-root dependency parsing.

Fast MST Algorithm Implementation of fast algorithms for (Maximum Spanning Tree) MST parsing that includes fast ArcMax+Reweighting+Tarjan algorithm fo

Miloš Stanojević 11 Oct 14, 2022
Code for Ditto: Building Digital Twins of Articulated Objects from Interaction

Ditto: Building Digital Twins of Articulated Objects from Interaction Zhenyu Jiang, Cheng-Chun Hsu, Yuke Zhu CVPR 2022, Oral Project | arxiv News 2022

UT Robot Perception and Learning Lab 78 Dec 22, 2022
Implementation of "Bidirectional Projection Network for Cross Dimension Scene Understanding" CVPR 2021 (Oral)

Bidirectional Projection Network for Cross Dimension Scene Understanding CVPR 2021 (Oral) [ Project Webpage ] [ arXiv ] [ Video ] Existing segmentatio

Hu Wenbo 135 Dec 26, 2022
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022
3D mesh stylization driven by a text input in PyTorch

Text2Mesh [Project Page] Text2Mesh is a method for text-driven stylization of a 3D mesh, as described in "Text2Mesh: Text-Driven Neural Stylization fo

Threedle (University of Chicago) 649 Dec 27, 2022
"NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search".

NAS-Bench-301 This repository containts code for the paper: "NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search". The

AutoML-Freiburg-Hannover 57 Nov 30, 2022
Neural-net-from-scratch - A simple Neural Network from scratch in Python using the Pymathrix library

A Simple Neural Network from scratch A Simple Neural Network from scratch in Pyt

Youssef Chafiqui 2 Jan 07, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022