keras 简单 lstm实例(基于one-hot编码)

1. 简介

Keras是一个基于Python的深度学习库,它支持多个后端,包括TensorFlow,CNTK和Theano等,并提供了一个高级别、简洁的API,使得深度学习在Python中变得更加容易。其中LSTM是一种循环神经网络,常用于自然语言处理等领域。本文将介绍在Keras中使用LSTM进行序列预测,并且基于one-hot编码对模型进行训练,代码实现部分使用Python语言。

2. 数据准备

2.1 读取数据

我们选择了一份Shakespeare的作品作为数据。其中每行数据表示一些字符序列。我们以字符序列中的前100个字符作为输入,作为模型的训练样本。代码如下:

import urllib.request

import numpy as np

url = 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt'

path = 'shakespeare.txt'

urllib.request.urlretrieve(url, path)

text = open(path, 'rb').read().decode(encoding='utf-8')

print('Length of text: {} characters'.format(len(text)))

input_len = 100

step = 1

sentences = []

next_chars = []

for i in range(0, len(text) - input_len, step):

sentences.append(text[i:i + input_len])

next_chars.append(text[i + input_len])

print('Number of sentences:', len(sentences))

这里通过循环遍历文本,从第0个字符开始,取出100个字符作为序列,以第101个字符作为每个序列的下一个字符,得到训练数据。此时得到的sequences即为我们模型的输入,labels则为对应的下一个字符,作为模型的输出。由于LSTM需要数值类型作为输入,所以我们需要进行one-hot编码。代码如下:

chars = sorted(list(set(text)))

char_indices = dict((char, chars.index(char)) for char in chars)

print('Number of distinct characters:', len(chars))

x = np.zeros((len(sentences), input_len, len(chars)), dtype=np.float32)

y = np.zeros((len(sentences), len(chars)), dtype=np.float32)

for i, sentence in enumerate(sentences):

for t, char in enumerate(sentence):

x[i, t, char_indices[char]] = 1

y[i, char_indices[next_chars[i]]] = 1

对于每个字符,我们将它映射为数字,即0-25(我们选择了只使用小写字母)。对于每个序列,我们将它和输出都进行one-hot编码,得到$x,y$两个数组。

3. LSTM模型

我们将使用Keras构建一个LSTM模型,来对文本进行处理,并预测下一个字符。代码如下:

import keras

from keras import layers

lstm_model = keras.Sequential([

layers.LSTM(128, input_shape=(input_len, len(chars))),

layers.Dense(len(chars), activation='softmax')

])

optimizer = keras.optimizers.RMSprop(lr=0.01)

lstm_model.compile(loss='categorical_crossentropy', optimizer=optimizer)

lstm_model.summary()

这里我们使用了一个简单的LSTM模型,其中有一个LSTM层,输入形状为(100 x 26),输出为(128 x 1)。紧跟着一个Dense层,将输出转换为各个字符的概率分布。优化器使用RMSprop,损失函数为交叉熵。下面我们来训练模型。

4. 模型训练和预测

在模型训练中,我们训练10个epochs,每个epoch都把整个数据集过一遍,即len(sentences)个step。代码如下:

epochs = 10

batch_size = 128

temperature = 0.6

for epoch in range(epochs):

lstm_model.fit(x, y, batch_size=batch_size, epochs=1)

print('=== Generating text with temperature: {} ==='.format(temperature))

start_index = np.random.randint(0, len(text) - input_len - 1)

generated_text = text[start_index:start_index + input_len]

print('--- Generating with seed: "{}"'.format(generated_text) + ' ---')

for i in range(400):

x_pred = np.zeros((1, input_len, len(chars)))

for t, char in enumerate(generated_text):

x_pred[0, t, char_indices[char]] = 1.

preds = lstm_model.predict(x_pred, verbose=0)[0]

next_char_index = sample_from_output(preds, temperature)

next_char = chars[next_char_index]

generated_text += next_char

generated_text = generated_text[1:]

print(next_char, end='')

这里我们选择了一个batch_size为128,共训练了10个epochs。在每个epoch训练结束后,我们都会使用temperature对模型进行采样,以生成新的文本。下面我们来看看sample_from_output函数是如何实现的。

4.1 sample_from_output函数

该函数用于基于模型的输出得到下一个字符。由于LSTM生成的输出为概率分布,因此我们需要根据temperature对该分布进行加权,提高随机性。代码如下:

def sample_from_output(preds, temperature):

exp_preds = np.exp(np.log(preds) / temperature)

preds = exp_preds / np.sum(exp_preds)

probas = np.random.multinomial(1, preds, 1)

return np.argmax(probas)

该函数中的temperature表示温度,值越高,softmax函数输出的概率分布越平坦,增加了随机性,生成的文本也更加随机。下面我们来看一下训练展示的效果。

5. 训练展示

以下是我们训练的结果。我们随机选择了一些文本进行输出,每个文本的前100个字符作为模型的输入。随着温度的升高,生成的文本越来越随机。同时,由于训练数据量较小,所以模型在生成长文本时会出现重复的情况。同时,我们也可以看到,当温度很高时,生成的文本并不具有实际意义。

--- Generating with seed: "ter Christmas;

he is, indeed; but he is fi" ---

ter Christmas;

he is, indeed; but he is five,

Aly-bel my grievenes balatchien

And their happiers han?

Here under; so cold the case of doing.

Have put out modesty I would

Thus houses, the carbodinability

Only more of its challacknd, our loss,

I stand I saw thee dare I live it? it did

It for his constancy; which I will take

Our fingers in ways of the plainances,

Himself to keep this tale

When mans for love of in pastinable

his there never I'll and death-

like phrases,--the tuh they--no;-

=== Generating text with temperature: 0.6 ===

--- Generating with seed: "t speedily attend me here: some league" ---

t speedily attend me here: some league!

Oth: For this time, sir?

Ban: I'll bring her to it: she shall go hard

Before she shall be work'd, I know,

All my dishonours ; and nothing in the world

Could make me jeep, and my thankfulness,

And I will never fear our neake is a good child

To thee condianed king.

Pro: So such it perfect....

Bern: how blood enmity and suffer death.

That hate the Moor: yet I'll keep mine shold.

And art your duty to the bower of death,

With so sweet a grace.

Kath: The moor!

Good Rudopify. Entic'd them in their hearts KN

6. 总结

本文介绍了如何使用Keras对文本进行序列预测,并且在训练时采用了one-hot编码对模型进行训练。通过使用LSTM,在对模型进行训练后,我们可以使用temperature进行生成随机文本,从而增加文本的多样性和独创性。同时,我们也发现,随着temperature的升高,生成的文本越来越具有随机性,而温度过高时会导致生成文本没有意义。在实际应用中,我们可以根据具体需求选择合适的temperature值,从而得到符合实际需求的文本。

后端开发标签