#coding:utf-8
\'\'\'
created on February 11 15:38 2019
@author:zhulma
\'\'\'
import paddle
import paddle.dataset.imdb as imdb
import paddle.fluid as fluid
import numpy as np
# 简单的循环神经网络
def rnn_net(ipt, input_dim):
# 将句子分词的IDs作为输入
emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True)
sentence = fluid.layers.fc(input=emb, size=128, act=\'tanh\')
# 循环神经网络块
rnn = fluid.layers.DynamicRNN()
with rnn.block():
word = rnn.step_input(sentence)
prev = rnn.memory(shape=[128])
hidden = fluid.layers.fc(input=[word, prev], size=128, act=\'relu\')
rnn.update_memory(prev, hidden)
rnn.output(hidden)
last = fluid.layers.sequence_last_step(rnn()) # 这个接口通常使用在序列函数最后一步
out = fluid.layers.fc(input=last, size=2, act=\'softmax\')
return out
# 长短期记忆网络
def lstm_net(ipt, input_dim):
# 将句子分词的词向量作为输入
emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True)
# 第一个全连接层
fc1 = fluid.layers.fc(input=emb, size=128)
# 进行一个长短期记忆操作
lstm1, _ = fluid.layers.dynamic_lstm(input=fc1, size=128)
# 第一个最大序列池操作
fc2 = fluid.layers.sequence_pool(input=fc1, pool_type=\'max\')
# 第二个最大序列池操作
lstm2 = fluid.layers.sequence_pool(input=lstm1, pool_type=\'max\')
# 以softmax作为全连接的输出层,大小为2,也就是只有正反面两种评价
out = fluid.layers.fc(input=[fc2, lstm2], size=2, act=\'softmax\')
return out
#定义输入数据,lod_level不为0指定输入数据为序列数据
words=fluid.layers.data(name=\'words\',shape=[1],dtype=\'int64\',lod_level=1)
label=fluid.layers.data(name=\'label\',shape=[1],dtype=\'int64\')
#读取数据字典
print(\'加载数据字典中...\')
word_dict=imdb.word_dict()
#获取数据字典长度
dict_dim=len(word_dict)
#获取长短期记忆网络
model=lstm_net(words,dict_dim)
#获取循环神经网络
#model=rnn_net(words,dict_dim)
#定义损失函数和准确率,分类问题还是使用交叉熵损失函数
cost=fluid.layers.cross_entropy(input=model,label=label)
avg_cost=fluid.layers.mean(cost)
accuracy=fluid.layers.accuracy(input=model,label=label)
#克隆一个测试程序
test_program=fluid.default_main_program().clone(for_test=True)
#选择优化器,这里选择Adagrad优化方法,这种优化方法多用于处理稀疏数据
optimizer=fluid.optimizer.AdagradOptimizer(learning_rate=0.002)
opt=optimizer.minimize(avg_cost)
#创建执行器,这次数据集很大,但是我的虚拟机也跑不了CUDA,有条件的就跑CUDA
place=fluid.CPUPlace()
#place=fluid.CUDAPlace(1)
exe=fluid.Executor(place)
#初始化参数
exe.run(fluid.default_startup_program())
#由于数据集比较大,为了加快数据的读取速度,使用paddle.reader.shuffle()先将数据按照设置的大小读入到缓存中,并且打乱顺序
#获取训练和测试数据
print("获取训练数据中...")
train_reader=paddle.batch(paddle.reader.shuffle(imdb.train(word_dict),25000),batch_size=128)
print("获取测试数据中...")
test_reader=paddle.batch(imdb.test(word_dict),batch_size=128)
#定义数据输入维度
feeder=fluid.DataFeeder(place=place,feed_list=[words,label])
#开始训练
for pass_id in range(1):#跑一遍我的虚拟机就够呛,但是为了收敛效果更好,就跑三遍吧
train_cost=0
for batch_id,data in enumerate(train_reader()):
train_cost=exe.run(program=fluid.default_main_program(),feed=feeder.feed(data),fetch_list=[avg_cost])
if batch_id%50==0:
print(\'Pass:%d,Batch:%d,Cost:%0.5f\'%(pass_id,batch_id,train_cost[0]))
#开始测试
test_costs=[]
test_accs=[]
for batch_id,data in enumerate(test_reader()):
test_cost,test_acc=exe.run(program=test_program,feed=feeder.feed(data),fetch_list=[avg_cost,accuracy])
test_costs.append(test_cost[0])
test_accs.append(test_acc[0])
#计算这次batch经过后,平均的损失值和准确率
test_cost=(sum(test_costs)/len(test_costs))
test_acc=(sum(test_accs)/len(test_accs))
print(\'Test:%d,Cost:%0.5f,acc:%0.5f\'%(pass_id,test_cost,test_acc))
#预测数据
#定义预测句子,第一个是中性的,第二个是偏向正面的,第三个是偏向负面
test_str=[\'I read the book\',\'I am so happy\',\'this is a bad movie\']
#将句子转化成一个个单词
reviews=[c.split() for c in test_str]
#将句子中的单词转换成字典中的标签
#字典中没有的单词一律标签为<unk>对应的标签
unk=word_dict[\'<unk>\']
#获取每句话对应的标签
lod=[]
for c in reviews:
#需要将单词进行字符串的utf-8编码转换,对于没见过的单词就使用<unk>的标签
lod.append([word_dict.get(words.encode(\'utf-8\'),unk) for words in c])
#获取输入数据维度的大小,换句话说就是获取每句话的单词数量
base_shape=[[len(c) for c in lod]]
#将想要预测的数据转换成张量,准备开始预测
tensor_words=fluid.create_lod_tensor(lod,base_shape,place)
#开始预测
results=exe.run(program=test_program,feed={\'words\':tensor_words,\'label\':np.array([[0],[0],[0]]).astype(\'int64\')},fetch_list=[model])
#打印每句话的正负面概率
for i,r in enumerate(results[0]):
print(\'%s的预测结果是:正面概率为:%0.5f,反面概率为:%0.5f\'%(test_str[i],r[0],r[1]))
本文链接:http://task.lmcjl.com/news/12612.html