Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Overview

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning

Implemented & tested on Sort-of-CLEVR task.

Sort-of-CLEVR

Sort-of-CLEVR is simplified version of CLEVR.This is composed of 10000 images and 20 questions (10 relational questions and 10 non-relational questions) per each image. 6 colors (red, green, blue, orange, gray, yellow) are assigned to randomly chosen shape (square or circle), and placed in a image.

Non-relational questions are composed of 3 subtypes:

  1. Shape of certain colored object
  2. Horizontal location of certain colored object : whether it is on the left side of the image or right side of the image
  3. Vertical location of certain colored object : whether it is on the upside of the image or downside of the image

Theses questions are "non-relational" because the agent only need to focus on certain object.

Relational questions are composed of 3 subtypes:

  1. Shape of the object which is closest to the certain colored object
  2. Shape of the object which is furthest to the certain colored object
  3. Number of objects which have the same shape with the certain colored object

These questions are "relational" because the agent has to consider the relations between objects.

Questions are encoded into a vector of size of 11 : 6 for one-hot vector for certain color among 6 colors, 2 for one-hot vector of relational/non-relational questions. 3 for one-hot vector of 3 subtypes.

I.e., with the sample image shown, we can generate non-relational questions like:

  1. What is the shape of the red object? => Circle (even though it does not really look like "circle"...)
  2. Is green object placed on the left side of the image? => yes
  3. Is orange object placed on the upside of the image? => no

And relational questions:

  1. What is the shape of the object closest to the red object? => square
  2. What is the shape of the object furthest to the orange object? => circle
  3. How many objects have same shape with the blue object? => 3

Setup

Create conda environment from environment.yml file

$ conda env create -f environment.yml

Activate environment

$ conda activate RN3

If you don't use conda install python 3 normally and use pip install to install remaining dependencies. The list of dependencies can be found in the environment.yml file.

Usage

$ ./run.sh

or

$ python sort_of_clevr_generator.py

to generate sort-of-clevr dataset and

 $ python main.py 

to train the binary RN model. Alternatively, use

 $ python main.py --relation-type=ternary

to train the ternary RN model.

Modifications

In the original paper, Sort-of-CLEVR task used different model from CLEVR task. However, because model used CLEVR requires much less time to compute (network is much smaller), this model is used for Sort-of-CLEVR task.

Result

Relational Networks (20th epoch) CNN + MLP (without RN, 100th epoch)
Non-relational question 99% 66%
Relational question 89% 66%

CNN + MLP occured overfitting to the training data.

Relational networks shows far better results in relational questions and non-relation questions.

Contributions

@gngdb speeds up the model by 10 times.

Owner
Kim Heecheol
University of Tokyo, Intelligent systems & Informatics Lab.
Kim Heecheol
PyTorch deep learning projects made easy.

PyTorch Template Project PyTorch deep learning project made easy. PyTorch Template Project Requirements Features Folder Structure Usage Config file fo

Victor Huang 3.8k Jan 01, 2023
An image processing project uses Viola-jones technique to detect faces and then use SIFT algorithm for recognition.

Attendance_System An image processing project uses Viola-jones technique to detect faces and then use LPB algorithm for recognition. Face Detection Us

8 Jan 11, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 08, 2023
Supervised Classification from Text (P)

MSc-Thesis Module: Masters Research Thesis Language: Python Grade: 75 Title: An investigation of supervised classification of therapeutic process from

Matthew Laws 1 Nov 22, 2021
Memory efficient transducer loss computation

Introduction This project implements the optimization techniques proposed in Improving RNN Transducer Modeling for End-to-End Speech Recognition to re

Fangjun Kuang 51 Nov 25, 2022
On Generating Extended Summaries of Long Documents

ExtendedSumm This repository contains the implementation details and datasets used in On Generating Extended Summaries of Long Documents paper at the

Georgetown Information Retrieval Lab 76 Sep 05, 2022
This repository is the official implementation of Using Time-Series Privileged Information for Provably Efficient Learning of Prediction Models

Using Time-Series Privileged Information for Provably Efficient Learning of Prediction Models Link to paper Abstract We study prediction of future out

Rickard Karlsson 2 Aug 19, 2022
TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured Scenarios

TPH-YOLOv5 This repo is the implementation of "TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured

cv516Buaa 439 Dec 22, 2022
Deep Reinforcement Learning for Multiplayer Online Battle Arena

MOBA_RL Deep Reinforcement Learning for Multiplayer Online Battle Arena Prerequisite Python 3 gym-derk Tensorflow 2.4.1 Dotaservice of TimZaman Seed R

Dohyeong Kim 32 Dec 18, 2022
Hardware accelerated, batchable and differentiable optimizers in JAX.

JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be

Google 621 Jan 08, 2023
Accelerated SMPL operation, commonly used in generate 3D human mesh, STAR included.

SMPL2 An enchanced and accelerated SMPL operation which commonly used in 3D human mesh generation. It takes a poses, shapes, cam_trans as inputs, outp

JinTian 20 Oct 17, 2022
Statsmodels: statistical modeling and econometrics in Python

About statsmodels statsmodels is a Python package that provides a complement to scipy for statistical computations including descriptive statistics an

statsmodels 8.1k Jan 02, 2023
Fine-grained Post-training for Improving Retrieval-based Dialogue Systems - NAACL 2021

Fine-grained Post-training for Multi-turn Response Selection Implements the model described in the following paper Fine-grained Post-training for Impr

Janghoon Han 83 Dec 20, 2022
Official PyTorch implementation of "AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks"

AASIST This repository provides the overall framework for training and evaluating audio anti-spoofing systems proposed in 'AASIST: Audio Anti-Spoofing

Clova AI Research 56 Jan 02, 2023
Supplemental learning materials for "Fourier Feature Networks and Neural Volume Rendering"

Fourier Feature Networks and Neural Volume Rendering This repository is a companion to a lecture given at the University of Cambridge Engineering Depa

Matthew A Johnson 133 Dec 26, 2022
Ivy is a templated deep learning framework which maximizes the portability of deep learning codebases.

Ivy is a templated deep learning framework which maximizes the portability of deep learning codebases. Ivy wraps the functional APIs of existing frameworks. Framework-agnostic functions, libraries an

Ivy 8.2k Jan 02, 2023
Flower classification model that classifies flowers in 10 classes made using transfer learning (~85% accuracy).

flower-classification-inceptionV3 Flower classification model that classifies flowers in 10 classes. Training and validation are done using a pre-anot

Ivan R. Mršulja 1 Dec 12, 2021
DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

DeepConsensus DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS)

Google 149 Dec 19, 2022
Official PyTorch code for WACV 2022 paper "CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows"

CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows WACV 2022 preprint:https://arxiv.org/abs/2107.1

Denis 156 Dec 28, 2022