Cross-Modal Contrastive Learning for Text-to-Image Generation

Overview

Cross-Modal Contrastive Learning for Text-to-Image Generation

This repository hosts the open source JAX implementation of XMC-GAN.

Setup instructions

Environment

Set up virtualenv, and install required libraries:

virtualenv venv
source venv/bin/activate

Add the XMC-GAN library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/xmcgan/root/

JAX Installation

Note: Please follow the official JAX instructions for installing a GPU compatible version of JAX.

Other Dependencies

After installing JAX, install the remaining dependencies with:

pip install -r requirements.txt

Preprocess COCO-2014

To create the training and eval data, first start a directory. By default, the training scripts expect to save results in data/ in the base directory.

mkdir data/

The TFRecords required for training and validation on COCO-2014 can be created by running a preprocessing script over the TFDS coco_captions dataset:

python preprocess_data.py

This may take a while to complete, as it runs a pretrained BERT model over the captions and stores the embeddings. With a GPU, it runs in about 2.5 hours for train, and 1 hour for validation. Once it is done, the train and validation tfrecords files will be saved in the data/ directory. The train files require around 58G of disk space, and the validation requires 29G.

Note: If you run into an error related to TensorFlow gfile, one workaround is to edit site-packages/bert/tokenization.py and change tf.gfile.GFile to tf.io.gfile.GFile. For more details, refer to the following link.

If you run into a tensorflow.python.framework.errors_impl.ResourceExhaustedError about having too many open files, you may have to increase the machine's open file limits. To do so, open the limit configuration file for editing:

vi /etc/security/limits.conf

and append the following lines to the end of the file:

*         hard    nofile      500000
*         soft    nofile      500000
root      hard    nofile      500000
root      soft    nofile      500000

You may have to adjust the limit values depending on your machine. You will need to logout and login to your machine for these values to take effect.

Download Pretrained ResNet

To train XMC-GAN, we need a network pretrained on ImageNet to extract features. For our purposes, we train a ResNet-50 network for this. To download the weights, run:

gsutil cp gs://gresearch/xmcgan/resnet_pretrained.npy data/

If you would like to pretrain your own network on ImageNet, please refer to the official Flax ImageNet example.

Training

Start a training run, by first editing train.sh to specify an appropriate work directory. By default, the script assumes that 8 GPUs are available, and runs training on the first 7 GPUs, while test.sh assumes testing will run on the last GPU. After configuring the training job, start an experiment by running it on bash:

mkdir exp
bash train.sh exp_name &> train.txt

Checkpoints and Tensorboard logs will be saved in /path/to/exp/exp_name. By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights. To reproduce the full results on 256px images in our paper, the full model needs to be run using a 32-core Pod slice of Google Cloud TPU v3 devices.

Evaluation

To run an evaluation job, update test.sh with the correct settings used in the training script. Then, execute

bash test.sh exp_name &> eval.txt

to start an evaluation job. All checkpoints in workdir will be evaluated for FID and Inception Score. If you can spare the GPUs, you can also run train.sh and test.sh in parallel, which will continuously evaluate new checkpoints saved into the work directory. Scores will be written to Tensorboard and output to eval.txt.

Tensorboard

To start a Tensorboard for monitoring training progress, run:

tensorboard --logdir /path/to/exp/exp_name

Citation

If you find this work useful, please consider citing:

@inproceedings{zhang2021cross,
  title={Cross-Modal Contrastive Learning for Text-to-Image Generation},
  author={Zhang, Han and Koh, Jing Yu and Baldridge, Jason and Lee, Honglak and Yang, Yinfei},
  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Disclaimer

Not an official Google product.

Owner
Google Research
Google Research
I-BERT: Integer-only BERT Quantization

I-BERT: Integer-only BERT Quantization HuggingFace Implementation I-BERT is also available in the master branch of HuggingFace! Visit the following li

Sehoon Kim 139 Dec 27, 2022
[ICCV'21] Pri3D: Can 3D Priors Help 2D Representation Learning?

Pri3D: Can 3D Priors Help 2D Representation Learning? [ICCV 2021] Pri3D leverages 3D priors for downstream 2D image understanding tasks: during pre-tr

Ji Hou 124 Jan 06, 2023
Refactoring dalle-pytorch and taming-transformers for TPU VM

Text-to-Image Translation (DALL-E) for TPU in Pytorch Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning Requirements

Kim, Taehoon 61 Nov 07, 2022
NeurIPS 2021, "Fine Samples for Learning with Noisy Labels"

[Official] FINE Samples for Learning with Noisy Labels This repository is the official implementation of "FINE Samples for Learning with Noisy Labels"

mythbuster 27 Dec 23, 2022
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation

Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation Our paper is accepted by ICCV2021. Picture: Overview of the proposed Plug-an

Yunfei Liu 32 Dec 10, 2022
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
SlideGraph+: Whole Slide Image Level Graphs to Predict HER2 Status in Breast Cancer

SlideGraph+: Whole Slide Image Level Graphs to Predict HER2 Status in Breast Cancer A novel graph neural network (GNN) based model (termed SlideGraph+

28 Dec 24, 2022
A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21

ANEMONE A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21 Dependencies python==3.6.1 dgl==

Graph Analysis & Deep Learning Laboratory, GRAND 30 Dec 14, 2022
An investigation project for SISR.

SISR-Survey An investigation project for SISR. This repository is an official project of the paper "From Beginner to Master: A Survey for Deep Learnin

Juncheng Li 79 Oct 20, 2022
The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

AI2 114 Jan 06, 2023
[ICCV2021] Official code for "Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition"

CTR-GCN This repo is the official implementation for Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition. The pap

Yuxin Chen 148 Dec 16, 2022
PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Pre

Facebook Research 437 Dec 23, 2022
Object Depth via Motion and Detection Dataset

ODMD Dataset ODMD is the first dataset for learning Object Depth via Motion and Detection. ODMD training data are configurable and extensible, with ea

Brent Griffin 172 Dec 21, 2022
A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squares.

W.I.P-Aim-Memory-Game A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squar

dE_soot 1 Dec 08, 2021
Additional code for Stable-baselines3 to load and upload models from the Hub.

Hugging Face x Stable-baselines3 A library to load and upload Stable-baselines3 models from the Hub. Installation With pip Examples [Todo: add colab t

Hugging Face 34 Dec 10, 2022
coldcuts is an R package to automatically generate and plot segmentation drawings in R

coldcuts coldcuts is an R package that allows you to draw and plot automatically segmentations from 3D voxel arrays. The name is inspired by one of It

2 Sep 03, 2022
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
Locally Constrained Self-Attentive Sequential Recommendation

LOCKER This is the pytorch implementation of this paper: Locally Constrained Self-Attentive Sequential Recommendation. Zhankui He, Handong Zhao, Zhe L

Zhankui (Aaron) He 8 Jul 30, 2022
PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020).

NHDRRNet-PyTorch This is the PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020). 0. Differences between Original Paper and

Yutong Zhang 1 Mar 01, 2022