PyTorch Implementation of SSTNs for hyperspectral image classifications from the IEEE T-GRS paper "Spectral-Spatial Transformer Network for Hyperspectral Image Classification: A FAS Framework."

Related tags

Deep LearningSSTN
Overview

PyTorch Implementation of SSTN for Hyperspectral Image Classification

Paper links: SSTN published on IEEE T-GRS. Also, you can directly find the implementation of SSTN and SSRN here: NetworkBlocks

UPDATE: Source codes of training and testing SSTN/SSRN on Kennedy Space Center (KSC) dataset have been added, in addition to those on Pavia Center (PC), Indian Pines(IN), and University of Pavia (UP) datasets.

Here is the bibliography info:

Zilong Zhong, Ying Li, Lingfei Ma, Jonathan Li, Wei-Shi Zheng. "Spectral-Spatial Transformer 
Network for Hyperspectral Image Classification: A Factorized Architecture Search Framework.” 
IEEE Transactions on Geoscience and Remote Sensing, DOI:10.1109/TGRS.2021.3115699,2021.

Description

Neural networks have dominated the research of hyperspectral image classification, attributing to the feature learning capacity of convolution operations. However, the fixed geometric structure of convolution kernels hinders long-range interaction between features from distant locations. In this work, we propose a novel spectral-spatial transformer network (SSTN), which consists of spatial attention and spectral association modules, to overcome the constraints of convolution kernels. Extensive experiments conducted on three popular hyperspectral image benchmarks demonstrate the versatility of SSTNs over other state-of-the-art (SOTA) methods. Most importantly, SSTN obtains comparable accuracy to or outperforms SOTA methods with only 1.2% of multiply-and-accumulate (MAC) operations compared to a strong baseline SSRN.

Fig.1 Spectral-Spatial Transformer Network (SSTN) with the architecture of 'AEAE', in which 'A' and 'E' stand for a spatial attention block and a spectral association block, respectively. (a) Search space for unit setting. (b) Search space for block sequence.

Fig.2 Illustration of spatial attention module (left) and spectral association module (right). The attention maps Attn in the spatial attention module is produced by multiplying two reshaped tensors Q and K. Instead, the attention maps M1 and M2 in the spectral association module are the direct output of a convolution operation. The spectral association kernels Asso represent a compact set of spectral vectors used to reconstruct input feature X.

Prerequisites

When you create a conda environment, check you have installed the packages in the package-list. You can also refer to the managing environments of conda.

Usage

HSI data can be downloaded from this website HyperspectralData. Before training or evaluating different models, please make sure the datasets are in the correct folder and download the Pavia Center (PC) HSI dataset, which is too large to upload here. For example, the raw HSI imagery and its ground truth map for the PC dataset should be put in the following two paths:

./dataset/PC/Pavia.mat
./dataset/PC/Pavia_gt.mat 

Commands to train SSTNs with widely studied hyperspectral imagery (HSI) datasets:

$ python train_PC.py
$ python train_KSC.py
$ python train_IN.py
$ python train_UP.py

Commands to train SSRNs with widely studied hyperspectral imagery (HSI) datasets:

$ python train_PC.py --model SSRN
$ python train_KSC.py --model SSRN
$ python train_IN.py --model SSRN
$ python train_UP.py --model SSRN

Commands to test trained SSTNs and generate classification maps:

$ python test_IN.py
$ python test_KSC.py
$ python test_UP.py
$ python test_PC.py

Commands to test trained SSRNs and generate classification maps:

$ python test_IN.py --model SSRN
$ python test_KSC.py --model SSRN
$ python test_UP.py --model SSRN
$ python test_PC.py --model SSRN

Result of Pavia Center (PC) Dataset

Fig.3 Classification maps of different models with 200 samples for training on the PC dataset. (a) False color image. (b) Ground truth labels. (c) Classification map of SSRN (Overall Accuracy 97.64%) . (d) Classification map of SSTN (Overall Accuracy 98.95%) .

Result of Kennedy Space Center (KSC) Dataset

Fig.3 Classification maps of different models with 200 samples for training on the KSC dataset. (a) False color image. (b) Ground truth labels. (c) Classification map of SSRN (Overall Accuracy 96.60%) . (d) Classification map of SSTN (Overall Accuracy 97.66%) .

Result of Indian Pines (IN) dataset

Fig.4 Classification maps of different models with 200 samples for training on the IN dataset. (a) False color image. (b) Ground truth labels. (c) Classification map of SSRN (Overall Accuracy 91.75%) . (d) Classification map of SSTN (Overall Accuracy 94.78%).

Result of University of Pavia (UP) dataset

Fig.5 Classification maps of different models with 200 samples for training on the UP dataset. (a) False color image. (b) Ground truth labels. (c) Classification map of SSRN (Overall Accuracy 95.09%) . (d) Classification map of SSTN (Overall Accuracy 98.21%).

Reference

  1. Tensorflow implementation of SSRN: https://github.com/zilongzhong/SSRN.
  2. Auto-CNN-HSI-Classification: https://github.com/YushiChen/Auto-CNN-HSI-Classification
Owner
Zilong Zhong
PhD in Machine Learning and Intelligence at the Department of Systems Design Engineering, University of Waterloo
Zilong Zhong
Official Repsoitory for "Mish: A Self Regularized Non-Monotonic Neural Activation Function" [BMVC 2020]

Mish: Self Regularized Non-Monotonic Activation Function BMVC 2020 (Official Paper) Notes: (Click to expand) A considerably faster version based on CU

Xa9aX ツ 1.2k Dec 29, 2022
K-FACE Analysis Project on Pytorch

Installation Setup with Conda # create a new environment conda create --name insightKface python=3.7 # or over conda activate insightKface #install t

Jung Jun Uk 7 Nov 10, 2022
一套完整的微博舆情分析流程代码,包括微博爬虫、LDA主题分析和情感分析。

已经将项目的关键文件上传,包含微博爬虫、LDA主题分析和情感分析三个部分。 1.微博爬虫 实现微博评论爬取和微博用户信息爬取,一天大概十万条。 2.LDA主题分析 实现文档主题抽取,包括数据清洗及分词、主题数的确定(主题一致性和困惑度)和最优主题模型的选择(暴力搜索)。 3.情感分析 实现评论文本的

182 Jan 02, 2023
Generates all variables from your .tf files into a variables.tf file.

tfvg Generates all variables from your .tf files into a variables.tf file. It searches for every var.variable_name in your .tf files and generates a v

1 Dec 01, 2022
「PyTorch Implementation of AnimeGANv2」を用いて、生成した顔画像を元の画像に上書きするデモ

AnimeGANv2-Face-Overlay-Demo PyTorch Implementation of AnimeGANv2を用いて、生成した顔画像を元の画像に上書きするデモです。

KazuhitoTakahashi 21 Oct 18, 2022
FFTNet vocoder implementation

Unofficial Implementation of FFTNet vocode paper. implement the model. implement tests. overfit on a single batch (sanity check). linearize weights fo

Eren Gölge 81 Dec 08, 2022
Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

VT-UNet This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet. Environmen

Himashi Amanda Peiris 114 Dec 20, 2022
GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors

GPU implementation of kNN and SNN GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors Supported by numba cuda and faiss library E

Hyeon Jeon 7 Nov 23, 2022
Draw like Bob Ross using the power of Neural Networks (With PyTorch)!

Draw like Bob Ross using the power of Neural Networks! (+ Pytorch) Learning Process Visualization Getting started Install dependecies Requires python3

Kendrick Tan 116 Mar 07, 2022
Source Code for DialogBERT: Discourse-Aware Response Generation via Learning to Recover and Rank Utterances (https://arxiv.org/pdf/2012.01775.pdf)

DialogBERT This is a PyTorch implementation of the DialogBERT model described in DialogBERT: Neural Response Generation via Hierarchical BERT with Dis

Xiaodong Gu 67 Jan 06, 2023
Development kit for MIT Scene Parsing Benchmark

Development Kit for MIT Scene Parsing Benchmark [NEW!] Our PyTorch implementation is released in the following repository: https://github.com/hangzhao

MIT CSAIL Computer Vision 424 Dec 01, 2022
Learning Chinese Character style with conditional GAN

zi2zi: Master Chinese Calligraphy with Conditional Adversarial Networks Introduction Learning eastern asian language typefaces with GAN. zi2zi(字到字, me

Yuchen Tian 2.2k Jan 02, 2023
Using multidimensional LSTM neural networks to create a forecast for Bitcoin price

Multidimensional LSTM BitCoin Time Series Using multidimensional LSTM neural networks to create a forecast for Bitcoin price. For notes around this co

Jakob Aungiers 318 Dec 14, 2022
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

22 Sep 22, 2022
Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

causal-bald | Abstract | Installation | Example | Citation | Reproducing Results DUE An implementation of the methods presented in Causal-BALD: Deep B

OATML 13 Oct 07, 2022
Pytorch implementation of CoCon: A Self-Supervised Approach for Controlled Text Generation

COCON_ICLR2021 This is our Pytorch implementation of COCON. CoCon: A Self-Supervised Approach for Controlled Text Generation (ICLR 2021) Alvin Chan, Y

alvinchangw 79 Dec 18, 2022
LRBoost is a scikit-learn compatible approach to performing linear residual based stacking/boosting.

LRBoost is a sckit-learn compatible package for linear residual boosting. LRBoost combines a linear estimator and a non-linear estimator to leverage t

Andrew Patton 5 Nov 23, 2022
Self-supervised Multi-modal Hybrid Fusion Network for Brain Tumor Segmentation

JBHI-Pytorch This repository contains a reference implementation of the algorithms described in our paper "Self-supervised Multi-modal Hybrid Fusion N

FeiyiFANG 5 Dec 13, 2021
Final report with code for KAIST Course KSE 801.

Orthogonal collocation is a method for the numerical solution of partial differential equations

Chuanbo HUA 4 Apr 06, 2022