Pretty Tensor - Fluent Neural Networks in TensorFlow

Overview

Pretty Tensor - Fluent Neural Networks in TensorFlow

Pretty Tensor provides a high level builder API for TensorFlow. It provides thin wrappers on Tensors so that you can easily build multi-layer neural networks.

Pretty Tensor provides a set of objects that behave likes Tensors, but also support a chainable object syntax to quickly define neural networks and other layered architectures in TensorFlow.

result = (pretty_tensor.wrap(input_data, m)
          .flatten()
          .fully_connected(200, activation_fn=tf.nn.relu)
          .fully_connected(10, activation_fn=None)
          .softmax(labels, name=softmax_name))

Please look here for full documentation of the PrettyTensor object for all available operations: Available Operations or you can check out the complete documentation

See the tutorial directory for samples: tutorial/

Installation

The easiest installation is just to use pip:

  1. Follow the instructions at tensorflow.org
  2. pip install prettytensor

Note: Head is tested against the TensorFlow nightly builds and pip is tested against TensorFlow release.

Quick start

Imports

import prettytensor as pt
import tensorflow as tf

Setup your input

my_inputs = # numpy array of shape (BATCHES, BATCH_SIZE, DATA_SIZE)
my_labels = # numpy array of shape (BATCHES, BATCH_SIZE, CLASSES)
input_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, DATA_SIZE))
label_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, CLASSES))
pretty_input = pt.wrap(input_tensor)

Define your model

softmax, loss = (pretty_input.
                 fully_connected(100).
                 softmax_classifier(CLASSES, labels=label_tensor))

Train and evaluate

accuracy = softmax.evaluate_classifier(label_tensor)

optimizer = tf.train.GradientDescentOptimizer(0.1)  # learning rate
train_op = pt.apply_optimizer(optimizer, losses=[loss])

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init_op)
    for inp, label in zip(my_inputs, my_labels):
        unused_loss_value, accuracy_value = sess.run([loss, accuracy],
                                 {input_tensor: inp, label_tensor: label})
        print 'Accuracy: %g' % accuracy_value

Features

Thin

Full power of TensorFlow is easy to use

Pretty Tensors can be used (almost) everywhere that a tensor can. Just call pt.wrap to make a tensor pretty.

You can also add any existing TensorFlow function to the chain using apply. apply applies the current Tensor as the first argument and takes all the other arguments as normal.

Note: because apply is so generic, Pretty Tensor doesn't try to wrap the world.

Plays well with other libraries

It also uses standard TensorFlow idioms so that it plays well with other libraries, this means that you can use it a little bit in a model or throughout. Just make sure to run the update_ops on each training set (see with_update_ops).

Terse

You've already seen how a Pretty Tensor is chainable and you may have noticed that it takes care of handling the input shape. One other feature worth noting are defaults. Using defaults you can specify reused values in a single place without having to repeat yourself.

with pt.defaults_scope(activation_fn=tf.nn.relu):
  hidden_output2 = (pretty_images.flatten()
                   .fully_connected(100)
                   .fully_connected(100))

Check out the documentation to see all supported defaults.

Code matches model

Sequential mode lets you break model construction across lines and provides the subdivide syntactic sugar that makes it easy to define and understand complex structures like an inception module:

with pretty_tensor.defaults_scope(activation_fn=tf.nn.relu):
  seq = pretty_input.sequential()
  with seq.subdivide(4) as towers:
    towers[0].conv2d(1, 64)
    towers[1].conv2d(1, 112).conv2d(3, 224)
    towers[2].conv2d(1, 32).conv2d(5, 64)
    towers[3].max_pool(2, 3).conv2d(1, 32)

Inception module showing branch and rejoin

Templates provide guaranteed parameter reuse and make unrolling recurrent networks easy:

output = [], s = tf.zeros([BATCH, 256 * 2])

A = (pretty_tensor.template('x')
     .lstm_cell(num_units=256, state=UnboundVariable('state'))

for x in pretty_input_array:
  h, s = A.construct(x=x, state=s)
  output.append(h)

There are also some convenient shorthands for LSTMs and GRUs:

pretty_input_array.sequence_lstm(num_units=256)

Unrolled RNN

Extensible

You can call any existing operation by using apply and it will simply subsitute the current tensor for the first argument.

pretty_input.apply(tf.mul, 5)

You can also create a new operation There are two supported registration mechanisms to add your own functions. @Register() allows you to create a method on PrettyTensor that operates on the Tensors and returns either a loss or a new value. Name scoping and variable scoping are handled by the framework.

The following method adds the leaky_relu method to every Pretty Tensor:

@pt.Register
def leaky_relu(input_pt):
  return tf.select(tf.greater(input_pt, 0.0), input_pt, 0.01 * input_pt)

@RegisterCompoundOp() is like adding a macro, it is designed to group together common sets of operations.

Safe variable reuse

Within a graph, you can reuse variables by using templates. A template is just like a regular graph except that some variables are left unbound.

See more details in PrettyTensor class.

Accessing Variables

Pretty Tensor uses the standard graph collections from TensorFlow to store variables. These can be accessed using tf.get_collection(key) with the following keys:

  • tf.GraphKeys.VARIABLES: all variables that should be saved (including some statistics).
  • tf.GraphKeys.TRAINABLE_VARIABLES: all variables that can be trained (including those before a stop_gradients` call). These are what would typically be called parameters of the model in ML parlance.
  • pt.GraphKeys.TEST_VARIABLES: variables used to evaluate a model. These are typically not saved and are reset by the LocalRunner.evaluate method to get a fresh evaluation.

Authors

Eider Moore (eiderman)

with key contributions from:

  • Hubert Eichner
  • Oliver Lange
  • Sagar Jain (sagarjn)
Owner
Google
Google ❤️ Open Source
Google
Face and Pose detector that emits MQTT events when a face or human body is detected and not detected.

Face Detect MQTT Face or Pose detector that emits MQTT events when a face or human body is detected and not detected. I built this as an alternative t

Jacob Morris 38 Oct 21, 2022
An open source app to help calm you down when needed.

By: Seanpm2001, Et; Al. Top README.md Read this article in a different language Sorted by: A-Z Sorting options unavailable ( af Afrikaans Afrikaans |

Sean P. Myrick V19.1.7.2 2 Oct 24, 2022
Fuzzing tool (TFuzz): a fuzzing tool based on program transformation

T-Fuzz T-Fuzz consists of 2 components: Fuzzing tool (TFuzz): a fuzzing tool based on program transformation Crash Analyzer (CrashAnalyzer): a tool th

HexHive 244 Nov 09, 2022
Code for the paper "Graph Attention Tracking". (CVPR2021)

SiamGAT 1. Environment setup This code has been tested on Ubuntu 16.04, Python 3.5, Pytorch 1.2.0, CUDA 9.0. Please install related libraries before r

122 Dec 24, 2022
A Fast Sequence Transducer Implementation with PyTorch Bindings

transducer A Fast Sequence Transducer Implementation with PyTorch Bindings. The corresponding publication is Sequence Transduction with Recurrent Neur

Awni Hannun 184 Dec 18, 2022
PyTorch Code for the paper "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"

Improving Visual-Semantic Embeddings with Hard Negatives Code for the image-caption retrieval methods from VSE++: Improving Visual-Semantic Embeddings

Fartash Faghri 441 Dec 05, 2022
Official PyTorch implementation of "Improving Face Recognition with Large AgeGaps by Learning to Distinguish Children" (BMVC 2021)

Inter-Prototype (BMVC 2021): Official Project Webpage This repository provides the official PyTorch implementation of the following paper: Improving F

Jungsoo Lee 16 Jun 30, 2022
Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting

Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting Note: You can find here the accompanying seq2seq RNN forecas

Guillaume Chevalier 1k Dec 25, 2022
[ICML 2021] “ Self-Damaging Contrastive Learning”, Ziyu Jiang, Tianlong Chen, Bobak Mortazavi, Zhangyang Wang

Self-Damaging Contrastive Learning Introduction The recent breakthrough achieved by contrastive learning accelerates the pace for deploying unsupervis

VITA 51 Dec 29, 2022
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
PyTorch implementation of UNet++ (Nested U-Net).

PyTorch implementation of UNet++ (Nested U-Net) This repository contains code for a image segmentation model based on UNet++: A Nested U-Net Architect

4ui_iurz1 642 Jan 04, 2023
Unity Propagation in Bayesian Networks Handling Inconsistency via Unity Smoothing

This repository contains the scripts needed to generate the results from the paper Unity Propagation in Bayesian Networks Handling Inconsistency via U

0 Jan 19, 2022
This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

AdapterHub 18 Dec 09, 2022
Implementation of "Deep Implicit Templates for 3D Shape Representation"

Deep Implicit Templates for 3D Shape Representation Zerong Zheng, Tao Yu, Qionghai Dai, Yebin Liu. arXiv 2020. This repository is an implementation fo

Zerong Zheng 144 Dec 07, 2022
Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

CGTransformer Code for our AAAI 2022 paper "Contrastive-Geometry Transformer network for Generalized 3D Pose Transfer" Contrastive-Geometry Transforme

18 Jun 28, 2022
Code for KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs

KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs Check out the paper on arXiv: https://arxiv.org/abs/2103.13744 This repo cont

Christian Reiser 373 Dec 20, 2022
Pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks."

alpha-GAN Unofficial pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks." arXi

Victor Shepardson 78 Dec 08, 2022
Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

14 Nov 06, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
Automatic labeling, conversion of different data set formats, sample size statistics, model cascade

Simple Gadget Collection for Object Detection Tasks Automatic image annotation Conversion between different annotation formats Obtain statistical info

llt 4 Aug 24, 2022