[ICLR'21] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

Related tags

Deep LearningFedBN
Overview

FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

This is the PyTorch implemention of our paper FedBN: Federated Learning on Non-IID Features via Local Batch Normalization by Xiaoxiao Li, Meirui Jiang, Xiaofei Zhang, Michael Kamp and Qi Dou

Abstract

The emerging paradigm of federated learning (FL) strives to enable collaborative training of deep models on the network edge without centrally aggregating raw data and hence improving data privacy. In most cases, the assumption of independent and identically distributed samples across local clients does not hold for federated learning setups. Under this setting, neural network training performance may vary significantly according to the data distribution and even hurt training convergence. Most of the previous work has focused on a difference in the distribution of labels. Unlike those settings, we address an important problem of FL, e.g., different scanner/sensors in medical imaging, different scenery distribution in autonomous driving (highway vs. city), where local clients may store examples with different marginal or conditional feature distributions compared to other nodes, which we denote as feature shift non-iid. In this work, we propose an effective method that uses local batch normalization to alleviate the feature shift before averaging models. The resulting scheme, called FedBN, outperforms both classical FedAvg, as well as the state-of-the-art for non-iid data (FedProx) on our extensive experiments. These empirical results are supported by a convergence analysis that shows in a simplified setting that FedBN has a faster convergence rate in expectation than FedAvg.

avatar

Usage

Setup

pip

See the requirements.txt for environment configuration.

pip install -r requirements.txt

conda

We recommend using conda to quick setup the environment. Please use the following commands.

conda env create -f environment.yaml
conda activate fedbn

Dataset & Pretrained Modeel

Benchmark(Digits)

  • Please download our pre-processed datasets here, put under data/ directory and perform following commands:
    cd ./data
    unzip digit_dataset.zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip digit_model.zip

office-caltech10

  • Please download our pre-processed datasets here, put under data/ directory and perform following commands:
    cd ./data
    unzip office_caltech_10_dataset.zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip office_caltech_10_model.zip

DomainNet

  • Please first download our splition here, put under data/ directory and perform following commands:
    cd ./data
    unzip domainnet_dataset.zip
  • then download dataset including: Clipart, Infograph, Painting, Quickdraw, Real, Sketch, put under data/DomainNet directory and unzip them.
    cd ./data/DomainNet
    unzip [filename].zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip domainnet_model.zip

Train

Federated Learning

Please using following commands to train a model with federated learning strategy.

  • --mode specify federated learning strategy, option: fedavg | fedprox | fedbn
cd federated
# benchmark experiment
python fed_digits.py --mode fedbn

# office-caltech-10 experiment
python fed_office.py --mode fedbn

# DomaiNnet experiment
python fed_domainnet.py --mode fedbn

SingleSet

Please using following commands to train a model using singleset data.

  • --data specify the single dataset
cd singleset 
# benchmark experiment, --data option: svhn | usps | synth | mnistm | mnist
python single_digits.py --data svhn

# office-caltech-10 experiment --data option: amazon | caltech | dslr | webcam
python single_office.py --data amazon

# DomaiNnet experiment --data option: clipart | infograph | painting | quickdraw | real | sketch
python single_domainnet.py --data clipart

Test

cd federated
# benchmark experiment
python fed_digits.py --mode fedbn --test

# office-caltech-10 experiment
python fed_office.py --mode fedbn --test

# DomaiNnet experiment
python fed_domainnet.py --mode fedbn --test

Citation

If you find the code and dataset useful, please cite our paper.

@inproceedings{
li2021fedbn,
title={Fed{\{}BN{\}}: Federated Learning on Non-{\{}IID{\}} Features via Local Batch Normalization},
author={Xiaoxiao Li and Meirui JIANG and Xiaofei Zhang and Michael Kamp and Qi Dou},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=6YEQUn0QICG}
}
Owner
[email protected]
Medical Image Analysis, Artificial Intelligence, Robotics
<a href=[email protected]">
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
Efficient Training of Visual Transformers with Small Datasets

Official codes for "Efficient Training of Visual Transformers with Small Datasets", NerIPS 2021.

Yahui Liu 112 Dec 25, 2022
PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

36 Oct 30, 2022
This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?”

This repository accompanies our paper “Do Prompt-Based Models Really Understand the Meaning of Their Prompts?” Usage To replicate our results in Secti

Albert Webson 64 Dec 11, 2022
Official implementation of Meta-StyleSpeech and StyleSpeech

Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation Dongchan Min, Dong Bok Lee, Eunho Yang, and Sung Ju Hwang This is an official code

min95 168 Dec 28, 2022
From a body shape, infer the anatomic skeleton.

OSSO: Obtaining Skeletal Shape from Outside (CVPR 2022) This repository contains the official implementation of the skeleton inference from: OSSO: Obt

Marilyn Keller 166 Dec 28, 2022
piSTAR Lab is a modular platform built to make AI experimentation accessible and fun. (pistar.ai)

piSTAR Lab WARNING: This is an early release. Overview piSTAR Lab is a modular deep reinforcement learning platform built to make AI experimentation a

piSTAR Lab 0 Aug 01, 2022
A curated list of automated deep learning (including neural architecture search and hyper-parameter optimization) resources.

Awesome AutoDL A curated list of automated deep learning related resources. Inspired by awesome-deep-vision, awesome-adversarial-machine-learning, awe

D-X-Y 2k Dec 30, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr.

fix_m1_rgb Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr. No warranty provided for using th

Kevin Gao 116 Jan 01, 2023
Image Segmentation and Object Detection in Pytorch

Image Segmentation and Object Detection in Pytorch Pytorch-Segmentation-Detection is a library for image segmentation and object detection with report

Daniil Pakhomov 732 Dec 10, 2022
Crawl & visualize ICLR papers and reviews

Crawl and Visualize ICLR 2022 OpenReview Data Descriptions This Jupyter Notebook contains the data crawled from ICLR 2022 OpenReview webpages and thei

Federico Berto 75 Dec 05, 2022
ReLoss - Official implementation for paper "Relational Surrogate Loss Learning" ICLR 2022

Relational Surrogate Loss Learning (ReLoss) Official implementation for paper "R

Tao Huang 31 Nov 22, 2022
PlenOctrees: NeRF-SH Training & Conversion

PlenOctrees Official Repo: NeRF-SH training and conversion This repository contains code to train NeRF-SH and to extract the PlenOctree, constituting

Alex Yu 323 Dec 29, 2022
Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, and finding their unique parameters (e.g. death rate).

DINN We introduce Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, a

19 Dec 10, 2022
Fuse radar and camera for detection

SAF-FCOS: Spatial Attention Fusion for Obstacle Detection using MmWave Radar and Vision Sensor This project hosts the code for implementing the SAF-FC

ChangShuo 18 Jan 01, 2023
A PoC Corporation Relationship Knowledge Graph System on top of Nebula Graph.

Corp-Rel is a PoC of Corpartion Relationship Knowledge Graph System. It's built on top of the Open Source Graph Database: Nebula Graph with a dataset

Wey Gu 20 Dec 11, 2022
Direct application of DALLE-2 to video synthesis, using factored space-time Unet and Transformers

DALLE2 Video (wip) ** only to be built after DALLE2 image is done and replicated, and the importance of the prior network is validated ** Direct appli

Phil Wang 105 May 15, 2022
Public scripts, services, and configuration for running a smart home K3S network cluster

makerhouse_network Public scripts, services, and configuration for running MakerHouse's home network. This network supports: TODO features here For mo

Scott Martin 1 Jan 15, 2022