[ICLR'21] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

Related tags

Deep LearningFedBN
Overview

FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

This is the PyTorch implemention of our paper FedBN: Federated Learning on Non-IID Features via Local Batch Normalization by Xiaoxiao Li, Meirui Jiang, Xiaofei Zhang, Michael Kamp and Qi Dou

Abstract

The emerging paradigm of federated learning (FL) strives to enable collaborative training of deep models on the network edge without centrally aggregating raw data and hence improving data privacy. In most cases, the assumption of independent and identically distributed samples across local clients does not hold for federated learning setups. Under this setting, neural network training performance may vary significantly according to the data distribution and even hurt training convergence. Most of the previous work has focused on a difference in the distribution of labels. Unlike those settings, we address an important problem of FL, e.g., different scanner/sensors in medical imaging, different scenery distribution in autonomous driving (highway vs. city), where local clients may store examples with different marginal or conditional feature distributions compared to other nodes, which we denote as feature shift non-iid. In this work, we propose an effective method that uses local batch normalization to alleviate the feature shift before averaging models. The resulting scheme, called FedBN, outperforms both classical FedAvg, as well as the state-of-the-art for non-iid data (FedProx) on our extensive experiments. These empirical results are supported by a convergence analysis that shows in a simplified setting that FedBN has a faster convergence rate in expectation than FedAvg.

avatar

Usage

Setup

pip

See the requirements.txt for environment configuration.

pip install -r requirements.txt

conda

We recommend using conda to quick setup the environment. Please use the following commands.

conda env create -f environment.yaml
conda activate fedbn

Dataset & Pretrained Modeel

Benchmark(Digits)

  • Please download our pre-processed datasets here, put under data/ directory and perform following commands:
    cd ./data
    unzip digit_dataset.zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip digit_model.zip

office-caltech10

  • Please download our pre-processed datasets here, put under data/ directory and perform following commands:
    cd ./data
    unzip office_caltech_10_dataset.zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip office_caltech_10_model.zip

DomainNet

  • Please first download our splition here, put under data/ directory and perform following commands:
    cd ./data
    unzip domainnet_dataset.zip
  • then download dataset including: Clipart, Infograph, Painting, Quickdraw, Real, Sketch, put under data/DomainNet directory and unzip them.
    cd ./data/DomainNet
    unzip [filename].zip
  • Please download our pretrained model here and put under snapshots/ directory, perform following commands:
    cd ./snapshots
    unzip domainnet_model.zip

Train

Federated Learning

Please using following commands to train a model with federated learning strategy.

  • --mode specify federated learning strategy, option: fedavg | fedprox | fedbn
cd federated
# benchmark experiment
python fed_digits.py --mode fedbn

# office-caltech-10 experiment
python fed_office.py --mode fedbn

# DomaiNnet experiment
python fed_domainnet.py --mode fedbn

SingleSet

Please using following commands to train a model using singleset data.

  • --data specify the single dataset
cd singleset 
# benchmark experiment, --data option: svhn | usps | synth | mnistm | mnist
python single_digits.py --data svhn

# office-caltech-10 experiment --data option: amazon | caltech | dslr | webcam
python single_office.py --data amazon

# DomaiNnet experiment --data option: clipart | infograph | painting | quickdraw | real | sketch
python single_domainnet.py --data clipart

Test

cd federated
# benchmark experiment
python fed_digits.py --mode fedbn --test

# office-caltech-10 experiment
python fed_office.py --mode fedbn --test

# DomaiNnet experiment
python fed_domainnet.py --mode fedbn --test

Citation

If you find the code and dataset useful, please cite our paper.

@inproceedings{
li2021fedbn,
title={Fed{\{}BN{\}}: Federated Learning on Non-{\{}IID{\}} Features via Local Batch Normalization},
author={Xiaoxiao Li and Meirui JIANG and Xiaofei Zhang and Michael Kamp and Qi Dou},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=6YEQUn0QICG}
}
Owner
[email protected]
Medical Image Analysis, Artificial Intelligence, Robotics
<a href=[email protected]">
Intel® Neural Compressor is an open-source Python library running on Intel CPUs and GPUs

Intel® Neural Compressor targeting to provide unified APIs for network compression technologies, such as low precision quantization, sparsity, pruning, knowledge distillation, across different deep l

Intel Corporation 846 Jan 04, 2023
This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset

HiRID-ICU-Benchmark This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset for which the manuscript can be found here.

Biomedical Informatics at ETH Zurich 30 Dec 16, 2022
Open source person re-identification library in python

Open-ReID Open-ReID is a lightweight library of person re-identification for research purpose. It aims to provide a uniform interface for different da

Tong Xiao 1.3k Jan 01, 2023
tf2onnx - Convert TensorFlow, Keras and Tflite models to ONNX.

tf2onnx converts TensorFlow (tf-1.x or tf-2.x), tf.keras and tflite models to ONNX via command line or python api.

Open Neural Network Exchange 1.8k Jan 08, 2023
Winners of DrivenData's Overhead Geopose Challenge

Winners of DrivenData's Overhead Geopose Challenge

DrivenData 22 Aug 04, 2022
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
A repository for the paper "Improved Adversarial Systems for 3D Object Generation and Reconstruction".

Improved Adversarial Systems for 3D Object Generation and Reconstruction: This is a repository for the paper "Improved Adversarial Systems for 3D Obje

Edward Smith 188 Dec 25, 2022
Sequential Model-based Algorithm Configuration

SMAC v3 Project Copyright (C) 2016-2018 AutoML Group Attention: This package is a reimplementation of the original SMAC tool (see reference below). Ho

AutoML-Freiburg-Hannover 778 Jan 05, 2023
Artificial Intelligence search algorithm base on Pacman

Pacman Search Artificial Intelligence search algorithm base on Pacman Source The Pacman Projects by the University of California, Berkeley. Layouts Di

Day Fundora 6 Nov 17, 2022
QueryInst: Parallelly Supervised Mask Query for Instance Segmentation

QueryInst is a simple and effective query based instance segmentation method driven by parallel supervision on dynamic mask heads, which outperforms previous arts in terms of both accuracy and speed.

Hust Visual Learning Team 386 Jan 08, 2023
Open-source Monocular Python HawkEye for Tennis

Tennis Tracking 🎾 Objectives Track the ball Detect court lines Detect the players To track the ball we used TrackNet - deep learning network for trac

ArtLabs 188 Jan 08, 2023
⚓ Eurybia monitor model drift over time and securize model deployment with data validation

View Demo · Documentation · Medium article 🔍 Overview Eurybia is a Python library which aims to help in : Detecting data drift and model drift Valida

MAIF 172 Dec 27, 2022
Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Octavio Arriaga 5.3k Dec 30, 2022
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022
SSD: Single Shot MultiBox Detector pytorch implementation focusing on simplicity

SSD: Single Shot MultiBox Detector Introduction Here is my pytorch implementation of 2 models: SSD-Resnet50 and SSDLite-MobilenetV2.

Viet Nguyen 149 Jan 07, 2023
Uncertain natural language inference

Uncertain Natural Language Inference This repository hosts the code for the following paper: Tongfei Chen*, Zhengping Jiang*, Adam Poliak, Keisuke Sak

Tongfei Chen 14 Sep 01, 2022
Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation.

Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation. It was introduced in Wright, Logan G. & Onodera, Tatsuhiro et al. (2021)1 to train Physical Neural Networ

McMahon Lab 230 Jan 05, 2023
This repository attempts to replicate the SqueezeNet architecture and implement the same on an image classification task.

SqueezeNet-Implementation This repository attempts to replicate the SqueezeNet architecture using TensorFlow discussed in the research paper: "Squeeze

Rohan Mathur 3 Dec 13, 2022
Super-BPD: Super Boundary-to-Pixel Direction for Fast Image Segmentation (CVPR 2020)

Super-BPD for Fast Image Segmentation (CVPR 2020) Introduction We propose direction-based super-BPD, an alternative to superpixel, for fast generic im

189 Dec 07, 2022
Implementation of the Point Transformer layer, in Pytorch

Point Transformer - Pytorch Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed

Phil Wang 501 Jan 03, 2023