Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.

Overview

Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG)

Divyat Mahajan, Shruti Tople, Amit Sharma

Privacy & Causal Learning (ICML 2020) | MatchDG: Causal View of DG (ICML 2021) | Privacy & DG Connection paper

For machine learning models to be reliable, they need to generalize to data beyond the train distribution. In addition, ML models should be robust to privacy attacks like membership inference and domain knowledge-based attacks like adversarial attacks.

To advance research in building robust and generalizable models, we are releasing a toolkit for building and evaluating ML models, RobustDG. RobustDG contains implementations of domain generalization algorithms and includes evaluation benchmarks based on out-of-distribution accuracy and robustness to membership privacy attacks. We will be adding evaluation for adversarial attacks and more privacy attacks soon.

It is easily extendable. Add your own DG algorithms and evaluate them on different benchmarks.

Installation

To use the command-line interface of RobustDG, clone this repo and add the folder to your system's PATH (or alternatively, run the commands from the RobustDG root directory).

Load dataset

Let's first load the rotatedMNIST dataset in a suitable format for the resnet18 architecture.

python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000

Train and evaluate ML model

The following commands would train and evalute the MatchDG method on the Rotated MNIST dataset.

python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --match_func_aug_case 1

python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25

python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25 --test_metric acc

python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score

Demo

A quick introduction on how to use our repository can be accessed here in the Getting Started notebook.

If you are interested in reproducing results from the MatchDG paper, check out the Reproducing results notebook.

Roadmap

  • Support for more domain generalization algorithms like CSD and IRM. If you are an author of a DG algorithm and would like to contribute, please raise a pull request here or get in touch.
  • More evaluation metrics based on adversarial attacks, privacy attacks like model inversion. If you'd like to see an evaluation metric implemented, please raise an issue here.

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Backtesting an algorithmic trading strategy using Machine Learning and Sentiment Analysis.

Trading Tesla with Machine Learning and Sentiment Analysis An interactive program to train a Random Forest Classifier to predict Tesla daily prices us

Renato Votto 31 Nov 17, 2022
[HELP REQUESTED] Generalized Additive Models in Python

pyGAM Generalized Additive Models in Python. Documentation Official pyGAM Documentation: Read the Docs Building interpretable models with Generalized

daniel servén 747 Jan 05, 2023
Machine Learning approach for quantifying detector distortion fields

DistortionML Machine Learning approach for quantifying detector distortion fields. This project is a feasibility study for training a surrogate model

Joel Bernier 1 Nov 05, 2021
jaxfg - Factor graph-based nonlinear optimization library for JAX.

Factor graphs + nonlinear optimization in JAX

Brent Yi 134 Dec 21, 2022
Lightning ⚡️ fast forecasting with statistical and econometric models.

Nixtla Statistical ⚡️ Forecast Lightning fast forecasting with statistical and econometric models StatsForecast offers a collection of widely used uni

Nixtla 2.1k Dec 29, 2022
Python-based implementations of algorithms for learning on imbalanced data.

ND DIAL: Imbalanced Algorithms Minimalist Python-based implementations of algorithms for imbalanced learning. Includes deep and representational learn

DIAL | Notre Dame 220 Dec 13, 2022
Timeseries analysis for neuroscience data

=================================================== Nitime: timeseries analysis for neuroscience data ===============================================

NIPY developers 212 Dec 09, 2022
Python implementation of Weng-Lin Bayesian ranking, a better, license-free alternative to TrueSkill

Python implementation of Weng-Lin Bayesian ranking, a better, license-free alternative to TrueSkill This is a port of the amazing openskill.js package

Open Debates Project 156 Dec 14, 2022
Implementation of different ML Algorithms from scratch, written in Python 3.x

Implementation of different ML Algorithms from scratch, written in Python 3.x

Gautam J 393 Nov 29, 2022
This is my implementation on the K-nearest neighbors algorithm from scratch using Python

K Nearest Neighbors (KNN) algorithm In this Machine Learning world, there are various algorithms designed for classification problems such as Logistic

sonny1902 1 Jan 08, 2022
TensorFlow Decision Forests (TF-DF) is a collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models.

TensorFlow Decision Forests (TF-DF) is a collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models. The library is a collection of Keras models

538 Jan 01, 2023
A comprehensive repository containing 30+ notebooks on learning machine learning!

A comprehensive repository containing 30+ notebooks on learning machine learning!

Jean de Dieu Nyandwi 3.8k Jan 09, 2023
The code from the Machine Learning Bookcamp book and a free course based on the book

The code from the Machine Learning Bookcamp book and a free course based on the book

Alexey Grigorev 5.5k Jan 09, 2023
Python package for machine learning for healthcare using a OMOP common data model

This library was developed in order to facilitate rapid prototyping in Python of predictive machine-learning models using longitudinal medical data from an OMOP CDM-standard database.

Sontag Lab 75 Jan 03, 2023
Microsoft contributing libraries, tools, recipes, sample codes and workshop contents for machine learning & deep learning.

Microsoft contributing libraries, tools, recipes, sample codes and workshop contents for machine learning & deep learning.

Microsoft 366 Jan 03, 2023
A collection of video resources for machine learning

Machine Learning Videos This is a collection of recorded talks at machine learning conferences, workshops, seminars, summer schools, and miscellaneous

Dustin Tran 1.5k Dec 29, 2022
ml4h is a toolkit for machine learning on clinical data of all kinds including genetics, labs, imaging, clinical notes, and more

ml4h is a toolkit for machine learning on clinical data of all kinds including genetics, labs, imaging, clinical notes, and more

Broad Institute 65 Dec 20, 2022
Simple Machine Learning Tool Kit

Getting started smltk (Simple Machine Learning Tool Kit) package is implemented for helping your work during data preparation testing your model The g

Alessandra Bilardi 1 Dec 30, 2021
A repository of PyBullet utility functions for robotic motion planning, manipulation planning, and task and motion planning

pybullet-planning (previously ss-pybullet) A repository of PyBullet utility functions for robotic motion planning, manipulation planning, and task and

Caelan Garrett 260 Dec 27, 2022
A simple example of ML classification, cross validation, and visualization of feature importances

Simple-Classifier This is a basic example of how to use several different libraries for classification and ensembling, mostly with sklearn. Example as

Rob 2 Aug 25, 2022