FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

Overview

FedJAX: Federated learning with JAX

What is FedJAX?

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

FedJAX is built around the common core components needed in the FL setting:

  • Federated datasets: Clients and a dataset for each client
  • Models: CNN, ResNet, etc.
  • Optimizers: SGD, Momentum, etc.
  • Federated algorithms: Client updates and server aggregation

For Models and Optimizers, FedJAX provides lightweight wrappers and containers that can work with a variety of existing implementations (e.g. a model wrapper that can support both Haiku and Stax). Similarly, for Federated datasets, TFF provides a well established API for working with federated datasets, and FedJAX just provides utilties for converting to NumPy input acceptable to JAX.

However, what FL researchers will find most useful is the collection and customizability of Federated algorithms provided out of box by FedJAX.

Quickstart

The FedJAX Intro notebook provides an introduction into running existing FedJAX experiments. For more custom use cases, please refer to the FedJAX Advanced notebook.

You can also take a look at some of our examples:

Installation

You will need Python 3.6 or later and a working JAX installation. For a CPU-only version:

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

For other devices (e.g. GPU), follow these instructions.

Then, install fedjax from PyPi:

pip install fedjax

Or, to upgrade to the latest version of fedjax:

pip install --upgrade git+https://github.com/google/fedjax.git

Useful pointers

NOTE: FedJAX is not an officially supported Google product. FedJAX is still in the early stages and the API will likely continue to change.

Comments
  • FedJax depends on TensorFlow Federated?

    FedJax depends on TensorFlow Federated?

    I am helping users install FedJax for use in their federated learning research projects and I noticed that installing FedJax is pulling in TensorFlow Federated (0.17) and TensorFlow (2.3). I don't see either of these listed as dependencies of FedJax so I am trying to understand why they are being pulled in by pip install fedjax.

    opened by davidrpugh 7
  • CIFAR 100 Questions

    CIFAR 100 Questions

    Hi, thanks for the awesome library! I want to ask a couple of questions related to CIFAR100 datasets.

    1. I noticed that while the dataset is available in the library, the model is not. Curious if a model for CIFAR100 is work-in-progress, or if there is no short-term plan for this?
    2. Looking at the CIFAR100 dataset, this seems to be inconsistent with Google's TFF. Notably, the cropping size and normalizing are done differently from TFF. Is this intentional? Would it be correct to say that we could expect this to mirror TFF's design eventually?

    Thanks in advance for all the help!

    opened by HanGuo97 5
  • unbiased scale for DRIVE

    unbiased scale for DRIVE

    Following a discussion with @stheertha, I suggest using the unbiased scale (section 4.2 in Drive's paper) for cases where there is more than 1 client.

    Thank you for considering.

    opened by amitport 3
  • Problem of Quick Start in Readme.md

    Problem of Quick Start in Readme.md

    I tried to run the code in the QuickStart and I found some problems. federated_data = fedjax.FederatedData() can not be executed because it is an abstract class. So I replaced it as

    client_a_data = {
            'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
            'y': np.array([7, 8])
        }
    client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
    client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}
    federated_data = fedjax.InMemoryFederatedData(client_to_data_mapping)
    

    The other things are same as the QuickStart, but i got an error

    for client_id, client_output, _ in func(shared_input, clients):
    for client_id, client_batches, client_input in clients:
    ValueError: not enough values to unpack (expected 3, got 2)
    

    It seems that client_batches is missing and we need to batch the dataset, but there is no example which fits this situation.

    opened by Ichiruchan 2
  • Full EMNIST example does not exhibit parallelization

    Full EMNIST example does not exhibit parallelization

    Hi! I am facing an issue with parallelizing the base code provided by the developers.

    • My local workstation contains two GPUs.
    • I installed FedJax in a conda environment
    • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
    • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
    • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

    Any ideas what I am doing wrong?

    opened by gaseln 2
  • Clarifying the meaning of

    Clarifying the meaning of "weight"

    In the Intro notebook, the backward_pass_output from model.backward has a weight feature. It seems to me that this is used for performing a weighted averaging in FedAvg, but this is not clear to me how. Perhaps this can be renamed to batch_size?

    opened by Saipraneet 1
  • [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    [NumPy] Remove references to deprecated NumPy type aliases.

    This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

    NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

    opened by copybara-service[bot] 0
  • Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Disable pytype import error for old stax import path

    Why? The deprecated jax.experimental.stax path will soon be removed (see https://github.com/google/jax/pull/11700), and this causes pytype to fail.

    opened by copybara-service[bot] 0
  • Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Rename jax.experimental.stax -> jax.example_libraries.stax

    Why? The former name has been deprecated since JAX version 0.2.25, released in November 2021 (see https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0225-nov-10-2021), and will soon be removed.

    opened by copybara-service[bot] 0
  • Implement standard CIFAR-100 model in fedjax.models.cifar100

    Implement standard CIFAR-100 model in fedjax.models.cifar100

    Add a standard implementation of the model for the CIFAR-100 task. The dataset can be found in fedjax.datasets.cifar100.

    For the model architecture, we should follow “Adaptive Federated Optimization”. The model architecture is detailed in section 4 as a ResNet-18 (replacing batch norm with group norm). Code for this paper and a Keras implementation of the model can be found here. We suggest using either haiku or flax to implement the model for use with JAX.

    If you choose to use haiku, you can use fedjax.create_model_from_haiku to create a fedjax compatible model. If you choose to use flax, wrapping it in a fedjax.Model is fairly straightforward and we can provide guidance for this.

    A good example to follow is #265 that checks in a simple linear model for CIFAR-100 and includes the model implementation, tests, and baseline results with FedAvg using this script. Make sure to add a flags file similar to https://github.com/google/fedjax/blob/main/experiments/fed_avg/fed_avg.CIFAR100_LOGISTIC.flags and add the new task to https://github.com/google/fedjax/blob/main/fedjax/training/tasks.py.

    Thanks for your contributions!

    enhancement contributions welcome 
    opened by jaehunro 1
  • Support for manually modifying client/server learning rate

    Support for manually modifying client/server learning rate

    Hi, I'm playing around with clients learning rate but I cannot find a clean way of modifying it.

    Basically, I need to change the LR following a schedule based on the current round. Is that possible?

    Thanks

    opened by marcociccone 1
  • Support for gldv2 and inaturalist datasets

    Support for gldv2 and inaturalist datasets

    I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

    By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

    The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

    Thanks!

    opened by marcociccone 7
  • Support for haiku models with non-trainable state

    Support for haiku models with non-trainable state

    Hi! congrats on this great library! I've started using it a few days ago and I love it!

    Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

    Thanks a lot for your help!

    opened by marcociccone 2
  • How to create a validation dataset?

    How to create a validation dataset?

    Hello!

    I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?

    opened by gaseln 4
  • Feature request: Convert standard dataset into a federated dataset

    Feature request: Convert standard dataset into a federated dataset

    Synthetic federated datasets can constructed from standard centralized ones by artificially splitting them among clients. This is usually done using a Dirichlet distribution (e.g. Hsu et al. 2019). Such synthetic datasets are very useful since we can explicitly control the total number of users, as well as the heterogeneity.

    It would be great to have primitives which can automatically convert standard numpy dataset into a FedJax datset.

    contributions welcome 
    opened by Saipraneet 5
Releases(v0.0.15)
Owner
Google
Google ❤️ Open Source
Google
Transformer Tracking (CVPR2021)

TransT - Transformer Tracking [CVPR2021] Official implementation of the TransT (CVPR2021) , including training code and trained models. We are revisin

chenxin 465 Jan 06, 2023
Self-Regulated Learning for Egocentric Video Activity Anticipation

Self-Regulated Learning for Egocentric Video Activity Anticipation Introduction This is a Pytorch implementation of the model described in our paper:

qzhb 13 Sep 23, 2022
CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning

CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning This repository contains the code and relevant instructions

XiaoMing 5 Aug 19, 2022
A Light CNN for Deep Face Representation with Noisy Labels

A Light CNN for Deep Face Representation with Noisy Labels Citation If you use our models, please cite the following paper: @article{wulight, title=

Alfred Xiang Wu 715 Nov 05, 2022
ManipulaTHOR, a framework that facilitates visual manipulation of objects using a robotic arm

ManipulaTHOR: A Framework for Visual Object Manipulation Kiana Ehsani, Winson Han, Alvaro Herrasti, Eli VanderBilt, Luca Weihs, Eric Kolve, Aniruddha

AI2 65 Dec 30, 2022
GUI for a Vocal Remover that uses Deep Neural Networks.

GUI for a Vocal Remover that uses Deep Neural Networks.

4.4k Jan 07, 2023
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 44 Nov 24, 2022
PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis

Daft-Exprt - PyTorch Implementation PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis The

Keon Lee 47 Dec 18, 2022
Code for CVPR2019 paper《Unequal Training for Deep Face Recognition with Long Tailed Noisy Data》

Unequal-Training-for-Deep-Face-Recognition-with-Long-Tailed-Noisy-Data. This is the code of CVPR 2019 paper《Unequal Training for Deep Face Recognition

Zhong Yaoyao 68 Jan 07, 2023
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet bench

Google Research 169 Dec 20, 2022
Automated detection of anomalous exoplanet transits in light curve data.

Automatically detecting anomalous exoplanet transits This repository contains the source code for the paper "Automatically detecting anomalous exoplan

1 Feb 01, 2022
A Demo server serving Bert through ONNX with GPU written in Rust with <3

Demo BERT ONNX server written in rust This demo showcase the use of onnxruntime-rs on BERT with a GPU on CUDA 11 served by actix-web and tokenized wit

Xavier Tao 28 Jan 01, 2023
Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021) An efficient PyTorch library for Point Cloud Completion.

Microsoft 119 Jan 02, 2023
Code for the Paper: Alexandra Lindt and Emiel Hoogeboom.

Discrete Denoising Flows This repository contains the code for the experiments presented in the paper Discrete Denoising Flows [1]. To give a short ov

Alexandra Lindt 3 Oct 09, 2022
Implementation of Stochastic Image-to-Video Synthesis using cINNs.

Stochastic Image-to-Video Synthesis using cINNs Official PyTorch implementation of Stochastic Image-to-Video Synthesis using cINNs accepted to CVPR202

CompVis Heidelberg 135 Dec 28, 2022
This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper

DeepShift This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper, that aims to replace multiplicati

Mostafa Elhoushi 88 Dec 23, 2022
Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper

Gotta Go Fast When Generating Data with Score-Based Models This repo contains the official implementation for the paper Gotta Go Fast When Generating

Alexia Jolicoeur-Martineau 89 Nov 09, 2022
A visualization tool to show a TensorFlow's graph like TensorBoard

tfgraphviz tfgraphviz is a module to visualize a TensorFlow's data flow graph like TensorBoard using Graphviz. tfgraphviz enables to provide a visuali

44 Nov 09, 2022
MAU: A Motion-Aware Unit for Video Prediction and Beyond, NeurIPS2021

MAU (NeurIPS2021) Zheng Chang, Xinfeng Zhang, Shanshe Wang, Siwei Ma, Yan Ye, Xinguang Xiang, Wen GAo. Official PyTorch Code for "MAU: A Motion-Aware

ZhengChang 20 Nov 25, 2022