这个例子会用泰戈尔的飞鸟集的著作来训练一个char-level RNN语言模型,同时使用它来生成莎士比亚风格的句子。
数据准备 输入文件是纯文本文件,我们会使用unidecode来把unicode转成ASCII文本。
import unidecode
import string
import random
import re
all_characters = string.printable
n_characters = len(all_characters)
file = unidecode.unidecode(open('date/flyaway-0.txt').read())
file_len = len(file)
print('file_len =', file_len)
file_len = 28297
chunk_len = 200
def random_chunk():
start_index = random.randint(0, file_len - chunk_len)
end_index = start_index + chunk_len + 1
return file[start_index:end_index]
rejoice in our fulness, then we can part with our fruits with joy. The raindrops kissed the earth and whispered, --- We are thy homesick children, mother, come back to thee from the heaven. The cobweb
回忆一下之前的Char RNN 分类器,我们是“手动”实现的最朴素的RNN,现在我们使用更加先进的GRU。 另外之前是没有Embedding的,直接用字母的one-hot作为输入。
import torch
import torch.nn as nn
from torch.autograd import Variable
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.encoder = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
self.decoder = nn.Linear(hidden_size, output_size)
def forward(self, input, hidden):
input = self.encoder(input.view(1, -1))
output, hidden = self.gru(input.view(1, 1, -1), hidden)
output = self.decoder(output.view(1, -1))
return output, hidden
def init_hidden(self):
return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
# 把string变成LongTensor
def char_tensor(string):
tensor = torch.zeros(len(string)).long()
for c in range(len(string)):
tensor[c] = all_characters.index(string[c])
return Variable(tensor)
Variable containing:
[torch.LongTensor of size 6]
def random_training_set():
chunk = random_chunk()
inp = char_tensor(chunk[:-1])
target = char_tensor(chunk[1:])
return inp, target
def evaluate(prime_str='A', predict_len=100, temperature=0.8):
hidden = decoder.init_hidden()
prime_input = char_tensor(prime_str)
predicted = prime_str
# 假设输入的前缀是字符串prime_str,先用它来改变隐状态
for p in range(len(prime_str) - 1):
_, hidden = decoder(prime_input[p], hidden)
inp = prime_input[-1]
for p in range(predict_len):
output, hidden = decoder(inp, hidden)
# 根据输出概率采样
output_dist = output.data.view(-1).div(temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0]
# 用上一个输出作为下一轮的输入
predicted_char = all_characters[top_i]
predicted += predicted_char
inp = char_tensor(predicted_char)
return predicted
import time, math
def time_since(since):
s = time.time() - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def train(inp, target):
hidden = decoder.init_hidden()
loss = 0
for c in range(chunk_len):
output, hidden = decoder(inp[c], hidden)
loss += criterion(output, target[c])
return loss.data[0] / chunk_len
n_epochs = 2000
print_every = 100
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005
decoder = RNN(n_characters, hidden_size, n_characters, n_layers)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
start = time.time()
all_losses = []
loss_avg = 0
for epoch in range(1, n_epochs + 1):
loss = train(*random_training_set())
loss_avg += loss
if epoch % print_every == 0:
print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss))
print(evaluate('Wh', 100), '\n')
if epoch % plot_every == 0:
all_losses.append(loss_avg / plot_every)
loss_avg = 0
[0m 10s (100 5%) 2.0145]
WhN pered. The sarin the ove in ahe wiand, lon fres, mows frou of the reart the is shellousts hollove
[0m 20s (200 10%) 1.8616]
Whel the deeuss night, the to to ter a but ams, lit heirt its meraght. Light hat somper flofe theep bu
[0m 30s (300 15%) 1.7634]
Whe man like shird's of thes acon ging thister pongs flongs tomy love, The but yest heer and pornach s
[0m 40s (400 20%) 1.7827]
Why world of arind world like joog and the cor me thy some and the mesel it is the growides, love me o
[0m 50s (500 25%) 1.7398]
When pain sormher pauth a meryvers the growny becomness my world, "I things. The sher shadd has with t
[1m 0s (600 30%) 1.6644]
When beames is lamk. I commrend. My hear in the world can in cons's vomening. The ames chan lauty wain
[1m 10s (700 35%) 1.6080]
Whe fiening wtard is the senst the stard bins. The earth life with the the saice of whe not waves of i
[1m 20s (800 40%) 1.4164]
Whent mand it like the leave the dosten withe whens me unower stars the sungs lime flowers and rose, c
[1m 30s (900 45%) 1.5041]
When your day satiwer morning to gonent. When on am the rust noth of light's thark and in wain. I voic
[1m 41s (1000 50%) 1.4494]
When my wishes foad of the dising death of the sprowd, Wome is and knor is not when I a know hidd, "It
[1m 51s (1100 55%) 1.6263]
When wor derlise the heart, "You world kisses of does my world but your summs the sunselt to men hide,
[2m 1s (1200 60%) 1.7908]
When the star, like the whoice of love! World, and everrive of liver silent in the drown raid of your
[2m 11s (1300 65%) 1.4474]
When my kise of the begad and we are are who where amone. I wave to her when whose dom, astine for my
[2m 21s (1400 70%) 1.4267]
Wher back and the saglless the neishess for the death. I spaimnes, water its me, my child sing? Death
[2m 31s (1500 75%) 1.7435]
Wher and souch the same at lote shame and all that the sky, your out and him your world flowers leshin
[2m 41s (1600 80%) 1.3867]
Wher when here are one sea, before light in alour its heart wind the name where which but of the sea w
[2m 51s (1700 85%) 1.5425]
Where death thy one and dechife, be fray light, for has sing to the demrancishil when huisting, do not
[3m 1s (1800 90%) 1.3357]
When heart heaven the glown no the sunny in trush--s hern has narrience of lover. Toed naiden in falli
[3m 11s (1900 95%) 1.4562]
Where it silence. The rowappses death. The rowed as the sand trut the stars flower is the earth into l
[3m 21s (2000 100%) 1.0370]
When the sun, silence my cloud in they love. The sun your flower of the earth my woute in the hereep o
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%matplotlib inline
[<matplotlib.lines.Line2D at 0x1f5ee7db5f8>]
# 混淆矩阵
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000
def evaluate(line_tensor):
hidden = rnn.init_hidden()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
return output
# 最好是有一个测试数据集,我们这里随机从训练数据里采样
for i in range(n_confusion):
category, line, category_tensor, line_tensor = random_training_pair()
output = evaluate(line_tensor)
guess, guess_i = category_from_output(output)
category_i = all_categories.index(category)
confusion[category_i][guess_i] += 1
# 归一化
for i in range(n_categories):
confusion[i] = confusion[i] / confusion[i].sum()
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
# 设置x轴的文字往上走
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
