利用Tensorflow实现基于CNN的中文短文本分类

Overview

Text Classification with CNN

使用卷积神经网络进行中文文本分类

CNN做句子分类的论文可以参看: Convolutional Neural Networks for Sentence Classification

还可以去读dennybritz大牛的博客:Implementing a CNN for Text Classification in TensorFlow

以及字符级CNN的论文:Character-level Convolutional Networks for Text Classification

本文是基于TensorFlow在中文数据集上的简化实现,使用了字符级CNN对中文文本进行分类,达到了较好的效果。

文中所使用的Conv1D与论文中有些不同,详细参考官方文档:tf.nn.conv1d

环境

  • Python 2/3
  • TensorFlow 1.3以上(我的是2.x)
  • numpy
  • scikit-learn
  • scipy

数据集

使用THUCNews数据集的一个子集进行训练与测试,数据集可在THUCTC:一个高效的中文文本分类工具包下载,请遵循数据提供方的开源协议。

本次训练使用了其中的10个分类,每个分类6500条数据。

类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

这个子集可以在此下载:链接: https://pan.baidu.com/s/1hugrfRu 密码: qfud

数据集划分如下:

  • 训练集: 5000 x 10
  • 验证集: 500 x 10
  • 测试集: 1000 x 10

从原数据集生成子集的过程请参看helper下的两个脚本。其中,copy_data.sh用于从每个分类拷贝6500个文件,cnews_group.py用于将多个文件整合到一个文件中。执行该文件后,得到三个数据文件:

  • cnews.train.txt: 训练集(50000条)
  • cnews.val.txt: 验证集(5000条)
  • cnews.test.txt: 测试集(10000条)

预处理

data/cnews_loader.py为数据的预处理文件。

  • read_file(): 读取文件数据。
  • build_vocab(): 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理。
  • read_vocab(): 读取上一步存储的词汇表,转换为{词:id}表示。
  • read_category(): 将分类目录固定,转换为{类别: id}表示。
  • to_words(): 将一条由id表示的数据重新转换为文字。
  • process_file(): 将数据集从文字转换为固定长度的id序列表示。
  • batch_iter(): 为神经网络的训练准备经过shuffle的批次的数据。

经过数据预处理,数据的格式如下:

Data Shape Data Shape
x_train [50000, 600] y_train [50000, 10]
x_val [5000, 600] y_val [5000, 10]
x_test [10000, 600] y_test [10000, 10]

CNN卷积神经网络

配置项

CNN可配置的参数如下所示,在cnn_model.py中。

class TCNNConfig(object):
    """CNN配置参数"""

    embedding_dim = 64      # 词向量维度
    seq_length = 600        # 序列长度
    num_classes = 10        # 类别数
    num_filters = 128       # 卷积核数目
    kernel_size = 5         # 卷积核尺寸
    vocab_size = 5000       # 词汇表达小

    hidden_dim = 128        # 全连接层神经元数目

    dropout_keep_prob = 0.5 # dropout正则化保留比例
    learning_rate = 1e-3    # 学习率

    batch_size = 64         # 每批训练大小
    num_epochs = 10         # 总迭代轮次

    print_per_batch = 100   # 每多少轮输出一次结果
    save_per_batch = 10     # 每多少轮存入tensorboard

CNN模型

具体参看cnn_model.py的实现。

大致结构如下:

image-20211110151539493

训练与验证

用cmd命令在代码文件所在目录运行 python run_cnn.py train,可以开始训练。

若之前进行过训练,请把tensorboard/textcnn删除,避免TensorBoard多次训练结果重叠。

Configuring CNN model...
Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:14
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:   8.92%, Time: 0:00:01 *
Iter:    100, Train Loss:   0.88, Train Acc:  73.44%, Val Loss:    1.2, Val Acc:  68.46%, Time: 0:00:04 *
Iter:    200, Train Loss:   0.38, Train Acc:  92.19%, Val Loss:   0.75, Val Acc:  77.32%, Time: 0:00:07 *
Iter:    300, Train Loss:   0.22, Train Acc:  92.19%, Val Loss:   0.46, Val Acc:  87.08%, Time: 0:00:09 *
Iter:    400, Train Loss:   0.24, Train Acc:  90.62%, Val Loss:    0.4, Val Acc:  88.62%, Time: 0:00:12 *
Iter:    500, Train Loss:   0.16, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  90.38%, Time: 0:00:15 *
Iter:    600, Train Loss:  0.084, Train Acc:  96.88%, Val Loss:   0.35, Val Acc:  91.36%, Time: 0:00:17 *
Iter:    700, Train Loss:   0.21, Train Acc:  93.75%, Val Loss:   0.26, Val Acc:  92.58%, Time: 0:00:20 *
Epoch: 2
Iter:    800, Train Loss:   0.07, Train Acc:  98.44%, Val Loss:   0.24, Val Acc:  94.12%, Time: 0:00:23 *
Iter:    900, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.27, Val Acc:  92.86%, Time: 0:00:25
Iter:   1000, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  92.82%, Time: 0:00:28
Iter:   1100, Train Loss:    0.2, Train Acc:  93.75%, Val Loss:   0.23, Val Acc:  93.26%, Time: 0:00:31
Iter:   1200, Train Loss:  0.081, Train Acc:  98.44%, Val Loss:   0.25, Val Acc:  92.96%, Time: 0:00:33
Iter:   1300, Train Loss:  0.052, Train Acc: 100.00%, Val Loss:   0.24, Val Acc:  93.58%, Time: 0:00:36
Iter:   1400, Train Loss:    0.1, Train Acc:  95.31%, Val Loss:   0.22, Val Acc:  94.12%, Time: 0:00:39
Iter:   1500, Train Loss:   0.12, Train Acc:  98.44%, Val Loss:   0.23, Val Acc:  93.58%, Time: 0:00:41
Epoch: 3
Iter:   1600, Train Loss:    0.1, Train Acc:  96.88%, Val Loss:   0.26, Val Acc:  92.34%, Time: 0:00:44
Iter:   1700, Train Loss:  0.018, Train Acc: 100.00%, Val Loss:   0.22, Val Acc:  93.46%, Time: 0:00:47
Iter:   1800, Train Loss:  0.036, Train Acc: 100.00%, Val Loss:   0.28, Val Acc:  92.72%, Time: 0:00:50
No optimization for a long time, auto-stopping...

在验证集上的最佳效果为94.12%,且只经过了3轮迭代就已经停止。

准确率和误差如图所示:

accuracy_1

测试

用cmd命令在代码文件所在目录下运行 python run_cnn.py test 在测试集上进行测试。

Configuring CNN model...
Loading test data...
Testing...
Test Loss:   0.14, Test Acc:  96.04%
Precision, Recall and F1-Score...
             precision    recall  f1-score   support

         体育       0.99      0.99      0.99      1000
         财经       0.96      0.99      0.97      1000
         房产       1.00      1.00      1.00      1000
         家居       0.95      0.91      0.93      1000
         教育       0.95      0.89      0.92      1000
         科技       0.94      0.97      0.95      1000
         时尚       0.95      0.97      0.96      1000
         时政       0.94      0.94      0.94      1000
         游戏       0.97      0.96      0.97      1000
         娱乐       0.95      0.98      0.97      1000

avg / total       0.96      0.96      0.96     10000

Confusion Matrix...
[[991   0   0   0   2   1   0   4   1   1]
 [  0 992   0   0   2   1   0   5   0   0]
 [  0   1 996   0   1   1   0   0   0   1]
 [  0  14   0 912   7  15   9  29   3  11]
 [  2   9   0  12 892  22  18  21  10  14]
 [  0   0   0  10   1 968   4   3  12   2]
 [  1   0   0   9   4   4 971   0   2   9]
 [  1  16   0   4  18  12   1 941   1   6]
 [  2   4   1   5   4   5  10   1 962   6]
 [  1   0   1   6   4   3   5   0   1 979]]
Time usage: 0:00:05

在测试集上的准确率达到了96.04%,且各类的precision, recall和f1-score都超过了0.9。

损失函数变化如图所示:

loss

从混淆矩阵也可以看出分类效果非常优秀。

预测

为方便预测,predict.py 展示了一个简单demo的预测。

Owner
Jeremiah
如今的现在早已不是当初的未来、
Jeremiah
Learning cell communication from spatial graphs of cells

ncem Features Repository for the manuscript Fischer, D. S., Schaar, A. C. and Theis, F. Learning cell communication from spatial graphs of cells. 2021

Theis Lab 77 Dec 30, 2022
Repo for the Video Person Clustering dataset, and code for the associated paper

Video Person Clustering Repo for the Video Person Clustering dataset, and code for the associated paper. This reporsitory contains the Video Person Cl

Andrew Brown 47 Nov 02, 2022
Nest Protect integration for Home Assistant. This will allow you to integrate your smoke, heat, co and occupancy status real-time in HA.

Nest Protect integration for Home Assistant Custom component for Home Assistant to interact with Nest Protect devices via an undocumented and unoffici

Mick Vleeshouwer 175 Dec 29, 2022
v objective diffusion inference code for JAX.

v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models

Katherine Crowson 186 Dec 21, 2022
An unofficial personal implementation of UM-Adapt, specifically to tackle joint estimation of panoptic segmentation and depth prediction for autonomous driving datasets.

Semisupervised Multitask Learning This repository is an unofficial and slightly modified implementation of UM-Adapt[1] using PyTorch. This code primar

Abhinav Atrishi 11 Nov 25, 2022
Deep Learning pipeline for motor-imagery classification.

BCI-ToolBox 1. Introduction BCI-ToolBox is deep learning pipeline for motor-imagery classification. This repo contains five models: ShallowConvNet, De

DongHee 18 Oct 31, 2022
Politecnico of Turin Thesis: "Implementation and Evaluation of an Educational Chatbot based on NLP Techniques"

THESIS_CAIRONE_FIORENTINO Politecnico of Turin Thesis: "Implementation and Evaluation of an Educational Chatbot based on NLP Techniques" GENERATE TOKE

cairone_fiorentino97 1 Dec 10, 2021
A PyTorch implementation of "Signed Graph Convolutional Network" (ICDM 2018).

SGCN ⠀ A PyTorch implementation of Signed Graph Convolutional Network (ICDM 2018). Abstract Due to the fact much of today's data can be represented as

Benedek Rozemberczki 251 Nov 30, 2022
Official PyTorch implementation of SyntaSpeech (IJCAI 2022)

SyntaSpeech: Syntax-Aware Generative Adversarial Text-to-Speech | | | | 中文文档 This repository is the official PyTorch implementation of our IJCAI-2022

Zhenhui YE 116 Nov 24, 2022
Course on computational design, non-linear optimization, and dynamics of soft systems at UIUC.

Computational Design and Dynamics of Soft Systems · This is a repository that contains the source code for generating the lecture notes, handouts, exe

Tejaswin Parthasarathy 4 Jul 21, 2022
Fully Convolutional Networks for Semantic Segmentation by Jonathan Long*, Evan Shelhamer*, and Trevor Darrell. CVPR 2015 and PAMI 2016.

Fully Convolutional Networks for Semantic Segmentation This is the reference implementation of the models and code for the fully convolutional network

Evan Shelhamer 3.2k Jan 08, 2023
Serving PyTorch 1.0 Models as a Web Server in C++

Serving PyTorch Models in C++ This repository contains various examples to perform inference using PyTorch C++ API. Run git clone https://github.com/W

Onur Kaplan 223 Jan 04, 2023
Network Enhancement implementation in pytorch

network_enahncement_pytorch Network Enhancement implementation in pytorch Research paper Network Enhancement: a general method to denoise weighted bio

Yen 1 Nov 12, 2021
RGB-stacking 🛑 🟩 🔷 for robotic manipulation

RGB-stacking 🛑 🟩 🔷 for robotic manipulation BLOG | PAPER | VIDEO Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes, Alex X. Lee*,

DeepMind 95 Dec 23, 2022
Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.

NuPIC Numenta Platform for Intelligent Computing The Numenta Platform for Intelligent Computing (NuPIC) is a machine intelligence platform that implem

Numenta 6.3k Dec 30, 2022
OCR-D wrapper for detectron2 based segmentation models

ocrd_detectron2 OCR-D wrapper for detectron2 based segmentation models Introduction Installation Usage OCR-D processor interface ocrd-detectron2-segm

Robert Sachunsky 13 Dec 06, 2022
HINet: Half Instance Normalization Network for Image Restoration

HINet: Half Instance Normalization Network for Image Restoration Liangyu Chen, Xin Lu, Jie Zhang, Xiaojie Chu, Chengpeng Chen Paper: https://arxiv.org

303 Dec 31, 2022
Official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space

NeuralFusion This is the official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space. We provide code to train the proposed pipel

53 Jan 01, 2023
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022
Sparse Progressive Distillation: Resolving Overfitting under Pretrain-and-Finetune Paradigm

Sparse Progressive Distillation: Resolving Overfitting under Pretrain-and-Finetu

3 Dec 05, 2022