Edge-Augmented Graph Transformer

Overview

PWCPWCPWCPWCPWC

Edge-augmented Graph Transformer

Introduction

This is the official implementation of the Edge-augmented Graph Transformer (EGT) as described in https://arxiv.org/abs/2108.03348, which augments the Transformer architecture with residual edge channels. The resultant architecture can directly process graph-structured data and acheives good results on supervised graph-learning tasks as presented by Dwivedi et al.. It also achieves good performance on the large-scale PCQM4M-LSC (0.1263 MAE on val) dataset. EGT beats convolutional/message-passing graph neural networks on a wide range of supervised tasks and thus demonstrates that convolutional aggregation is not an essential inductive bias for graphs.

Requirements

  • python >= 3.7
  • tensorflow >= 2.1.0
  • h5py >= 2.8.0
  • numpy >= 1.18.4
  • scikit-learn >= 0.22.1

Download the Datasets

For our experiments, we converted the datasets to HDF5 format for the convenience of using them without any specific library. Only the h5py library is required. The datasets can be downloaded from -

Or you can simply run the provided bash scripts download_medium_scale_datasets.sh, download_large_scale_datasets.sh. The default location of the datasets is the datasets directory.

Run Training and Evaluations

You must create a JSON config file containing the configuration of a model, its training and evaluation configs (configurations). The same config file is used to do both training and evaluations.

  • To run training: python run_training.py <config_file.json>
  • To end training (prematurely): python end_training.py <config_file.json>
  • To perform evaluations: python do_evaluations.py <config_file.json>

Config files for the main results presented in the paper are contained in the configs/main directory, whereas configurations for the ablation study are contained in the configs/ablation directory. The paths and names of the files are self-explanatory.

More About Training and Evaluations

Once the training is started a model folder will be created in the models directory, under the specified dataset name. This folder will contain a copy of the input config file, for the convenience of resuming training/evaluation. Also, it will contain a config.json which will contain all configs, including unspecified default values, used for the training. Training will be checkpointed per epoch. In case of any interruption you can resume training by running the run_training.py with the config.json file again.

In case you wish to finalize training midway, just stop training and run end_training.py script with the config.json file to save the model weights.

After training, you can run the do_evaluations.py script with the same config file to perform evaluations. Alongside being printed to stdout, results will be saved in the predictions directory, under the model directory.

Config File

The config file can contain many different configurations, however, the only required configuration is scheme, which specifies the training scheme. If the other configurations are not specified, a default value will be assumed for them. Here are some of the commonly used configurations:

scheme: Used to specify the training scheme. It has a format <dataset_name>.<positional_encoding>. For example: cifar10.svd or zinc.eig. If no encoding is to be used it can be something like pcqm4m.mat. For a full list you can explore the lib/training/schemes directory.

dataset_path: If the datasets are contained in the default location in the datasets directory, this config need not be specified. Otherwise you have to point it towards the <dataset_name>.h5 file.

model_name: Serves as an identifier for the model, also specifies default path of the model directory, weight files etc.

save_path: The training process will create a model directory containing the logs, checkpoints, configs, model summary and predictions/evaluations. By default it creates a folder at models/<dataset_name>/<model_name> but it can be changed via this config.

cache_dir: During first time of training/evaluation the data will be cached to a tensorflow cache format. Default path is data_cache/<dataset_name>/<positional_encoding>. But it can be changed via this config.

distributed: In a multi-gpu setting you can set it to True, for distributed training.

batch_size: Batch size.

num_epochs: Maximum Number of epochs.

initial_lr: Initial learning rate. In case of warmup it is the maximum learning rate.

rlr_factor: Reduce LR on plateau factor. Setting it to a value >= 1.0 turns off Reduce LR.

rlr_patience: Reduce LR patience, i.e. the number of epochs after which LR is reduced if validation loss doesn't improve.

min_lr_factor: The factor by which the minimum LR is smaller, of the initial LR. Default is 0.01.

model_height: The number of layers L.

model_width: The dimensionality of the node channels d_h.

edge_width: The dimensionality of the edge channels d_e.

num_heads: The number of attention heads. Default is 8.

ffn_multiplier: FFN multiplier for both channels. Default is 2.0 .

virtual_nodes: number of virtual nodes. 0 (default) would result in global average pooling being used instead of virtual nodes.

upto_hop: Clipping value of the input distance matrix. A value of 1 (default) would result in adjacency matrix being used as input structural matrix.

mlp_layers: Dimensionality of the final MLP layers, specified as a list of factors with respect to d_h. Default is [0.5, 0.25].

gate_attention: Set this to False to get the ungated EGT variant (EGT-U).

dropout: Dropout rate for both channels. Default is 0.

edge_dropout: If specified, applies a different dropout rate to the edge channels.

edge_channel_type: Used to create ablated variants of EGT. A value of "residual" (default) implies pure/full EGT. "constrained" implies EGT-constrained. "bias" implies EGT-simple.

warmup_steps: If specified, performs a linear learning rate warmup for the specified number of gradient update steps.

total_steps: If specified, performs a cosine annealing after warmup, so that the model is trained for the specified number of steps.

[For SVD-based encodings]:

use_svd: Turning this off (False) would result in no positional encoding being used.

sel_svd_features: Rank of the SVD encodings r.

random_neg: Augment SVD encodings by random negation.

[For Eigenvectors encodings]:

use_eig: Turning this off (False) would result in no positional encoding being used.

sel_eig_features: Number of eigen vectors.

[For Distance prediction Objective (DO)]:

distance_target: Predict distance up to the specified hop, nu.

distance_loss: Factor by which to multiply the distance prediction loss, kappa.

Creation of the HDF5 Datasets from Scratch

We included two Jupyter notebooks to demonstrate how the HDF5 datasets are created

  • For the medium scale datasets view create_hdf_benchmarking_datasets.ipynb. You will need pytorch, ogb==1.1.1 and dgl==0.4.2 libraries to run the notebook. The notebook is also runnable on Google Colaboratory.
  • For the large scale pcqm4m dataset view create_hdf_pcqm4m.ipynb. You will need pytorch, ogb>=1.3.0 and rdkit>=2019.03.1 to run the notebook.

Python Environment

The Anaconda environment in which our experiments were conducted is specified in the environment.yml file.

Citation

Please cite the following paper if you find the code useful:

@article{hussain2021edge,
  title={Edge-augmented Graph Transformers: Global Self-attention is Enough for Graphs},
  author={Hussain, Md Shamim and Zaki, Mohammed J and Subramanian, Dharmashankar},
  journal={arXiv preprint arXiv:2108.03348},
  year={2021}
}
Owner
Md Shamim Hussain
Md Shamim Hussain is a Ph.D. student in Computer Science at Rensselaer Polytechnic Institute, NY. He got his B.Sc. and M.Sc. in EEE from BUET, Dhaka.
Md Shamim Hussain
运小筹公众号是致力于分享运筹优化(LP、MIP、NLP、随机规划、鲁棒优化)、凸优化、强化学习等研究领域的内容以及涉及到的算法的代码实现。

OlittleRer 运小筹公众号是致力于分享运筹优化(LP、MIP、NLP、随机规划、鲁棒优化)、凸优化、强化学习等研究领域的内容以及涉及到的算法的代码实现。编程语言和工具包括Java、Python、Matlab、CPLEX、Gurobi、SCIP 等。 关注我们: 运筹小公众号 有问题可以直接在

运小筹 151 Dec 30, 2022
Comprehensive-E2E-TTS - PyTorch Implementation

A Non-Autoregressive End-to-End Text-to-Speech (text-to-wav), supporting a family of SOTA unsupervised duration modelings. This project grows with the research community, aiming to achieve the ultima

Keon Lee 114 Nov 13, 2022
KoBART model on huggingface transformers

KoBART-Transformers SKT에서 공개한 KoBART를 편리하게 사용할 수 있게 transformers로 포팅하였습니다. Install (Optional) BartModel과 PreTrainedTokenizerFast를 이용하면 설치하실 필요 없습니다. p

Hyunwoong Ko 58 Dec 07, 2022
Neural building blocks for speaker diarization: speech activity detection, speaker change detection, overlapped speech detection, speaker embedding

⚠️ Checkout develop branch to see what is coming in pyannote.audio 2.0: a much smaller and cleaner codebase Python-first API (the good old pyannote-au

pyannote 2.2k Jan 09, 2023
PyTorch original implementation of Cross-lingual Language Model Pretraining.

XLM NEW: Added XLM-R model. PyTorch original implementation of Cross-lingual Language Model Pretraining. Includes: Monolingual language model pretrain

Facebook Research 2.7k Dec 27, 2022
The code from the whylogs workshop in DataTalks.Club on 29 March 2022

whylogs Workshop The code from the whylogs workshop in DataTalks.Club on 29 March 2022 whylogs - The open source standard for data logging (Don't forg

DataTalksClub 12 Sep 05, 2022
Deal or No Deal? End-to-End Learning for Negotiation Dialogues

Introduction This is a PyTorch implementation of the following research papers: (1) Hierarchical Text Generation and Planning for Strategic Dialogue (

Facebook Research 1.4k Dec 29, 2022
Score-Based Point Cloud Denoising (ICCV'21)

Score-Based Point Cloud Denoising (ICCV'21) [Paper] https://arxiv.org/abs/2107.10981 Installation Recommended Environment The code has been tested in

Shitong Luo 79 Dec 26, 2022
A crowdsourced dataset of dialogues grounded in social contexts involving utilization of commonsense.

A crowdsourced dataset of dialogues grounded in social contexts involving utilization of commonsense.

Alexa 62 Dec 20, 2022
skweak: A software toolkit for weak supervision applied to NLP tasks

Labelled data remains a scarce resource in many practical NLP scenarios. This is especially the case when working with resource-poor languages (or text domains), or when using task-specific labels wi

Norsk Regnesentral (Norwegian Computing Center) 850 Dec 28, 2022
Official implementation of Meta-StyleSpeech and StyleSpeech

Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation Dongchan Min, Dong Bok Lee, Eunho Yang, and Sung Ju Hwang This is an official code

min95 169 Jan 05, 2023
a CTF web challenge about making screenshots

screenshotter (web) A CTF web challenge about making screenshots. It is inspired by a bug found in real life. The challenge was created by @LiveOverfl

219 Jan 02, 2023
Mycroft Core, the Mycroft Artificial Intelligence platform.

Mycroft Mycroft is a hackable open source voice assistant. Table of Contents Getting Started Running Mycroft Using Mycroft Home Device and Account Man

Mycroft 6.1k Jan 09, 2023
Part of Speech Tagging using Hidden Markov Model (HMM) POS Tagger and Brill Tagger

Part of Speech Tagging using Hidden Markov Model (HMM) POS Tagger and Brill Tagger In this project, our aim is to tune, compare, and contrast the perf

Chirag Daryani 0 Dec 25, 2021
Yomichad - a Japanese pop-up dictionary that can display readings and English definitions of Japanese words

Yomichad is a Japanese pop-up dictionary that can display readings and English definitions of Japanese words, kanji, and optionally named entities. It is similar to yomichan, 10ten, and rikaikun in s

Jonas Belouadi 7 Nov 07, 2022
Share constant definitions between programming languages and make your constants constant again

Introduction Reconstant lets you share constant and enum definitions between programming languages. Constants are defined in a yaml file and converted

Natan Yellin 47 Sep 10, 2022
StarGAN - Official PyTorch Implementation

StarGAN - Official PyTorch Implementation ***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 ***** This repository provides t

Yunjey Choi 5.1k Dec 30, 2022
SimpleChinese2 集成了许多基本的中文NLP功能,使基于 Python 的中文文字处理和信息提取变得简单方便。

SimpleChinese2 SimpleChinese2 集成了许多基本的中文NLP功能,使基于 Python 的中文文字处理和信息提取变得简单方便。 声明 本项目是为方便个人工作所创建的,仅有部分代码原创。

Ming 30 Dec 02, 2022
Collection of useful (to me) python scripts for interacting with napari

Napari scripts A collection of napari related tools in various state of disrepair/functionality. Browse_LIF_widget.py This module can be imported, for

5 Aug 15, 2022