PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

Overview

neural-combinatorial-rl-pytorch

PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

I have implemented the basic RL pretraining model with greedy decoding from the paper. An implementation of the supervised learning baseline model is available here. Instead of a critic network, I got my results below on TSP from using an exponential moving average critic. The critic network is simply commented out in my code right now. From correspondence with a few others, it was determined that the exponential moving average critic significantly helped improve results.

My implementation uses a stochastic decoding policy in the pointer network, realized via PyTorch's torch.multinomial(), during training, and beam search (not yet finished, only supports 1 beam a.k.a. greedy) for decoding when testing the model.

Currently, there is support for a sorting task and the planar symmetric Euclidean TSP.

See main.sh for an example of how to run the code.

Use the --load_path $LOAD_PATH and --is_train False flags to load a saved model.

To load a saved model and view the pointer network's attention layer, also use the --plot_attention True flag.

Please, feel free to notify me if you encounter any errors, or if you'd like to submit a pull request to improve this implementation.

Adding other tasks

This implementation can be extended to support other combinatorial optimization problems. See sorting_task.py and tsp_task.py for examples on how to add. The key thing is to provide a dataset class and a reward function that takes in a sample solution, selected by the pointer network from the input, and returns a scalar reward. For the sorting task, the agent received a reward proportional to the length of the longest strictly increasing subsequence in the decoded output (e.g., [1, 3, 5, 2, 4] -> 3/5 = 0.6).

Dependencies

  • Python=3.6 (should be OK with v >= 3.4)
  • PyTorch=0.2 and 0.3
  • tqdm
  • matplotlib
  • tensorboard_logger

PyTorch 0.4 compatibility is available on branch pytorch-0.4.

TSP Results

Results for 1 random seed over 50 epochs (each epoch is 10,000 batches of size 128). After each epoch, I validated performance on 1000 held out graphs. I used the same hyperparameters from the paper, as can be seen in main.sh. The dashed line shows the value indicated in Table 2 of Bello, et. al for comparison. The log scale x axis for the training reward is used to show how the tour length drops early on.

TSP 20 Train TSP 20 Val TSP 50 Train TSP 50 Val

Sort Results

I trained a model on sort10 for 4 epochs of 1,000,000 randomly generated samples. I tested it on a dataset of size 10,000. Then, I tested the same model on sort15 and sort20 to test the generalization capabilities.

Test results on 10,000 samples (A reward of 1.0 means the network perfectly sorted the input):

task average reward variance
sort10 0.9966 0.0005
sort15 0.7484 0.0177
sort20 0.5586 0.0060

Example prediction on sort10:

input: [4, 7, 5, 0, 3, 2, 6, 8, 9, 1]
output: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Attention visualization

Plot the pointer network's attention layer with the argument --plot_attention True

TODO

  • Add RL pretraining-Sampling
  • Add RL pretraining-Active Search
  • Active Search
  • Asynchronous training a la A3C
  • Refactor USE_CUDA variable
  • Finish implementing beam search decoding to support > 1 beam
  • Add support for variable length inputs

Acknowledgements

Special thanks to the repos devsisters/neural-combinatorial-rl-tensorflow and MaximumEntropy/Seq2Seq-PyTorch for getting me started, and @ricgama for figuring out that weird bug with clone()

Owner
Patrick E.
Machine Learning PhD Candidate at Univ. of Florida. Deep generative models | object-centric representation learning | RL | transportation
Patrick E.
Yolo algorithm for detection + centroid tracker to track vehicles

Vehicle Tracking using Centroid tracker Algorithm used : Yolo algorithm for detection + centroid tracker to track vehicles Backend : opencv and python

6 Dec 21, 2022
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).

Github Code of "MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices" Introduction This repo is official PyTorch implementatio

Choi Sang Bum 203 Jan 05, 2023
AutoDeeplab / auto-deeplab / AutoML for semantic segmentation, implemented in Pytorch

AutoML for Image Semantic Segmentation Currently this repo contains the only working open-source implementation of Auto-Deeplab which, by the way out-

AI Necromancer 299 Dec 17, 2022
Implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork.

YOLOv4-large This is the implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork. YOLOv4-CSP YOLOv4-tiny YOLOv4-

Kin-Yiu, Wong 2k Jan 02, 2023
Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Codes for TIM2021 paper "Anchor-Based Spatio-Temporal Attention 3-D Convolutional Networks for Dynamic 3-D Point Cloud Sequences"

Intelligent Robotics and Machine Vision Lab 4 Jul 19, 2022
A Transformer-Based Siamese Network for Change Detection

ChangeFormer: A Transformer-Based Siamese Network for Change Detection (Under review at IGARSS-2022) Wele Gedara Chaminda Bandara, Vishal M. Patel Her

Wele Gedara Chaminda Bandara 214 Dec 29, 2022
[ACM MM 2021] TSA-Net: Tube Self-Attention Network for Action Quality Assessment

Tube Self-Attention Network (TSA-Net) This repository contains the PyTorch implementation for paper TSA-Net: Tube Self-Attention Network for Action Qu

ShunliWang 18 Dec 23, 2022
Generating Band-Limited Adversarial Surfaces Using Neural Networks

Generating Band-Limited Adversarial Surfaces Using Neural Networks This is the official repository of the technical report that was published on arXiv

3 Jul 26, 2022
Official implementation for "Image Quality Assessment using Contrastive Learning"

Image Quality Assessment using Contrastive Learning Pavan C. Madhusudana, Neil Birkbeck, Yilin Wang, Balu Adsumilli and Alan C. Bovik This is the offi

Pavan Chennagiri 67 Dec 30, 2022
Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transformers to Guarantee TopologyPreservation in Segmentations"

TEDS-Net Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transfo

Madeleine K Wyburd 14 Jan 04, 2023
BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation This is a demo implementation of BYOL for Audio (BYOL-A), a self-sup

NTT Communication Science Laboratories 160 Jan 04, 2023
This project demonstrates the use of neural networks and computer vision to create a classifier that interprets the Brazilian Sign Language.

LIBRAS-Image-Classifier This project demonstrates the use of neural networks and computer vision to create a classifier that interprets the Brazilian

Aryclenio Xavier Barros 26 Oct 14, 2022
A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Mugs: A Multi-Granular Self-Supervised Learning Framework This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-

Sea AI Lab 62 Nov 08, 2022
An implementation of chunked, compressed, N-dimensional arrays for Python.

Zarr Latest Release Package Status License Build Status Coverage Downloads Gitter Citation What is it? Zarr is a Python package providing an implement

Zarr Developers 1.1k Dec 30, 2022
GPT, but made only out of gMLPs

GPT - gMLP This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will

Phil Wang 80 Dec 01, 2022
Neural network chess engine trained on Gary Kasparov's games.

Neural Chess It's not the best chess engine, but it is a chess engine. Proof of concept neural network chess engine (feed-forward multi-layer perceptr

3 Jun 22, 2022
SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

SEOVER-Master This code is the implementation of paper: SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

4 Feb 24, 2022
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
Sparse R-CNN: End-to-End Object Detection with Learnable Proposals, CVPR2021

End-to-End Object Detection with Learnable Proposal, CVPR2021

Peize Sun 1.2k Dec 27, 2022
Api's bulid in Flask perfom to manage Todo Task.

Citymall-task Api's bulid in Flask perfom to manage Todo Task. Installation Requrements : Python: 3.10.0 MongoDB create .env file with variables DB_UR

Aisha Tayyaba 1 Dec 17, 2021