Convert ONNX model graph to Keras model format.

Overview

onnx2keras

ONNX to Keras deep neural network converter.

GitHub License Python Version Downloads PyPI

Requirements

TensorFlow 2.0

API

onnx_to_keras(onnx_model, input_names, input_shapes=None, name_policy=None, verbose=True, change_ordering=False) -> {Keras model}

onnx_model: ONNX model to convert

input_names: list with graph input names

input_shapes: override input shapes (experimental)

name_policy: ['renumerate', 'short', 'default'] override layer names (experimental)

verbose: detailed output

change_ordering: change ordering to HWC (experimental)

Getting started

ONNX model

import onnx
from onnx2keras import onnx_to_keras

# Load ONNX model
onnx_model = onnx.load('resnet18.onnx')

# Call the converter (input - is the main model input name, can be different for your model)
k_model = onnx_to_keras(onnx_model, ['input'])

Keras model will be stored to the k_model variable. So simple, isn't it?

PyTorch model

Using ONNX as intermediate format, you can convert PyTorch model as well.

import numpy as np
import torch
from torch.autograd import Variable
from pytorch2keras.converter import pytorch_to_keras
import torchvision.models as models

if __name__ == '__main__':
    input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
    input_var = Variable(torch.FloatTensor(input_np))
    model = models.resnet18()
    model.eval()
    k_model = \
        pytorch_to_keras(model, input_var, [(3, 224, 224,)], verbose=True, change_ordering=True)

    for i in range(3):
        input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
        input_var = Variable(torch.FloatTensor(input_np))
        output = model(input_var)
        pytorch_output = output.data.numpy()
        keras_output = k_model.predict(np.transpose(input_np, [0, 2, 3, 1]))
        error = np.max(pytorch_output - keras_output)
        print('error -- ', error)  # Around zero :)

Deplying model as frozen graph

You can try using the snippet below to convert your onnx / PyTorch model to frozen graph. It may be useful for deploy for Tensorflow.js / for Tensorflow for Android / for Tensorflow C-API.

import numpy as np
import torch
from pytorch2keras.converter import pytorch_to_keras
from torch.autograd import Variable
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2


# Create and load model
model = Model()
model.load_state_dict(torch.load('model-checkpoint.pth'))
model.eval()

# Make dummy variables (and checking if the model works)
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
input_var = Variable(torch.FloatTensor(input_np))
output = model(input_var)

# Convert the model!
k_model = \
    pytorch_to_keras(model, input_var, (3, 224, 224), 
                     verbose=True, name_policy='short',
                     change_ordering=True)

# Save model to SavedModel format
tf.saved_model.save(k_model, "./models")

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: k_model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(k_model.inputs[0].shape, k_model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

print("-" * 50)
print("Frozen model layers: ")
for layer in [op.name for op in frozen_func.graph.get_operations()]:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

License

This software is covered by MIT License.

Owner
Grigory Malivenko
Machine Learning Engineer
Grigory Malivenko
The aim of this project is to build an AI bot that can play the Wordle game, or more generally Squabble

Wordle RL The aim of this project is to build an AI bot that can play the Wordle game, or more generally Squabble I know there are more deterministic

Aditya Arora 3 Feb 22, 2022
BC3407-Group-5-Project - BC3407 Group Project With Python

BC3407-Group-5-Project As the world struggles to contain the ever-changing varia

1 Jan 26, 2022
Resources related to our paper "CLIN-X: pre-trained language models and a study on cross-task transfer for concept extraction in the clinical domain"

CLIN-X (CLIN-X-ES) & (CLIN-X-EN) This repository holds the companion code for the system reported in the paper: "CLIN-X: pre-trained language models a

Bosch Research 4 Dec 05, 2022
A curated list and survey of awesome Vision Transformers.

English | 简体中文 A curated list and survey of awesome Vision Transformers. You can use mind mapping software to open the mind mapping source file. You c

OpenMMLab 281 Dec 21, 2022
Bayesian Optimization Library for Medical Image Segmentation.

bayesmedaug: Bayesian Optimization Library for Medical Image Segmentation. bayesmedaug optimizes your data augmentation hyperparameters for medical im

Şafak Bilici 7 Feb 10, 2022
[AAAI 2022] Separate Contrastive Learning for Organs-at-Risk and Gross-Tumor-Volume Segmentation with Limited Annotation

A paper Introduction This is an official release of the paper Separate Contrastive Learning for Organs-at-Risk and Gross-Tumor-Volume Segmentation wit

Jiacheng Wang 14 Dec 08, 2022
PyStan, a Python interface to Stan, a platform for statistical modeling. Documentation: https://pystan.readthedocs.io

PyStan NOTE: This documentation describes a BETA release of PyStan 3. PyStan is a Python interface to Stan, a package for Bayesian inference. Stan® is

Stan 229 Dec 29, 2022
View model summaries in PyTorch!

torchinfo (formerly torch-summary) Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensor

Tyler Yep 1.5k Jan 05, 2023
List of all dependencies affected by node-ipc malicious commit

node-ipc-dependencies-list List of all dependencies affected by node-ipc malicious commit as of 17/3/2022 - 19/3/2022 (timestamp) Please improve upon

99 Oct 15, 2022
SARS-Cov-2 Recombinant Finder for fasta sequences

Sc2rf - SARS-Cov-2 Recombinant Finder Pronounced: Scarf What's this? Sc2rf can search genome sequences of SARS-CoV-2 for potential recombinants - new

Lena Schimmel 41 Oct 03, 2022
A Comparative Review of Recent Kinect-Based Action Recognition Algorithms (TIP2020, Matlab codes)

A Comparative Review of Recent Kinect-Based Action Recognition Algorithms This repo contains: the HDG implementation (Matlab codes) for 'Analysis and

Lei Wang 5 Oct 22, 2022
Library extending Jupyter notebooks to integrate with Apache TinkerPop and RDF SPARQL.

Graph Notebook: easily query and visualize graphs The graph notebook provides an easy way to interact with graph databases using Jupyter notebooks. Us

Amazon Web Services 501 Dec 28, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Mask-invariant Face Recognition through Template-level Knowledge Distillation

Mask-invariant Face Recognition through Template-level Knowledge Distillation This is the official repository of "Mask-invariant Face Recognition thro

Fadi Boutros 35 Dec 06, 2022
This is the official pytorch implementation of AutoDebias, an automatic debiasing method for recommendation.

AutoDebias This is the official pytorch implementation of AutoDebias, a debiasing method for recommendation system. AutoDebias is proposed in the pape

Dong Hande 77 Nov 25, 2022
PyTorch implementation of SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching

SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching This is the official PyTorch implementation of SMODICE: Versatile Offline I

Jason Ma 14 Aug 30, 2022
Official PyTorch implementation of CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds Introduction This is the official PyTorch implementation of o

Yijia Weng 96 Dec 07, 2022
The Illinois repository for Climatehack (https://climatehack.ai/). We won 1st place!

Climatehack This is the repository for Illinois's Climatehack Team. We earned first place on the leaderboard with a final score of 0.87992. An overvie

Jatin Mathur 20 Jun 09, 2022
Tools for investing in Python

InvestOps Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction This is a Python package with simple and effective

24 Nov 26, 2022