当前位置:网站首页>Word2Vec Skip-gram 模型實現
Word2Vec Skip-gram 模型實現
2022-07-19 00:02:00 【哇哢哢負負得正】
Word2Vec Skip-gram 模型
輸入:中間詞的詞向量
輸出:預測上下文詞是哪個(或哪組)

核心思想就是通過中間詞預測上下文。
0. 導包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA, TruncatedSVD
1. 語料庫, 詞典和 Id2Word
# 語料庫
sentences = ["i like dog", "i like cat", "i like animal",
"dog cat animal", "apple cat dog like", "cat like fish",
"dog like meat", "i like apple", "i hate apple",
"i like movie book music apple", "dog like bark", "dog friend cat"]
word_sequence = ' '.join(sentences).split()
# 詞錶
word_list = list(set(word_sequence))
# 詞典 word2Id
word_dict = {
word: i for i, word in enumerate(word_list)}
# Id2Word
id_dict = {
i: word for i, word in enumerate(word_list)}
word_dict, id_dict 實現了從詞到 id 相互映射:
word_dict, id_dict

2. 訓練樣本
輸入為中間詞,標簽為上下文的詞。這裏限定上下文為上一個詞和下一個詞。
skip_grams = []
# 從 1 開始,因為 0 前面沒詞,到倒數第 2 個結束,因為倒數第 1 個後面沒詞
for i in range(1, len(word_sequence) - 1):
input = word_dict[word_sequence[i]] # 對應的 id
context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]] # 上下文兩個 id
for w in context:
skip_grams.append([input, w])
skip_grams[index][0] 錶示輸入,skip_grams[index][1] 錶示標簽。
3. 網絡模型
# 記錄詞錶大小
voc_size = len(word_dict)
# 錶示用多少維向量錶示一個詞
embedding_size = 4
# 網絡模型
class Word2Vec(nn.Module):
def __init__(self):
super(Word2Vec, self).__init__()
self.embedding = nn.Embedding(voc_size, embedding_size)
self.linear = nn.Linear(embedding_size, voc_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear(x)
return x
4. 讀取 batch
def random_batch(data, batch_size):
input_batch = []
target_batch = []
# 隨機選擇 batch, batch_size 為 4, replace=False 錶示每次都取不同的值
random_index = np.random.choice(range(len(data)), batch_size, replace=False)
for i in random_index:
input_batch.append(data[i][0])
target_batch.append(data[i][1])
return input_batch, target_batch
5. 定義損失函數,優化器和 batch_size
model = Word2Vec()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
# 批量大小
batch_size = 4
6. 訓練
for epoch in range(50000):
input_batch, target_batch = random_batch(skip_grams, batch_size)
# Embedding 要求輸入為 torch.LongTensor 類型, 損失函數要求 target 為 torch.LongTensor 類型
input_batch, target_batch = torch.Tensor(input_batch).type(torch.LongTensor), torch.Tensor(target_batch).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input_batch)
loss = criterion(output, target_batch)
if (epoch + 1) % 10000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.item()))
loss.backward()
optimizer.step()

7. 取出在 Embedding 層訓練出來的詞向量
# 取出 Embedding 層德參數
word_vector = (list(model.parameters())[0])
word_vector

8. 降到二維輸出到平面坐標
# PCA 適合密集矩陣, TruncatedSVD 適合稀疏矩陣
# 降維
X_embedded = TruncatedSVD(n_components=2).fit_transform(word_vector.detach().numpy())
X_embedded

for i, label in enumerate(word_list):
x, y = float(X_embedded[i][0]), float(X_embedded[i][1])
plt.scatter(x, y)
plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
plt.show()

边栏推荐
- 用cmd命令进行磁盘清理(主要是系统盘)
- OpenPose:使用部分親和場的實時多人 2D 姿勢估計
- leetcode-7 整数反转&&leetcode8 字符串转换整数
- P5js coffee cup steaming JS special effect
- Redis distributed cache - redis master-slave
- 此主机支持Intel VT-x ,但Intel VT-x处于禁用状态
- shell第三天小练习 通过自搭建dns服务器访问自搭建nextcloud网盘服务
- Redis分布式缓存-数据持久化
- TCP 糊涂窗口综合症(silly window syndrome)与 rate-based 流控
- libtorch cmake
猜你喜欢

如何使用redis实现分布式缓存

11. Find out the distribution of JVM objects

每日一题:回文链表(剑指off027)

Delphi 485 Modbus RTU reading and writing RFID example source code

Vmware6.0 connection Qunhui iSCSI

About the use of go modules environment construction and package management tools

Fiddler cannot catch the package of wechat applet on PC

In depth explanation of MySQL index

MySQL安装常见报错怎么处理

数据接入平台(DIP)系列文章之一|功能及架构浅析
随机推荐
Redis分布式缓存-Redis主从
Summary of binary search questions
Introduction to the universal theme system of SAP appgyver
响应式表单样式透明设计
推荐一个讲即时通信的博客
20220707 thread learning inheritance
应用的无状态设计
Software test interview (II)
Fiddler cannot catch the package of wechat applet on PC
毕业四年换了3份软件测试工作,我为何依然焦虑?
Maximum subarray XOR and
VS使用WebDeploy发布网站
nodeJS中对Promise模块介绍
Programming examples of stm32f1 and stm32cube ide-w25q-spi-flash and SPIFs porting
2022 latest Chinese Camtasia studio computer recording screen tool
Applet: the picker view selector scrolls quickly. When confirming, the "value displays an error."“
A set of simple multipurpose form widget code
Vmware6.0 connection Qunhui iSCSI
單行文本 超出部分省略號,多行文本超出部分省略號,指定多行
Study with passion