当前位置:网站首页>Word2Vec Skip-gram 模型实现
Word2Vec Skip-gram 模型实现
2022-07-16 22:19:00 【哇咔咔负负得正】
Word2Vec Skip-gram 模型
输入:中间词的词向量
输出:预测上下文词是哪个(或哪组)

核心思想就是通过中间词预测上下文。
0. 导包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA, TruncatedSVD
1. 语料库, 词典和 Id2Word
# 语料库
sentences = ["i like dog", "i like cat", "i like animal",
"dog cat animal", "apple cat dog like", "cat like fish",
"dog like meat", "i like apple", "i hate apple",
"i like movie book music apple", "dog like bark", "dog friend cat"]
word_sequence = ' '.join(sentences).split()
# 词表
word_list = list(set(word_sequence))
# 词典 word2Id
word_dict = {
word: i for i, word in enumerate(word_list)}
# Id2Word
id_dict = {
i: word for i, word in enumerate(word_list)}
word_dict, id_dict 实现了从词到 id 相互映射:
word_dict, id_dict

2. 训练样本
输入为中间词,标签为上下文的词。这里限定上下文为上一个词和下一个词。
skip_grams = []
# 从 1 开始,因为 0 前面没词,到倒数第 2 个结束,因为倒数第 1 个后面没词
for i in range(1, len(word_sequence) - 1):
input = word_dict[word_sequence[i]] # 对应的 id
context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]] # 上下文两个 id
for w in context:
skip_grams.append([input, w])
skip_grams[index][0] 表示输入,skip_grams[index][1] 表示标签。
3. 网络模型
# 记录词表大小
voc_size = len(word_dict)
# 表示用多少维向量表示一个词
embedding_size = 4
# 网络模型
class Word2Vec(nn.Module):
def __init__(self):
super(Word2Vec, self).__init__()
self.embedding = nn.Embedding(voc_size, embedding_size)
self.linear = nn.Linear(embedding_size, voc_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear(x)
return x
4. 读取 batch
def random_batch(data, batch_size):
input_batch = []
target_batch = []
# 随机选择 batch, batch_size 为 4, replace=False 表示每次都取不同的值
random_index = np.random.choice(range(len(data)), batch_size, replace=False)
for i in random_index:
input_batch.append(data[i][0])
target_batch.append(data[i][1])
return input_batch, target_batch
5. 定义损失函数,优化器和 batch_size
model = Word2Vec()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
# 批量大小
batch_size = 4
6. 训练
for epoch in range(50000):
input_batch, target_batch = random_batch(skip_grams, batch_size)
# Embedding 要求输入为 torch.LongTensor 类型, 损失函数要求 target 为 torch.LongTensor 类型
input_batch, target_batch = torch.Tensor(input_batch).type(torch.LongTensor), torch.Tensor(target_batch).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input_batch)
loss = criterion(output, target_batch)
if (epoch + 1) % 10000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.item()))
loss.backward()
optimizer.step()

7. 取出在 Embedding 层训练出来的词向量
# 取出 Embedding 层德参数
word_vector = (list(model.parameters())[0])
word_vector

8. 降到二维输出到平面坐标
# PCA 适合密集矩阵, TruncatedSVD 适合稀疏矩阵
# 降维
X_embedded = TruncatedSVD(n_components=2).fit_transform(word_vector.detach().numpy())
X_embedded

for i, label in enumerate(word_list):
x, y = float(X_embedded[i][0]), float(X_embedded[i][1])
plt.scatter(x, y)
plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
plt.show()

边栏推荐
- Google icons library compose can be used directly.
- Wechat applet_ 15. Pure data field
- Vs publish websites using webdeploy
- Redis distributed cache - Data Persistence
- bufferbloat 与通货膨胀
- 單行文本 超出部分省略號,多行文本超出部分省略號,指定多行
- Glide 源码分析(4.13.2)
- MySQL transaction isolation mechanism
- Configuring VIM (4) from scratch -- some skills of keyboard mapping
- 在安装虚拟机时,”intel vt-x 处于禁用状态“ 如何解决
猜你喜欢

Redis has three modes -- master-slave replication, sentinel mode, and cluster

nodeJS中使用promise实现文件读取、写入的案例

MySQL locking mechanism

OpenCV DFT

账号创建+登录+联系表单代码

A set of simple multipurpose form widget code

How to use redis to realize distributed cache

Core principle of buffer pool

此主机支持Intel VT-x ,但Intel VT-x处于禁用状态

Openpose: real time multiplayer 2D pose estimation using partial affinity fields
随机推荐
数据接入平台(DIP)系列文章之一|功能及架构浅析
关于Go Modules环境搭建和包管理工具的使用
SAP AppGyver 的 Universal Theme System 使用介绍
账号创建+登录+联系表单代码
Redis三种模式——主从复制、哨兵模式、集群
AtCoder Beginner Contest 259 D Circumferences
A set of simple multipurpose form widget code
Maximum subarray XOR and
fiddler抓不到PC端微信小程序的包
AtCoder Beginner Contest 259 E - LCM on Whiteboard
20220707 线程学习 继承
Envoy生命周期管理
15. Sum of three numbers [list < list < integer> > ans, ans.add (arrays.aslist (num[i], num[j], num[k]))]
nodeJS中使用promise实现文件读取、写入的案例
How to use redis to realize distributed cache
Compose gradient
Account creation + login + contact form code
云平台简易上手操作
Compose uses coil to load network pictures
[STL] simulation implementation vector