
RNN-Pytorch实现demo

easy demo
- 根据语料整理出字典dic
- 将所有的句子统一长度
- 给出input target
- 将input target对应dit中index
- one-hot编码 Input shape: (3, 14, 17) –> (Batch Size, Sequence Length, One-Hot Encoding Size)
- 训练model
- model的参数:1. input_size . hidden_dim . output_size batch_size
- 预测
from Limu
想象成一个长方体 长为时间步长 宽为vocab_size 高为batch_size =》 又相当于 每层都是一个batch 每个时间步 进去竖着的一个切片
batch_size = 32
num_steps = 35
num_hiddens = 256
num_epochs = 500
learning_rate = 1
加载数据
train_iter, vocab = load_data_time_machine(batch_size, num_steps)
train_iter
使用id表示了数据集
vocab 相当于字典的 len=28 包括26个字母和空格和unk 包括词频
以上是包括了文章所有的数据
batch_size = 32
num_steps = 35
num_hiddens = 512
num_epochs = 500
lr = 1
use_random_iter = True
train_iter, vocab = load_data_time_machine(batch_size, num_steps,
use_random_iter=use_random_iter)
net = RNNModelScratch(len(vocab), num_hiddens, device,
get_params, init_rnn_state, gru)
train(net, train_iter, vocab, lr, num_epochs, device,
use_random_iter=use_random_iter)
plt.show()
- Title: RNN-Pytorch实现demo
- Author: Jason
- Created at : 2023-09-22 08:49:07
- Updated at : 2023-09-27 18:08:30
- Link: https://xxxijason1201.github.io/2023/09/22/NLP/CS224n/RNN/
- License: This work is licensed under CC BY-NC-SA 4.0.
Comments