Tensorflow-Project-Template - A best practice for tensorflow project template architecture.

Overview

Tensorflow Project Template

A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.

So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)

Table Of Contents

In a Nutshell

In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:

  • In models folder create a class named VGG that inherit the "base_model" class
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • In trainers folder create a VGG trainer that inherit from "base_train" class
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • Override these two functions "train_step", "train_epoch" where you write the logic of the training process
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.

In Details

Project architecture

Folder structure

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

Main Components

Models


  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model contains:

    • Save -This function to save a checkpoint to the desk.
    • Load -This function to load a checkpoint from the desk.
    • Cur_epoch, Global_step counters -These variables to keep track of the current epoch and global step.
    • Init_Saver An abstract function to initialize the saver used for saving and loading the checkpoint, Note: override this function in the model you want to implement.
    • Build_model Here's an abstract function to define the model, Note: override this function in the model you want to implement.
  • Your model

    Here's where you implement your model. So you should :

    • Create your model class and inherit the base_model class
    • override "build_model" where you write the tensorflow model you want
    • override "init_save" where you create a tensorflow saver to use it to save and load checkpoint
    • call the "build_model" and "init_saver" in the initializer.

Trainer


  • Base trainer

    Base trainer is an abstract class that just wrap the training process.

  • Your trainer

    Here's what you should implement in your trainer.

    1. Create your trainer class and inherit the base_trainer class.
    2. override these two functions "train_step", "train_epoch" where you implement the training process of each step and each epoch.

Data Loader

This class is responsible for all data handling and processing and provide an easy interface that can be used by the trainer.

Logger

This class is responsible for the tensorboard summary, in your trainer create a dictionary of all tensorflow variables you want to summarize then pass this dictionary to logger.summarize().

This class also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric. Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Comet.ml Integration

This template also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric.

Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Here's how it looks after you start training:

You can also link your Github repository to your comet.ml project for full version control. Here's a live page showing the example from this repo

Configuration

I use Json as configuration method and then parse it, so write all configs you want then parse it using "utils/config/process_config" and pass this configuration object to all other objects.

Main

Here's where you combine all previous part.

  1. Parse the config file.
  2. Create a tensorflow session.
  3. Create an instance of "Model", "Data_Generator" and "Logger" and parse the config to all of them.
  4. Create an instance of "Trainer" and pass all previous objects to it.
  5. Now you can train your model by calling "Trainer.train()"

Future Work

  • Replace the data loader part with new tensorflow dataset API.

Contributing

Any kind of enhancement or contribution is welcomed.

Acknowledgments

Thanks for my colleague Mo'men Abdelrazek for contributing in this work. and thanks for Mohamed Zahran for the review. Thanks for Jtoy for including the repo in Awesome Tensorflow.

Owner
Mahmoud G. Salem
MSc. in AI at university of Guelph and Vector Institute. AI intern @samsung
Mahmoud G. Salem
v objective diffusion inference code for JAX.

v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models

Katherine Crowson 186 Dec 21, 2022
Official implementation of "Dynamic Anchor Learning for Arbitrary-Oriented Object Detection" (AAAI2021).

DAL This project hosts the official implementation for our AAAI 2021 paper: Dynamic Anchor Learning for Arbitrary-Oriented Object Detection [arxiv] [c

ming71 215 Nov 28, 2022
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022
NeuralDiff: Segmenting 3D objects that move in egocentric videos

NeuralDiff: Segmenting 3D objects that move in egocentric videos Project Page | Paper + Supplementary | Video About This repository contains the offic

Vadim Tschernezki 14 Dec 05, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
My 1st place solution at Kaggle Hotel-ID 2021

1st place solution at Kaggle Hotel-ID My 1st place solution at Kaggle Hotel-ID to Combat Human Trafficking 2021. https://www.kaggle.com/c/hotel-id-202

Kohei Ozaki 18 Aug 19, 2022
Zero-Cost Proxies for Lightweight NAS

Zero-Cost-NAS Companion code for the ICLR2021 paper: Zero-Cost Proxies for Lightweight NAS tl;dr A single minibatch of data is used to score neural ne

SamsungLabs 108 Dec 20, 2022
Simulation of the solar system using various nummerical methods

solar-system Simulation of the solar system using various nummerical methods Download the repo Make shure matplotlib, scipy etc. are installed execute

Caspar 7 Jul 15, 2022
DAT4 - General Assembly's Data Science course in Washington, DC

DAT4 Course Repository Course materials for General Assembly's Data Science course in Washington, DC (12/15/14 - 3/16/15). Instructors: Sinan Ozdemir

Kevin Markham 779 Dec 25, 2022
modelvshuman is a Python library to benchmark the gap between human and machine vision

modelvshuman is a Python library to benchmark the gap between human and machine vision. Using this library, both PyTorch and TensorFlow models can be evaluated on 17 out-of-distribution datasets with

Bethge Lab 244 Jan 03, 2023
This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models are Pix2Pix, Pix2PixHD, CycleGAN and PointWise.

RGB2NIR_Experimental This repository contains several image-to-image translation models, whcih were tested for RGB to NIR image generation. The models

5 Jan 04, 2023
Segmentation models with pretrained backbones. PyTorch.

Python library with Neural Networks for Image Segmentation based on PyTorch. The main features of this library are: High level API (just two lines to

Pavel Yakubovskiy 6.6k Jan 06, 2023
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

neural-combinatorial-rl-pytorch PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning. I have implemented the basic

Patrick E. 454 Jan 06, 2023
Code for binary and multiclass model change active learning, with spectral truncation implementation.

Model Change Active Learning Paper (To Appear) Python code for doing active learning in graph-based semi-supervised learning (GBSSL) paradigm. Impleme

Kevin Miller 1 Jul 24, 2022
A high-performance anchor-free YOLO. Exceeding yolov3~v5 with ONNX, TensorRT, NCNN, and Openvino supported.

YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and industrial communities. For more details, please refer to our rep

7.7k Jan 06, 2023
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

9 Nov 22, 2022
(Personalized) Page-Rank computation using PyTorch

torch-ppr This package allows calculating page-rank and personalized page-rank via power iteration with PyTorch, which also supports calculation on GP

Max Berrendorf 69 Dec 03, 2022
A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21

ANEMONE A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21 Dependencies python==3.6.1 dgl==

Graph Analysis & Deep Learning Laboratory, GRAND 30 Dec 14, 2022
Neural Magic Eye: Learning to See and Understand the Scene Behind an Autostereogram, arXiv:2012.15692.

Neural Magic Eye Preprint | Project Page | Colab Runtime Official PyTorch implementation of the preprint paper "NeuralMagicEye: Learning to See and Un

Zhengxia Zou 56 Jul 15, 2022
Official implementation of the NeurIPS'21 paper 'Conditional Generation Using Polynomial Expansions'.

Conditional Generation Using Polynomial Expansions Official implementation of the conditional image generation experiments as described on the NeurIPS

Grigoris 4 Aug 07, 2022