A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Overview

Attention Walk

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018).

Abstract

Graph embedding methods represent nodes in a continuous vector space, preserving different types of relational information from the graph. There are many hyper-parameters to these methods (e.g. the length of a random walk) which have to be manually tuned for every graph. In this paper, we replace previously fixed hyper-parameters with trainable ones that we automatically learn via backpropagation. In particular, we propose a novel attention model on the power series of the transition matrix, which guides the random walk to optimize an upstream objective. Unlike previous approaches to attention models, the method that we propose utilizes attention parameters exclusively on the data itself (e.g. on the random walk), and are not used by the model for inference. We experiment on link prediction tasks, as we aim to produce embeddings that best-preserve the graph structure, generalizing to unseen information. We improve state-of-the-art results on a comprehensive suite of real-world graph datasets including social, collaboration, and biological networks, where we observe that our graph attention model can reduce the error by up to 20%-40%. We show that our automatically-learned attention parameters can vary significantly per graph, and correspond to the optimal choice of hyper-parameter if we manually tune existing methods.

This repository provides an implementation of Attention Walk as described in the paper:

Watch Your Step: Learning Node Embeddings via Graph Attention. Sami Abu-El-Haija, Bryan Perozzi, Rami Al-Rfou, Alexander A. Alemi. NIPS, 2018. [Paper]

The original Tensorflow implementation is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torchvision       0.3.0

Datasets

The code takes an input graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. Sample graphs for the `Twitch Brasilians` and `Wikipedia Chameleons` are included in the `input/` directory.

### Options

Learning of the embedding is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path         STR   Input graph path.     Default is `input/chameleon_edges.csv`.
  --embedding-path    STR   Embedding path.       Default is `output/chameleon_AW_embedding.csv`.
  --attention-path    STR   Attention path.       Default is `output/chameleon_AW_attention.csv`.

Model options

  --dimensions           INT       Number of embeding dimensions.        Default is 128.
  --epochs               INT       Number of training epochs.            Default is 200.
  --window-size          INT       Skip-gram window size.                Default is 5.
  --learning-rate        FLOAT     Learning rate value.                  Default is 0.01.
  --beta                 FLOAT     Attention regularization parameter.   Default is 0.5.
  --gamma                FLOAT     Embedding regularization parameter.   Default is 0.5.
  --num-of-walks         INT       Number of walks per source node.      Default is 80.

Examples

The following commands learn a graph embedding and write the embedding to disk. The node representations are ordered by the ID.

Creating an Attention Walk embedding of the default dataset with the standard hyperparameter settings. Saving this embedding at the default path.

``` python src/main.py ```

Creating an Attention Walk embedding of the default dataset with 256 dimensions.

python src/main.py --dimensions 256

Creating an Attention Walk embedding of the default dataset with a higher window size.

python src/main.py --window-size 20

Creating an embedding of another dataset the Twitch Brasilians. Saving the outputs under custom file names.

python src/main.py --edge-path input/ptbr_edges.csv --embedding-path output/ptbr_AW_embedding.csv --attention-path output/ptbr_AW_attention.csv

License


Comments
  • Nan parameters

    Nan parameters

    Thanks for your pytorch code. I found that my parameters become Nan during training. Nan parameters include model.left_factors, model.right_factors, model.attention. All the entries of them become Nan during training. And also the loss. I'm trying to find the reason. I would appreciate it if you could give me some help or hints.

    opened by kkkkk001 9
  • Memory Error

    Memory Error

    I'm getting OOM errors even with small files. The attached file link_network.txt throws the following error:

    Adjacency matrix powers: 100%|███████████████████████████████████████████████████████| 4/4 [00:00<00:00, 108.39it/s]
    Traceback (most recent call last):
      File "src\main.py", line 79, in <module>
        main()
      File "src\main.py", line 74, in main
        model = AttentionWalkTrainer(args)
      File "E:\AttentionWalk\src\attentionwalk.py", line 70, in __init__
        self.initialize_model_and_features()
      File "E:\AttentionWalk\src\attentionwalk.py", line 76, in initialize_model_and_features
        self.target_tensor = feature_calculator(self.args, self.graph)
      File "E:\AttentionWalk\src\utils.py", line 53, in feature_calculator
        target_matrices = np.array(target_matrices)
    MemoryError
    

    I guess this is due to the large indices of the nodes. Any workarounds for this?

    opened by davidlenz 2
  • modified normalized_adjacency_matrix calculation

    modified normalized_adjacency_matrix calculation

    As mentioned in this issue: https://github.com/benedekrozemberczki/AttentionWalk/issues/9

    Added normalization into calculation, able to prevent unbalanced loss and prevent loss_on_mat to be extreme big while node count of data is big.

    opened by neilctwu 1
  • miscalculations of normalized adjacency matrix

    miscalculations of normalized adjacency matrix

    Thanks for sharing this awesome repo.

    The issue is I found that loss_on_target will become extreme big while training from the original code, and I think is due to the miscalculation of normalized_adjacency_matrix.

    From your original code, normalized_adjacency_matrix is been calculated by:

    normalized_adjacency_matrix = degs.dot(adjacency_matrix)
    

    However while the matrix hasn't been normalize but simply multiple by degree of nodes. I think the part of normalized_adjacency_matrix should be modified like its original definition:

      normalized_adjacency_matrix = degs.power(-1/2)\
                                        .dot(adjacency_matrix)\
                                        .dot(degs.power(-1/2))
    

    It'll turn out to be more reasonable loss shown below: image

    Am I understand it correctly?

    opened by neilctwu 1
  • problem with being killed

    problem with being killed

    Hi, I tried to train the model with new dataset which have about 60000 nodes, but I have a problem of getting Killed suddenly. Do you have any idea why? Thanks :) image

    opened by amy-hyunji 1
  • Directed weighted graphs

    Directed weighted graphs

    Is it possible to use the code with directed and weighted graphs? The paper states the attention walk framework for unweighted graphs only, but i'd like to use it for such types of networks. Thank you for your attention.

    opened by federicoairoldi 1
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
High-quality implementations of standard and SOTA methods on a variety of tasks.

Uncertainty Baselines The goal of Uncertainty Baselines is to provide a template for researchers to build on. The baselines can be a starting point fo

Google 1.1k Dec 30, 2022
MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

Burak Bagatarhan 12 Mar 29, 2022
GLNet for Memory-Efficient Segmentation of Ultra-High Resolution Images

GLNet for Memory-Efficient Segmentation of Ultra-High Resolution Images Collaborative Global-Local Networks for Memory-Efficient Segmentation of Ultra-

VITA 298 Dec 12, 2022
Cooperative Driving Dataset: a dataset for multi-agent driving scenarios

Cooperative Driving Dataset (CODD) The Cooperative Driving dataset is a synthetic dataset generated using CARLA that contains lidar data from multiple

Eduardo Henrique Arnold 124 Dec 28, 2022
render sprites into your desktop environment as shaped windows using GTK

spritegtk render static or animated sprites into your desktop environment as dynamic shaped windows using GTK requires pycairo and PYGobject: pip inst

hermit 20 Oct 27, 2022
Implementing DropPath/StochasticDepth in PyTorch

%load_ext memory_profiler Implementing Stochastic Depth/Drop Path In PyTorch DropPath is available on glasses my computer vision library! Introduction

Francesco Saverio Zuppichini 13 Jan 05, 2023
This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction

H3DS Dataset This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction Access

Crisalix 72 Dec 10, 2022
Hypersim: A Photorealistic Synthetic Dataset for Holistic Indoor Scene Understanding

The Hypersim Dataset For many fundamental scene understanding tasks, it is difficult or impossible to obtain per-pixel ground truth labels from real i

Apple 1.3k Jan 04, 2023
URIE: Universal Image Enhancementfor Visual Recognition in the Wild

URIE: Universal Image Enhancementfor Visual Recognition in the Wild This is the implementation of the paper "URIE: Universal Image Enhancement for Vis

Taeyoung Son 43 Sep 12, 2022
DeepOBS: A Deep Learning Optimizer Benchmark Suite

DeepOBS - A Deep Learning Optimizer Benchmark Suite DeepOBS is a benchmarking suite that drastically simplifies, automates and improves the evaluation

Aaron Bahde 7 May 12, 2020
The implementation of 'Image synthesis via semantic composition'.

Image synthesis via semantic synthesis [Project Page] by Yi Wang, Lu Qi, Ying-Cong Chen, Xiangyu Zhang, Jiaya Jia. Introduction This repository gives

DV Lab 71 Jan 06, 2023
Code for SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021)

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021) SyncTwin is a treatment effect estimation method tailored for observat

Zhaozhi Qian 3 Nov 03, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 865 Nov 17, 2022
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 07, 2023
This repository contains all code and data for the Inside Out Visual Place Recognition task

Inside Out Visual Place Recognition This repository contains code and instructions to reproduce the results for the Inside Out Visual Place Recognitio

15 May 21, 2022
Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks"

LUNAR Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks" Adam Goodge, Bryan Hooi, Ng See Kiong and

Adam Goodge 25 Dec 28, 2022
Multi-agent reinforcement learning algorithm and environment

Multi-agent reinforcement learning algorithm and environment [en/cn] Pytorch implements multi-agent reinforcement learning algorithms including IQL, Q

万鲲鹏 7 Sep 20, 2022
deep_image_prior_extension

Code for "Is Deep Image Prior in Need of a Good Education?" Project page: https://jleuschn.github.io/docs.educated_deep_image_prior/. Supplementary Ma

riccardo barbano 7 Jan 09, 2022
A novel Engagement Detection with Multi-Task Training (ED-MTT) system

A novel Engagement Detection with Multi-Task Training (ED-MTT) system which minimizes MSE and triplet loss together to determine the engagement level of students in an e-learning environment.

Onur Çopur 12 Nov 11, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021 Accepted

NU-Wave — Official PyTorch Implementation NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling Junhyeok Lee, Seungu Han @ MINDsLab Inc

MINDs Lab 242 Dec 23, 2022