Distributionally robust neural networks for group shifts

Overview

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

This code implements the group DRO algorithm from the following paper:

Shiori Sagawa*, Pang Wei Koh*, Tatsunori Hashimoto, and Percy Liang

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

The experiments use the following datasets:

For an executable, Dockerized version of the experiments in these paper, please see our Codalab worksheet.

Abstract

Overparameterized neural networks can be highly accurate on average on an i.i.d. test set yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that hold on average but not in such groups). Distributionally robust optimization (DRO) allows us to learn models that instead minimize the worst-case training loss over a set of pre-defined groups. However, we find that naively applying group DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss also already has vanishing worst-case training loss. Instead, their poor worst-case performance arises from poor generalization on some groups. By coupling group DRO models with increased regularization---stronger-than-typical L2 regularization or early stopping---we achieve substantially higher worst-group accuracies, with 10-40 percentage point improvements on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is critical for worst-group generalization in the overparameterized regime, even if it is not needed for average generalization. Finally, we introduce and give convergence guarantees for a stochastic optimizer for the group DRO setting, underpinning the empirical study above.

Prerequisites

  • python 3.6.8
  • matplotlib 3.0.3
  • numpy 1.16.2
  • pandas 0.24.2
  • pillow 5.4.1
  • pytorch 1.1.0
  • pytorch_transformers 1.2.0
  • torchvision 0.5.0a0+19315e3
  • tqdm 4.32.2

Datasets and code

To run our code, you will need to change the root_dir variable in data/data.py. The main point of entry to the code is run_expt.py. Below, we provide sample commands for each dataset.

CelebA

Our code expects the following files/folders in the [root_dir]/celebA directory:

  • data/list_eval_partition.csv
  • data/list_attr_celeba.csv
  • data/img_align_celeba/

You can download these dataset files from this Kaggle link. The original dataset, due to Liu et al. (2015), can be found here. The version of the CelebA dataset that we use in the paper (with the (hair, gender) groups) can also be accessed through the WILDS package, which will automatically download the dataset.

A sample command to run group DRO on CelebA is: python run_expt.py -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0001 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 50 --reweight_groups --robust --gamma 0.1 --generalization_adjustment 0

Waterbirds

The Waterbirds dataset is constructed by cropping out birds from photos in the Caltech-UCSD Birds-200-2011 (CUB) dataset (Wah et al., 2011) and transferring them onto backgrounds from the Places dataset (Zhou et al., 2017).

Our code expects the following files/folders in the [root_dir]/cub directory:

  • data/waterbird_complete95_forest2water2/

You can download a tarball of this dataset here. The Waterbirds dataset can also be accessed through the WILDS package, which will automatically download the dataset.

A sample command to run group DRO on Waterbirds is: python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --reweight_groups --robust --gamma 0.1 --generalization_adjustment 0

Note that compared to the training set, the validation and test sets are constructed with different proportions of each group. We describe this in more detail in Appendix C.1 of our paper, which we reproduce here for convenience:

We use the official train-test split of the CUB dataset, randomly choosing 20% of the training data to serve as a validation set. For the validation and test sets, we allocate distribute landbirds and waterbirds equally to land and water backgrounds (i.e., there are the same number of landbirds on land vs. water backgrounds, and separately, the same number of waterbirds on land vs. water backgrounds). This allows us to more accurately measure the performance of the rare groups, and it is particularly important for the Waterbirds dataset because of its relatively small size; otherwise, the smaller groups (waterbirds on land and landbirds on water) would have too few samples to accurately estimate performance on. We note that we can only do this for the Waterbirds dataset because we control the generation process; for the other datasets, we cannot generate more samples from the rare groups.

In a typical application, the validation set might be constructed by randomly dividing up the available training data. We emphasize that this is not the case here: the training set is skewed, whereas the validation set is more balanced. We followed this construction so that we could better compare ERM vs. reweighting vs. group DRO techniques using a stable set of hyperparameters. In practice, if the validation set were also skewed, we might expect hyperparameter tuning based on worst-group accuracy to be more challenging and noisy.

Due to the above procedure, when reporting average test accuracy in our experiments, we calculate the average test accuracy over each group and then report a weighted average, with weights corresponding to the relative proportion of each group in the (skewed) training dataset.

If you'd like to generate variants of this dataset, we have included the script we used to generate this dataset (from the CUB and Places datasets) in dataset_scripts/generate_waterbirds.py. Note that running this script will not create the exact dataset we provide above, due to random seed differences. You will need to download the CUB dataset as well as the Places dataset. We use the high-resolution training images (MD5: 67e186b496a84c929568076ed01a8aa1) from Places. Once you have downloaded and extracted these datasets, edit the corresponding paths in generate_waterbirds.py.

MultiNLI with annotated negations

Our code expects the following files/folders in the [root_dir]/multinli directory:

  • data/metadata_random.csv
  • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli
  • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm
  • glue_data/MNLI/cached_train_bert-base-uncased_128_mnli

We have included the metadata file in dataset_metadata/multinli in this repository. The metadata file records whether each example belongs to the train/val/test dataset as well as whether it contains a negation word.

The glue_data/MNLI files are generated by the huggingface Transformers library and can be downloaded here.

A sample command to run group DRO on MultiNLI is: python run_expt.py -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 2e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 3 --reweight_groups --robust --generalization_adjustment 0

We created our own train/val/test split of the MultiNLI dataset, as described in Appendix C.1 of our paper:

The standard MultiNLI train-test split allocates most examples (approximately 90%) to the training set, with another 5% as a publicly-available development set and the last 5% as a held-out test set that is only accessible through online competition leaderboards (Williams et al., 2018). To accurately estimate performance on rare groups in the validation and test sets, we combine the training set and development set and then randomly resplit it to a 50-20-30 train-val-test split that allocates more examples to the validation and test sets than the standard split.

If you'd like to modify the metadata file (e.g., considering other confounders than the presence of negation words), we have included the script we used to generate the metadata file in dataset_scripts/generate_multinli.py. Note that running this script will not create the exact dataset we provide above, due to random seed differences. You will need to download the MultiNLI dataset and edit the paths in that script accordingly.

A library built upon PyTorch for building embeddings on discrete event sequences using self-supervision

pytorch-lifestream a library built upon PyTorch for building embeddings on discrete event sequences using self-supervision. It can process terabyte-si

Dmitri Babaev 103 Dec 17, 2022
Pytorch implementation of Implicit Behavior Cloning.

Implicit Behavior Cloning - PyTorch (wip) Pytorch implementation of Implicit Behavior Cloning. Install conda create -n ibc python=3.8 pip install -r r

Kevin Zakka 49 Dec 25, 2022
UniFormer - official implementation of UniFormer

UniFormer This repo is the official implementation of "Uniformer: Unified Transf

SenseTime X-Lab 573 Jan 04, 2023
Gym environments used in the paper: "Developmental Reinforcement Learning of Control Policy of a Quadcopter UAV with Thrust Vectoring Rotors"

gym_multirotor Gym to train reinforcement learning agents on UAV platforms Quadrotor Tiltrotor Requirements This package has been tested on Ubuntu 18.

Aditya M. Deshpande 19 Dec 29, 2022
TrTr: Visual Tracking with Transformer

TrTr: Visual Tracking with Transformer We propose a novel tracker network based on a powerful attention mechanism called Transformer encoder-decoder a

趙 漠居(Zhao, Moju) 66 Dec 27, 2022
Bald-to-Hairy Translation Using CycleGAN

GANiry: Bald-to-Hairy Translation Using CycleGAN Official PyTorch implementation of GANiry. GANiry: Bald-to-Hairy Translation Using CycleGAN, Fidan Sa

Fidan Samet 10 Oct 27, 2022
上海交通大学全自动抢课脚本,支持准点开抢与抢课后持续捡漏两种模式。2021/06/08更新。

Welcome to Course-Bullying-in-SJTU-v3.1! 2021/6/8 紧急更新v3.1 更新说明 为了更好地保护用户隐私,将原来用户名+密码的登录方式改为微信扫二维码+cookie登录方式,不再需要配置使用pytesseract。在使用扫码登录模式时,请稍等,二维码将马

87 Sep 13, 2022
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022
An Unpaired Sketch-to-Photo Translation Model

Unpaired-Sketch-to-Photo-Translation We have released our code at https://github.com/rt219/Unsupervised-Sketch-to-Photo-Synthesis This project is the

38 Oct 28, 2022
Official code of Team Yao at Multi-Modal-Fact-Verification-2022

Official code of Team Yao at Multi-Modal-Fact-Verification-2022 A Multi-Modal Fact Verification dataset released as part of the De-Factify workshop in

Wei-Yao Wang 11 Nov 15, 2022
Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions

README Repository containing the code for the paper "Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions". Specifically, an

Yousef Emam 13 Nov 24, 2022
Self-Supervised Learning

Self-Supervised Learning Features self_supervised offers features like modular framework support for multi-gpu training using PyTorch Lightning easy t

Robin 1 Dec 14, 2021
Hands-On Machine Learning for Algorithmic Trading, published by Packt

Hands-On Machine Learning for Algorithmic Trading Hands-On Machine Learning for Algorithmic Trading, published by Packt This is the code repository fo

Packt 981 Dec 29, 2022
Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs

Perceiver IO Unofficial implementation of Perceiver IO: A General Architecture for Structured Inputs & Outputs Usage import torch from src.perceiver.

Timur Ganiev 111 Nov 15, 2022
Sequence lineage information extracted from RKI sequence data repo

Pango lineage information for German SARS-CoV-2 sequences This repository contains a join of the metadata and pango lineage tables of all German SARS-

Cornelius Roemer 24 Oct 26, 2022
Official Implementation of Domain-Aware Universal Style Transfer

Domain Aware Universal Style Transfer Official Pytorch Implementation of 'Domain Aware Universal Style Transfer' (ICCV 2021) Domain Aware Universal St

KibeomHong 80 Dec 30, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
Repository accompanying the "Sign Pose-based Transformer for Word-level Sign Language Recognition" paper

by Matyáš Boháček and Marek Hrúz, University of West Bohemia Should you have any questions or inquiries, feel free to contact us here. Repository acco

Matyáš Boháček 30 Dec 30, 2022
This is the dataset for testing the robustness of various VO/VIO methods

KAIST VIO dataset This is the dataset for testing the robustness of various VO/VIO methods You can download the whole dataset on KAIST VIO dataset Ind

1 Sep 01, 2022
ALFRED - A Benchmark for Interpreting Grounded Instructions for Everyday Tasks

ALFRED A Benchmark for Interpreting Grounded Instructions for Everyday Tasks Mohit Shridhar, Jesse Thomason, Daniel Gordon, Yonatan Bisk, Winson Han,

ALFRED 204 Dec 15, 2022