The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification".

Overview

Rule-based Representation Learner

This is a PyTorch implementation of Rule-based Representation Learner (RRL) as described in NeurIPS 2021 paper: Scalable Rule-Based Representation Learning for Interpretable Classification.

drawing

RRL aims to obtain both good scalability and interpretability, and it automatically learns interpretable non-fuzzy rules for data representation and classification. Moreover, RRL can be easily adjusted to obtain a trade-off between classification accuracy and model complexity for different scenarios.

Requirements

  • torch>=1.3.0
  • torchvision>=0.4.1
  • tensorboard>=2.0.0
  • sklearn>=0.22.2.post1
  • numpy>=1.17.2
  • pandas>=0.24.2
  • matplotlib>=3.0.3
  • CUDA==10.1

Run the demo

We need to put the data sets in the dataset folder. You can specify one data set in the dataset folder and train the model as follows:

# trained on the tic-tac-toe data set with one GPU.
python3 experiment.py -d tic-tac-toe -bs 32 -s [email protected] -e401 -lrde 200 -lr 0.002 -ki 0 -mp 12481 -i 0 -wd 1e-6 &

The demo reads the data set and data set information first, then trains the RRL on the training set. During the training, you can check the training loss and the evaluation result on the validation set by:

tensorboard --logdir=log_folder/ --bind_all

drawing

The training log file (log.txt) can be found in a folder created in log_folder. In this example, the folder path is

log_folder/tic-tac-toe/tic-tac-toe_e401_bs32_lr0.002_lrdr0.75_lrde200_wd1[email protected]

After training, the evaluation result on the test set is shown in the file test_res.txt:

[INFO] - On Test Set:
        Accuracy of RRL  Model: 1.0
        F1 Score of RRL  Model: 1.0

Moreover, the trained RRL model is saved in model.pth, and the discrete RRL is printed in rrl.txt:

RID class_negative(b=-2.1733) class_positive(b=1.9689) Support Rule
(-1, 1) -5.8271 6.3045 0.0885 3_x & 6_x & 9_x
(-1, 2) -5.4949 5.4566 0.0781 7_x & 8_x & 9_x
(-1, 4) -4.5605 4.7578 0.1146 1_x & 2_x & 3_x
...... ...... ...... ...... ......

Your own data sets

You can use the demo to train RRL on your own data set by putting the data and data information files in the dataset folder. Please read DataSetDesc for a more specific guideline.

Available arguments

List all the available arguments and their default values by:

$ python3 experiment.py --help
usage: experiment.py [-h] [-d DATA_SET] [-i DEVICE_IDS] [-nr NR] [-e EPOCH]
                     [-bs BATCH_SIZE] [-lr LEARNING_RATE]
                     [-lrdr LR_DECAY_RATE] [-lrde LR_DECAY_EPOCH]
                     [-wd WEIGHT_DECAY] [-ki ITH_KFOLD] [-rc ROUND_COUNT]
                     [-ma MASTER_ADDRESS] [-mp MASTER_PORT] [-li LOG_ITER]
                     [--use_not] [--save_best] [--estimated_grad]
                     [-s STRUCTURE]

optional arguments:
  -h, --help            show this help message and exit
  -d DATA_SET, --data_set DATA_SET
                        Set the data set for training. All the data sets in
                        the dataset folder are available. (default: tic-tac-
                        toe)
  -i DEVICE_IDS, --device_ids DEVICE_IDS
                        Set the device (GPU ids). Split by @. E.g., [email protected]@3.
                        (default: None)
  -nr NR, --nr NR       ranking within the nodes (default: 0)
  -e EPOCH, --epoch EPOCH
                        Set the total epoch. (default: 41)
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        Set the batch size. (default: 64)
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        Set the initial learning rate. (default: 0.01)
  -lrdr LR_DECAY_RATE, --lr_decay_rate LR_DECAY_RATE
                        Set the learning rate decay rate. (default: 0.75)
  -lrde LR_DECAY_EPOCH, --lr_decay_epoch LR_DECAY_EPOCH
                        Set the learning rate decay epoch. (default: 10)
  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
                        Set the weight decay (L2 penalty). (default: 0.0)
  -ki ITH_KFOLD, --ith_kfold ITH_KFOLD
                        Do the i-th 5-fold validation, 0 <= ki < 5. (default:
                        0)
  -rc ROUND_COUNT, --round_count ROUND_COUNT
                        Count the round of experiments. (default: 0)
  -ma MASTER_ADDRESS, --master_address MASTER_ADDRESS
                        Set the master address. (default: 127.0.0.1)
  -mp MASTER_PORT, --master_port MASTER_PORT
                        Set the master port. (default: 12345)
  -li LOG_ITER, --log_iter LOG_ITER
                        The number of iterations (batches) to log once.
                        (default: 50)
  --use_not             Use the NOT (~) operator in logical rules. It will
                        enhance model capability but make the RRL more
                        complex. (default: False)
  --save_best           Save the model with best performance on the validation
                        set. (default: False)
  --estimated_grad      Use estimated gradient. (default: False)
  -s STRUCTURE, --structure STRUCTURE
                        Set the number of nodes in the binarization layer and
                        logical layers. E.g., [email protected], [email protected]@[email protected]. (default:
                        [email protected])

Citation

If our work is helpful to you, please kindly cite our paper as:

@article{wang2021scalable,
  title={Scalable Rule-Based Representation Learning for Interpretable Classification},
  author={Wang, Zhuo and Zhang, Wei and Liu, Ning and Wang, Jianyong},
  journal={arXiv preprint arXiv:2109.15103},
  year={2021}
}

License

MIT license

Owner
Zhuo Wang
Ph.D. student
Zhuo Wang
Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis in JAX

SYMPAIS: Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis Overview | Installation | Documentation | Examples | Notebo

Yicheng Luo 4 Sep 13, 2022
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation Created by Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas from Sta

Charles R. Qi 4k Dec 30, 2022
Homepage of paper: Paint Transformer: Feed Forward Neural Painting with Stroke Prediction, ICCV 2021.

Paint Transformer: Feed Forward Neural Painting with Stroke Prediction [Paper] [Official Paddle Implementation] [Huggingface Gradio Demo] [Unofficial

442 Dec 16, 2022
Tensors and neural networks in Haskell

Hasktorch Hasktorch is a library for tensors and neural networks in Haskell. It is an independent open source community project which leverages the co

hasktorch 920 Jan 04, 2023
SHIFT15M: multiobjective large-scale fashion dataset with distributional shifts

[arXiv] The main motivation of the SHIFT15M project is to provide a dataset that contains natural dataset shifts collected from a web service IQON, wh

ZOZO, Inc. 138 Nov 24, 2022
Receptive Field Block Net for Accurate and Fast Object Detection, ECCV 2018

Receptive Field Block Net for Accurate and Fast Object Detection By Songtao Liu, Di Huang, Yunhong Wang Updatas (2021/07/23): YOLOX is here!, stronger

Liu Songtao 1.4k Dec 21, 2022
Defending against Model Stealing via Verifying Embedded External Features

Defending against Model Stealing Attacks via Verifying Embedded External Features This is the official implementation of our paper Defending against M

20 Dec 30, 2022
GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning

GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning, as well as corresponding mitigation strategies.

129 Dec 30, 2022
AgML is a comprehensive library for agricultural machine learning

AgML is a comprehensive library for agricultural machine learning. Currently, AgML provides access to a wealth of public agricultural datasets for common agricultural deep learning tasks.

Plant AI and Biophysics Lab 1 Jul 07, 2022
EASY - Ensemble Augmented-Shot Y-shaped Learning: State-Of-The-Art Few-Shot Classification with Simple Ingredients.

EASY - Ensemble Augmented-Shot Y-shaped Learning: State-Of-The-Art Few-Shot Classification with Simple Ingredients. This repository is the official im

Yassir BENDOU 57 Dec 26, 2022
Realtime micro-expression recognition using OpenCV and PyTorch

Micro-expression Recognition Realtime micro-expression recognition from scratch using OpenCV and PyTorch Try it out with a webcam or video using the e

Irfan 35 Dec 05, 2022
Fast, modular reference implementation and easy training of Semantic Segmentation algorithms in PyTorch.

TorchSeg This project aims at providing a fast, modular reference implementation for semantic segmentation models using PyTorch. Highlights Modular De

ycszen 1.4k Jan 02, 2023
Python implementation of Project Fluent

Project Fluent This is a collection of Python packages to use the Fluent localization system. python-fluent consists of these packages: fluent.syntax

Project Fluent 155 Dec 28, 2022
Pytorch implementation for DFN: Distributed Feedback Network for Single-Image Deraining.

DFN:Distributed Feedback Network for Single-Image Deraining Abstract Recently, deep convolutional neural networks have achieved great success for sing

6 Nov 05, 2022
Extension to fastai for volumetric medical data

FAIMED 3D use fastai to quickly train fully three-dimensional models on radiological data Classification from faimed3d.all import * Load data in vari

Keno 26 Aug 22, 2022
Image data augmentation scheduler for albumentations transforms

albu_scheduler Scheduler for albumentations transforms based on PyTorch schedulers interface Usage TransformMultiStepScheduler import albumentations a

19 Aug 04, 2021
CountDown to New Year and shoot fireworks

CountDown and Shoot Fireworks About App This is an small application make you re

5 Dec 31, 2022
Baseline for the Spoofing-aware Speaker Verification Challenge 2022

Introduction This repository contains several materials that supplements the Spoofing-Aware Speaker Verification (SASV) Challenge 2022 including: calc

40 Dec 28, 2022
Reproduction process of AlexNet

PaddlePaddle论文复现杂谈 背景 注:该repo基于PaddlePaddle,对AlexNet进行复现。时间仓促,难免有所疏漏,如果问题或者想法,欢迎随时提issue一块交流。 飞桨论文复现赛地址:https://aistudio.baidu.com/aistudio/competitio

19 Nov 29, 2022
Complete system for facial identity system

Complete system for facial identity system. Include one-shot model, database operation, features visualization, monitoring

4 May 02, 2022