A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Overview

Attention Walk

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018).

Abstract

Graph embedding methods represent nodes in a continuous vector space, preserving different types of relational information from the graph. There are many hyper-parameters to these methods (e.g. the length of a random walk) which have to be manually tuned for every graph. In this paper, we replace previously fixed hyper-parameters with trainable ones that we automatically learn via backpropagation. In particular, we propose a novel attention model on the power series of the transition matrix, which guides the random walk to optimize an upstream objective. Unlike previous approaches to attention models, the method that we propose utilizes attention parameters exclusively on the data itself (e.g. on the random walk), and are not used by the model for inference. We experiment on link prediction tasks, as we aim to produce embeddings that best-preserve the graph structure, generalizing to unseen information. We improve state-of-the-art results on a comprehensive suite of real-world graph datasets including social, collaboration, and biological networks, where we observe that our graph attention model can reduce the error by up to 20%-40%. We show that our automatically-learned attention parameters can vary significantly per graph, and correspond to the optimal choice of hyper-parameter if we manually tune existing methods.

This repository provides an implementation of Attention Walk as described in the paper:

Watch Your Step: Learning Node Embeddings via Graph Attention. Sami Abu-El-Haija, Bryan Perozzi, Rami Al-Rfou, Alexander A. Alemi. NIPS, 2018. [Paper]

The original Tensorflow implementation is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torchvision       0.3.0

Datasets

The code takes an input graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. Sample graphs for the `Twitch Brasilians` and `Wikipedia Chameleons` are included in the `input/` directory.

### Options

Learning of the embedding is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path         STR   Input graph path.     Default is `input/chameleon_edges.csv`.
  --embedding-path    STR   Embedding path.       Default is `output/chameleon_AW_embedding.csv`.
  --attention-path    STR   Attention path.       Default is `output/chameleon_AW_attention.csv`.

Model options

  --dimensions           INT       Number of embeding dimensions.        Default is 128.
  --epochs               INT       Number of training epochs.            Default is 200.
  --window-size          INT       Skip-gram window size.                Default is 5.
  --learning-rate        FLOAT     Learning rate value.                  Default is 0.01.
  --beta                 FLOAT     Attention regularization parameter.   Default is 0.5.
  --gamma                FLOAT     Embedding regularization parameter.   Default is 0.5.
  --num-of-walks         INT       Number of walks per source node.      Default is 80.

Examples

The following commands learn a graph embedding and write the embedding to disk. The node representations are ordered by the ID.

Creating an Attention Walk embedding of the default dataset with the standard hyperparameter settings. Saving this embedding at the default path.

``` python src/main.py ```

Creating an Attention Walk embedding of the default dataset with 256 dimensions.

python src/main.py --dimensions 256

Creating an Attention Walk embedding of the default dataset with a higher window size.

python src/main.py --window-size 20

Creating an embedding of another dataset the Twitch Brasilians. Saving the outputs under custom file names.

python src/main.py --edge-path input/ptbr_edges.csv --embedding-path output/ptbr_AW_embedding.csv --attention-path output/ptbr_AW_attention.csv

License


Comments
  • Nan parameters

    Nan parameters

    Thanks for your pytorch code. I found that my parameters become Nan during training. Nan parameters include model.left_factors, model.right_factors, model.attention. All the entries of them become Nan during training. And also the loss. I'm trying to find the reason. I would appreciate it if you could give me some help or hints.

    opened by kkkkk001 9
  • Memory Error

    Memory Error

    I'm getting OOM errors even with small files. The attached file link_network.txt throws the following error:

    Adjacency matrix powers: 100%|███████████████████████████████████████████████████████| 4/4 [00:00<00:00, 108.39it/s]
    Traceback (most recent call last):
      File "src\main.py", line 79, in <module>
        main()
      File "src\main.py", line 74, in main
        model = AttentionWalkTrainer(args)
      File "E:\AttentionWalk\src\attentionwalk.py", line 70, in __init__
        self.initialize_model_and_features()
      File "E:\AttentionWalk\src\attentionwalk.py", line 76, in initialize_model_and_features
        self.target_tensor = feature_calculator(self.args, self.graph)
      File "E:\AttentionWalk\src\utils.py", line 53, in feature_calculator
        target_matrices = np.array(target_matrices)
    MemoryError
    

    I guess this is due to the large indices of the nodes. Any workarounds for this?

    opened by davidlenz 2
  • modified normalized_adjacency_matrix calculation

    modified normalized_adjacency_matrix calculation

    As mentioned in this issue: https://github.com/benedekrozemberczki/AttentionWalk/issues/9

    Added normalization into calculation, able to prevent unbalanced loss and prevent loss_on_mat to be extreme big while node count of data is big.

    opened by neilctwu 1
  • miscalculations of normalized adjacency matrix

    miscalculations of normalized adjacency matrix

    Thanks for sharing this awesome repo.

    The issue is I found that loss_on_target will become extreme big while training from the original code, and I think is due to the miscalculation of normalized_adjacency_matrix.

    From your original code, normalized_adjacency_matrix is been calculated by:

    normalized_adjacency_matrix = degs.dot(adjacency_matrix)
    

    However while the matrix hasn't been normalize but simply multiple by degree of nodes. I think the part of normalized_adjacency_matrix should be modified like its original definition:

      normalized_adjacency_matrix = degs.power(-1/2)\
                                        .dot(adjacency_matrix)\
                                        .dot(degs.power(-1/2))
    

    It'll turn out to be more reasonable loss shown below: image

    Am I understand it correctly?

    opened by neilctwu 1
  • problem with being killed

    problem with being killed

    Hi, I tried to train the model with new dataset which have about 60000 nodes, but I have a problem of getting Killed suddenly. Do you have any idea why? Thanks :) image

    opened by amy-hyunji 1
  • Directed weighted graphs

    Directed weighted graphs

    Is it possible to use the code with directed and weighted graphs? The paper states the attention walk framework for unweighted graphs only, but i'd like to use it for such types of networks. Thank you for your attention.

    opened by federicoairoldi 1
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Official implementation of "A Shared Representation for Photorealistic Driving Simulators" in PyTorch.

A Shared Representation for Photorealistic Driving Simulators The official code for the paper: "A Shared Representation for Photorealistic Driving Sim

VITA lab at EPFL 7 Oct 13, 2022
Joint Detection and Identification Feature Learning for Person Search

Person Search Project This repository hosts the code for our paper Joint Detection and Identification Feature Learning for Person Search. The code is

712 Dec 17, 2022
A pytorch implementation of faster RCNN detection framework (Use detectron2, it's a masterpiece)

Notice(2019.11.2) This repo was built back two years ago when there were no pytorch detection implementation that can achieve reasonable performance.

Ruotian(RT) Luo 1.8k Jan 01, 2023
This repository contains several jupyter notebooks to help users learn to use neon, our deep learning framework

neon_course This repository contains several jupyter notebooks to help users learn to use neon, our deep learning framework. For more information, see

Nervana 92 Jan 03, 2023
A Web API for automatic background removal using Deep Learning. App is made using Flask and deployed on Heroku.

Automatic_Background_Remover A Web API for automatic background removal using Deep Learning. App is made using Flask and deployed on Heroku. 👉 https:

Gaurav 16 Oct 29, 2022
Torch implementation of various types of GAN (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN, LSGAN)

gans-collection.torch Torch implementation of various types of GANs (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN). Note that EBGAN and

Minchul Shin 53 Jan 22, 2022
SelfRemaster: SSL Speech Restoration

SelfRemaster: Self-Supervised Speech Restoration Official implementation of SelfRemaster: Self-Supervised Speech Restoration with Analysis-by-Synthesi

Takaaki Saeki 46 Jan 07, 2023
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
you can add any codes in any language by creating its respective folder (if already not available).

HACKTOBERFEST-2021-WEB-DEV Beginner-Hacktoberfest Need Your first pr for hacktoberfest 2k21 ? come on in About This is repository of Responsive Portfo

Suman Sharma 8 Oct 17, 2022
labelpix is a graphical image labeling interface for drawing bounding boxes

Welcome to labelpix 👋 labelpix is a graphical image labeling interface for drawing bounding boxes. 🏠 Homepage Install pip install -r requirements.tx

schissmantics 26 May 24, 2022
Pytorch Implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension)

DiffSinger - PyTorch Implementation PyTorch implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension). Status

Keon Lee 152 Jan 02, 2023
This repository provides the official code for GeNER (an automated dataset Generation framework for NER).

GeNER This repository provides the official code for GeNER (an automated dataset Generation framework for NER). Overview of GeNER GeNER allows you to

DMIS Laboratory - Korea University 50 Nov 30, 2022
A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.

P-tuning A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''. How to use our code We have released the code

THUDM 562 Dec 27, 2022
[NeurIPS 2021] "Delayed Propagation Transformer: A Universal Computation Engine towards Practical Control in Cyber-Physical Systems"

Delayed Propagation Transformer: A Universal Computation Engine towards Practical Control in Cyber-Physical Systems Introduction Multi-agent control i

VITA 6 May 05, 2022
Deep learning for spiking neural networks

A deep learning library for spiking neural networks. Norse aims to exploit the advantages of bio-inspired neural components, which are sparse and even

Electronic Vision(s) Group — BrainScaleS Neuromorphic Hardware 59 Nov 28, 2022
This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset

HiRID-ICU-Benchmark This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset for which the manuscript can be found here.

Biomedical Informatics at ETH Zurich 30 Dec 16, 2022
Data manipulation and transformation for audio signal processing, powered by PyTorch

torchaudio: an audio library for PyTorch The aim of torchaudio is to apply PyTorch to the audio domain. By supporting PyTorch, torchaudio follows the

1.9k Dec 28, 2022
GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs

GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs GraphLily is the first FPGA overlay for graph processing. GraphLily supports a rich se

Cornell Zhang Research Group 39 Dec 13, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 18, 2021
Clockwork Convnets for Video Semantic Segmentation

Clockwork Convnets for Video Semantic Segmentation This is the reference implementation of arxiv:1608.03609: Clockwork Convnets for Video Semantic Seg

Evan Shelhamer 141 Nov 21, 2022