Pointer networks Tensorflow2

Overview

Pointer networks Tensorflow2

原文:https://arxiv.org/abs/1506.03134
仅供参考与学习,内含代码备注

环境

tensorflow==2.6.0
tqdm
matplotlib
numpy

《pointer networks》阅读笔记

应用场景:

文本摘要,凸包问题,Roundelay 三角剖分,旅行商问题

其中包括一些Latex,github无法渲染,所以建议clone下来用Typora查看。

abstract

本文提出一种新的网络结构:输出序列的元素是与输入序列中的位置相对应的离散标记。

an output sequence with elements that are discrete tokens corresponding to positions in an input sequence.

这种问题目前可以被一些现有的方法解决:sequence-to-sequence, neural turing machines。但是这些方法不是特别适用。

本文解决的问题是sorting variable sized sequences,以及各种组合优化问题。本模型使用attention机制来解决变化尺寸的输出。

intro

RNN模型的输出维度是固定的,sequence-to-sequence模型移除了这一个限制,通过用一个RNN把输入映射为一个embedding,又用一个RNN把embedding映射到输出序列。

但是这些sequence-to-sequence 方法都是固定大小的词汇表。

例如词汇表中只存在A,B,C。那么输入

1,2,3 ----> A,B,C

1,2,3,4 ----> A,B,C,A

本文提出的框架适用于输出的词汇表大小取决于输入问题的大小

image-20211105133740833

image-20211105134312635

左图:seq-2-seq

蓝色RNN,输出一个向量。

紫色RNN,利用概率的链式法则,输出一个固定维度。

本文的贡献如下:

  1. 提出一种新的结构,称为指针网路。简单且高效
  2. 良好的泛化性能
  3. 一个TSP近似求解器

Models

sequence-to-sequence 模型

训练数据为: $$ (P,C^P) $$ 其中,$\mathcal{P}=\left{P_{1}, \ldots, P_{n}\right}$,是n个向量。$\mathcal{C}^{\mathcal{P}}=\left{C_{1}, \ldots, C_{m(\mathcal{P})}\right}$ ,n个对应的结果,$m(\mathcal{P})\in [1,n]$ 。传统的sequence-to-sequence的$\mathcal{C}^{\mathcal{P}}$是固定大小的,但是要提前给定。本文的$\mathcal{C}^{\mathcal{P}}$为n,根据输入改变。

如果模型的参数记为$\theta$,神经网络模型表达为: $$ p(C^P|P,\theta) $$ 使用链式法则,写为: $$ p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta\right)=\prod_{i=1}^{m(\mathcal{P})} p_{\theta}\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P} ; \theta\right) $$ 训练阶段,最大似然概率: $$ \theta^{*}=\underset{\theta}{\arg \max } \sum_{\mathcal{P}, \mathcal{C}^{\mathcal{P}}} \log p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta\right) $$ input sequence的末端加一个$\Rightarrow$,代表进入生成阶段,$\Leftarrow$代表结束生成阶段。

推断: $$ \hat{\mathcal{C}}^{\mathcal{P}}=\underset{\mathcal{C}^{\mathcal{P}}}{\arg \max } p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta^{*}\right) $$

content based input attention

对于attention机制,请查看《Neural Machine Translation By Jointly Learning To Align And Translate》阅读笔记。

对于LSTM RNN $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) & j \in(1, \ldots, n) \ a_{j}^{i} &=\operatorname{softmax}\left(u_{j}^{i}\right) & j \in(1, \ldots, n) \ d_{i}^{\prime} &=\sum_{j=1}^{n} a_{j}^{i} e_{j} & \end{aligned} $$ 对于这个传统的attention机制,可以看到$u^{i}$, 是一个长度为$n$的向量。

这样的话,在解码器的每一个时间步迭代都会得到一个 n 长度的向量,可以作为指针,用于指向之前的 n 长度的序列。

Ptr-Net

所以Ptr-Net计算公式写为: $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) &=\operatorname{softmax}\left(u^{i}\right) \end{aligned} $$ image-20211111103159924

image-20211111110334755

数据以 [Batch, time_steps, feature] 的形式进入编码器LSTM(绿色部分),在时间步上迭代$n$次以后,得到:

  • n 个 e [batch, units], 可以合并写为 [batch, n, units]

  • 最后一个时间步输出的 c [batch, units]

进入到解码器LSTM(蓝色部分),输入为:

  • 上次得到解码得到的的pointer,如果是第一次则为initial pointer
  • 上次的状态d,c

pointer 如何得到?计算公式如下: $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) &=\operatorname{softmax}\left(u^{i}\right) \end{aligned} $$

motivation and datasets structure

文章是为了解决三种问题,凸包,Delaunay Triangulation,旅行商问题。在此只对旅行商问题进行探讨。

travelling salesman problem

给定一个城市列表,我们希望找到一条最短的路线,每个城市只访问一次,然后返回起点。此外,假设两个城市之间的距离在正反方向上是相同的。这是一个NP难问题,测试模型的能力和局限性。

数据生成:

卡迪尔坐标系(二维),$[0,1] \times[0,1]$

使用 Held-Karp algorithm 得到准确解,n最多为20。

A1,A2,A3为三种其他算法。A1,A2时间复杂度为$O\left(n^{2}\right)$,A3时间复杂度为$O\left(n^{3}\right)$。A3,Christofides algorithm 算法保证在距离最佳长度1.5倍的范围内找到解,详细信息查看原文参考文献。生成1M个数据进行训练。

image-20211111111416012

分析表格:

  1. n=5的时候,性能都很好
  2. n=10,ptr-net的性能比A1好
  3. n=50的时候,无法超过数据集性能(因为ptr-net使用不准确的答案进行训练的)
  4. 只用n少的训练,推广到大n情况,性能不太好。

对于n=30的情况,Ptr-net算法复杂度为$O(n \log n)$,远低于A1,A2,A3。却有相似的性能,说明可发展空间还是很大的。

You might also like...
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks without the use of any outside machine learning libraries - all from scratch.

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

Lightweight library to build and train neural networks in Theano

Lasagne Lasagne is a lightweight library to build and train neural networks in Theano. Its main features are: Supports feed-forward networks such as C

A flexible framework of neural networks for deep learning
A flexible framework of neural networks for deep learning

Chainer: A deep learning framework Website | Docs | Install Guide | Tutorials (ja) | Examples (Official, External) | Concepts | ChainerX Forum (en, ja

Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

Releases(v0)
Owner
HUANG HAO
Program = Algorithm + Data structure
HUANG HAO
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Erik Härkönen 1.7k Jan 03, 2023
Anatomy of Matplotlib -- tutorial developed for the SciPy conference

Introduction This tutorial is a complete re-imagining of how one should teach users the matplotlib library. Hopefully, this tutorial may serve as insp

Matplotlib Developers 1.1k Dec 29, 2022
Alphabetical Letter Recognition

DecisionTrees-Image-Classification Alphabetical Letter Recognition In these demo we are using "Decision Trees" Our database is composed by Learning Im

Mohammed Firass 4 Nov 30, 2021
Repository for the Bias Benchmark for QA dataset.

BBQ Repository for the Bias Benchmark for QA dataset. Authors: Alicia Parrish, Angelica Chen, Nikita Nangia, Vishakh Padmakumar, Jason Phang, Jana Tho

ML² AT CILVR 18 Nov 18, 2022
Chunkmogrify: Real image inversion via Segments

Chunkmogrify: Real image inversion via Segments Teaser video with live editing sessions can be found here This code demonstrates the ideas discussed i

David Futschik 112 Jan 04, 2023
[CVPR 2022] Structured Sparse R-CNN for Direct Scene Graph Generation

Structured Sparse R-CNN for Direct Scene Graph Generation Our paper Structured Sparse R-CNN for Direct Scene Graph Generation has been accepted by CVP

Multimedia Computing Group, Nanjing University 44 Dec 23, 2022
Code for generating the figures in the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Linearly Classified Under All Possible Views?"

Code for running simulations for the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Lin

Matthew Farrell 1 Nov 22, 2022
MQBench: Towards Reproducible and Deployable Model Quantization Benchmark

MQBench: Towards Reproducible and Deployable Model Quantization Benchmark We propose a benchmark to evaluate different quantization algorithms on vari

494 Dec 29, 2022
Pytorch implementation of the unsupervised object discovery method LOST.

LOST Pytorch implementation of the unsupervised object discovery method LOST. More details can be found in the paper: Localizing Objects with Self-Sup

Valeo.ai 189 Dec 25, 2022
Code and experiments for "Deep Neural Networks for Rank Consistent Ordinal Regression based on Conditional Probabilities"

corn-ordinal-neuralnet This repository contains the orginal model code and experiment logs for the paper "Deep Neural Networks for Rank Consistent Ord

Raschka Research Group 14 Dec 27, 2022
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models

LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models. Developers can reproduce these SOTA methods and

TuZheng 405 Jan 04, 2023
Keras attention models including botnet,CoaT,CoAtNet,CMT,cotnet,halonet,resnest,resnext,resnetd,volo,mlp-mixer,resmlp,gmlp,levit

Keras_cv_attention_models Keras_cv_attention_models Usage Basic Usage Layers Model surgery AotNet ResNetD ResNeXt ResNetQ BotNet VOLO ResNeSt HaloNet

319 Dec 28, 2022
Pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks."

alpha-GAN Unofficial pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks." arXi

Victor Shepardson 78 Dec 08, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Adversarially Learned Inference

Adversarially Learned Inference Code for the Adversarially Learned Inference paper. Compiling the paper locally From the repo's root directory, $ cd p

Mohamed Ishmael Belghazi 308 Sep 24, 2022
Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes

Compartmental epidemic model to assess undocumented infections: applications to SARS-CoV-2 epidemics in Brazil - Datasets and Codes The codes for simu

1 Jan 12, 2022
License Plate Detection Application

LicensePlate_Project 🚗 🚙 [Project] 2021.02 ~ 2021.09 License Plate Detection Application Overview 1. 데이터 수집 및 라벨링 차량 번호판 이미지를 직접 수집하여 각 이미지에 대해 '번호판

4 Oct 10, 2022
It is a simple library to speed up CLIP inference up to 3x (K80 GPU)

CLIP-ONNX It is a simple library to speed up CLIP inference up to 3x (K80 GPU) Usage Install clip-onnx module and requirements first. Use this trick !

Gerasimov Maxim 93 Dec 20, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 01, 2023