SHARM Lighting for all

使用RNN来生成任意风格的句子

2018-05-12
飘的沙鸥
AI

准备¶

这个例子会用泰戈尔的飞鸟集的著作来训练一个char-level RNN语言模型,同时使用它来生成莎士比亚风格的句子。

数据准备 输入文件是纯文本文件,我们会使用unidecode来把unicode转成ASCII文本。

import unidecode
import string
import random
import re

all_characters = string.printable
n_characters = len(all_characters)

file = unidecode.unidecode(open('date/flyaway-0.txt').read())
file_len = len(file)
print('file_len =', file_len)
file_len = 28297

这个文件太大了,我们随机的进行截断。

chunk_len = 200

def random_chunk():
    start_index = random.randint(0, file_len - chunk_len)
    end_index = start_index + chunk_len + 1
    return file[start_index:end_index]

print(random_chunk())
rejoice in our fulness, then we can part with our fruits with joy. The raindrops kissed the earth and whispered, --- We are thy homesick children, mother, come back to thee from the heaven. The cobweb 

定义模型¶

回忆一下之前的Char RNN 分类器,我们是“手动”实现的最朴素的RNN,现在我们使用更加先进的GRU。 另外之前是没有Embedding的,直接用字母的one-hot作为输入。

import torch
import torch.nn as nn
from torch.autograd import Variable

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input, hidden):
        input = self.encoder(input.view(1, -1))
        output, hidden = self.gru(input.view(1, 1, -1), hidden)
        output = self.decoder(output.view(1, -1))
        return output, hidden

    def init_hidden(self):
        return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))

输入和输出

每个chunk会变成一个LongTensor,做法是遍历每一个字母然后把它变成all_characters里的下标。

# 把string变成LongTensor
def char_tensor(string):
    tensor = torch.zeros(len(string)).long()
    for c in range(len(string)):
        tensor[c] = all_characters.index(string[c])
    return Variable(tensor)

print(char_tensor('abcDEF'))
Variable containing:
 10
 11
 12
 39
 40
 41
[torch.LongTensor of size 6]

最后我们随机的选择一个字符串作为训练数据,输入是字符串的第一个字母到倒数第二个字母,而输出是从第二个字母到最后一个字母。比如字符串是”abc”,那么输入就是”ab”,输出是”bc”

def random_training_set():    
    chunk = random_chunk()
    inp = char_tensor(chunk[:-1])
    target = char_tensor(chunk[1:])
    return inp, target

生成句子¶

为了评估模型生成的效果,我们首先需要让它来生成一些句子

def evaluate(prime_str='A', predict_len=100, temperature=0.8):
    hidden = decoder.init_hidden()
    prime_input = char_tensor(prime_str)
    predicted = prime_str

    # 假设输入的前缀是字符串prime_str,先用它来改变隐状态
    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[p], hidden)
    inp = prime_input[-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)
        
        # 根据输出概率采样
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]
        
        # 用上一个输出作为下一轮的输入
        predicted_char = all_characters[top_i]
        predicted += predicted_char
        inp = char_tensor(predicted_char)

    return predicted

训练¶

一些工具函数

import time, math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

训练函数:

def train(inp, target):
    hidden = decoder.init_hidden()
    decoder.zero_grad()
    loss = 0

    for c in range(chunk_len):
        output, hidden = decoder(inp[c], hidden)
        loss += criterion(output, target[c])

    loss.backward()
    decoder_optimizer.step()

    return loss.data[0] / chunk_len

接下来我们定义训练的参数,初始化模型,开始训练:

n_epochs = 2000
print_every = 100
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005

decoder = RNN(n_characters, hidden_size, n_characters, n_layers)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

start = time.time()
all_losses = []
loss_avg = 0

for epoch in range(1, n_epochs + 1):
    loss = train(*random_training_set())       
    loss_avg += loss

    if epoch % print_every == 0:
        print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / n_epochs * 100, loss))
        print(evaluate('Wh', 100), '\n')

    if epoch % plot_every == 0:
        all_losses.append(loss_avg / plot_every)
        loss_avg = 0
[0m 10s (100 5%) 2.0145]
WhN pered. The sarin the ove in ahe wiand, lon fres, mows frou of the reart the is shellousts hollove  

[0m 20s (200 10%) 1.8616]
Whel the deeuss night, the to to ter a but ams, lit heirt its meraght. Light hat somper flofe theep bu 

[0m 30s (300 15%) 1.7634]
Whe man like shird's of thes acon ging thister pongs flongs tomy love, The but yest heer and pornach s 

[0m 40s (400 20%) 1.7827]
Why world of arind world like joog and the cor me thy some and the mesel it is the growides, love me o 

[0m 50s (500 25%) 1.7398]
When pain sormher pauth a meryvers the growny becomness my world, "I things. The sher shadd has with t 

[1m 0s (600 30%) 1.6644]
When beames is lamk. I commrend. My hear in the world can in cons's vomening. The ames chan lauty wain 

[1m 10s (700 35%) 1.6080]
Whe fiening wtard is the senst the stard bins. The earth life with the the saice of whe not waves of i 

[1m 20s (800 40%) 1.4164]
Whent mand it like the leave the dosten withe whens me unower stars the sungs lime flowers and rose, c 

[1m 30s (900 45%) 1.5041]
When your day satiwer morning to gonent. When on am the rust noth of light's thark and in wain. I voic 

[1m 41s (1000 50%) 1.4494]
When my wishes foad of the dising death of the sprowd, Wome is and knor is not when I a know hidd, "It 

[1m 51s (1100 55%) 1.6263]
When wor derlise the heart, "You world kisses of does my world but your summs the sunselt to men hide, 

[2m 1s (1200 60%) 1.7908]
When the star, like the whoice of love! World, and everrive of liver silent in the drown raid of your  

[2m 11s (1300 65%) 1.4474]
When my kise of the begad and we are are who where amone. I wave to her when whose dom, astine for my  

[2m 21s (1400 70%) 1.4267]
Wher back and the saglless the neishess for the death. I spaimnes, water its me, my child sing? Death  

[2m 31s (1500 75%) 1.7435]
Wher and souch the same at lote shame and all that the sky, your out and him your world flowers leshin 

[2m 41s (1600 80%) 1.3867]
Wher when here are one sea, before light in alour its heart wind the name where which but of the sea w 

[2m 51s (1700 85%) 1.5425]
Where death thy one and dechife, be fray light, for has sing to the demrancishil when huisting, do not 

[3m 1s (1800 90%) 1.3357]
When heart heaven the glown no the sunny in trush--s hern has narrience of lover. Toed naiden in falli 

[3m 11s (1900 95%) 1.4562]
Where it silence. The rowappses death. The rowed as the sand trut the stars flower is the earth into l 

[3m 21s (2000 100%) 1.0370]
When the sun, silence my cloud in they love. The sun your flower of the earth my woute in the hereep o 

损失变化图

查看损失在训练过程中的变化有助于我们了解学习的过程。

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%matplotlib inline

plt.figure()
plt.plot(all_losses)
[<matplotlib.lines.Line2D at 0x1f5ee7db5f8>]

png

测试

# 混淆矩阵
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000

def evaluate(line_tensor):
  hidden = rnn.init_hidden()

  for i in range(line_tensor.size()[0]):
     output, hidden = rnn(line_tensor[i], hidden)

     return output
     
# 最好是有一个测试数据集,我们这里随机从训练数据里采样
for i in range(n_confusion):
   category, line, category_tensor, line_tensor = random_training_pair()
   output = evaluate(line_tensor)
   guess, guess_i = category_from_output(output)
   category_i = all_categories.index(category)
   confusion[category_i][guess_i] += 1

# 归一化
for i in range(n_categories):
   confusion[i] = confusion[i] / confusion[i].sum()
   
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)

# 设置x轴的文字往上走
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

plt.show()
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-49-6e0dc5f695cf> in <module>()
      1 # 混淆矩阵
----> 2 confusion = torch.zeros(n_categories, n_categories)
      3 n_confusion = 10000
      4 
      5 def evaluate(line_tensor):


NameError: name 'n_categories' is not defined

评论Comments

Content