没有合适的资源?快使用搜索试试~ 我知道了~
首页循环神经网络RNN实现手写数字识别
循环神经网络RNN实现手写数字识别
8 下载量 182 浏览量
更新于2023-03-16
评论
收藏 44KB PDF 举报
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets('mnist_data/',one_hot=True) #注意这里用了one_hot表示,标签的形状是(batch_size,num_batches),类型是float,如果不用one_hot,那么标签的形状是(batch_size,),类型是int num_classes=10 batch_size=64 hidden_dim1=32 hidden_dim2=64
资源详情
资源评论
资源推荐
循环神经网络循环神经网络RNN实现手写数字识别实现手写数字识别
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('mnist_data/',one_hot=True)
#注意这里用了one_hot表示,标签的形状是(batch_size,num_batches),类型是float,如果不用one_hot,那么标签的形状是(batch_size,),类型是int
num_classes=10
batch_size=64
hidden_dim1=32
hidden_dim2=64
epochs=10
embedding_dim=28
time_step=28
'''
这里解析一下tensorflow.nn.dynamic_rnn(cell,inputs,initial_state=None,dtype=None,time_major=False)
返回值有两个outputs,states.outputs的形状是[batch_size,time_step,cell.output_dim] 这个三维张量的第一个维度是batch_size,我们来看二三维度,相当于
每一个batch_size对应于一个二维矩阵,
二维矩阵的每一行长度是rnn的输出维度,代表在当前的batch_size中的当前的time_step的特征。
将outputs取转置并且在第一个维度上取最后一个元素[-1,batch_size,:],也就得到了每一个batch_size中所有数据的最后一个time_step,也就是
最后一个时刻的输出[batch_size,output_dim],
在语言模型中,time_step就是一个句子,所以这个张量[batch_size,output_dim]通常作为下一个时刻的输入,因为这个张量的每一行代表的意义是在
读取了整个句子后所得到的特征,这是很有意义的。如果将outputs取了转置后[time_step,batch_size,output_dim]对这个张量取[1,batch_size,output_dim]
那么现在得到的张量仅仅是在读取了句子中第一个单词后获得的输出,这样的张量如果作为下一层的输入那么模型的效果将惨不忍睹。
现在来讨论states,如果cell是LSTM类型的cell,那么states是一个有两个元素的元祖,为什么有两个元素呢,因为LSTM类型的
cell有两个状态
一个是cell state代表该神经元的细胞状态,另一个是hidden state代表该神经元的隐藏状态。
而且这两个状态都是最后一个时刻的神经元的状态,所以states的形状与time_step无关
这两个状态张量的形状都是[batch_size,output_dim] states[0]是cell的状态,states[1]是hidden的状态
所以states[1]与上面提到的提取每一个time_step的最后一个时刻得到的张量[batch_size,output_dim]是一样的,因为它们都是
每一个batch_size中
每一个数据的最后一个时刻的特征提取。
'''
class RNN_model:
def __init__(self):
tf.reset_default_graph()
def add_placeholder(self):
self.xs=tf.placeholder(shape=[None,784],dtype=tf.float32)
self.ys=tf.placeholder(shape=[None,10],dtype=tf.float32)
#由于用了one_hot表示,所以shape是(batch_size,10),dtype是float
#不用one_hot表示的话应该改为self.ys=tf.placeholder(shape=[None],dtype=tf.int32)
def rnn_layer(self):
rnn_input=tf.reshape(tensor=self.xs,shape=[-1,time_step,embedding_dim])
cell_1=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim1)
cell_2=tf.contrib.rnn.BasicLSTMCell(num_units=hidden_dim2)
cells=tf.contrib.rnn.MultiRNNCell([cell_1,cell_2])
initial_state=cells.zero_state(batch_size=batch_size,dtype=tf.float32)
outputs,states=tf.nn.dynamic_rnn(cells,rnn_input,initial_state=initial_state,time_major=False)
#outputs.shape==(batch_size,time_step,hidden_dim2)
#states表示的是batch_size里的每一个长度为time_step的数据的最后一个时刻状态,所以与time_step无关
#states[0][0].shape==states[0][1].shape==(batch_size,hidden_dim1)
#states[1][0].shape==states[1][1].shape==(batch_size,hidden_dim2)
#states[1][1]==tf.transpose(outputs,[1,0,2])[-1] outputs=tf.transpose(outputs,perm=[1,0,2])
self.rnn_output=outputs[-1]#(batch_size,hidden_dim2)
def output_layer(self):
weights=tf.Variable(tf.random_normal(shape=[hidden_dim2,num_classes],dtype=tf.float32))
biases=tf.Variable(tf.random_normal(shape=[num_classes],dtype=tf.float32))
self.predict=tf.matmul(self.rnn_output,weights)+biases#(batch_size,num_classes)
#注意softmax_cross_entropy_with_logits与sparse_softmax_cross_entropy_with_logits的区别
#前者的logits与labels必须是同样的shape(batch_size,num_classes)以及同样的dtype(float32)
#后者的logits形状是(batch_size,num_classes),dtype是float32,而labels的形状是(batch_size),dtype必须是int,每一个数值表
示logits的每一行数据属于哪一类
def loss_layer(self):
self.loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.predict,labels=self.ys))
self.train_op=tf.train.AdamOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.ys,1),tf.argmax(self.predict,1)),dtype=tf.float32))
def build_graph(self):
weixin_38620734
- 粉丝: 4
- 资源: 974
上传资源 快速赚钱
- 我的内容管理 收起
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
会员权益专享
最新资源
- RTL8188FU-Linux-v5.7.4.2-36687.20200602.tar(20765).gz
- c++校园超市商品信息管理系统课程设计说明书(含源代码) (2).pdf
- 建筑供配电系统相关课件.pptx
- 企业管理规章制度及管理模式.doc
- vb打开摄像头.doc
- 云计算-可信计算中认证协议改进方案.pdf
- [详细完整版]单片机编程4.ppt
- c语言常用算法.pdf
- c++经典程序代码大全.pdf
- 单片机数字时钟资料.doc
- 11项目管理前沿1.0.pptx
- 基于ssm的“魅力”繁峙宣传网站的设计与实现论文.doc
- 智慧交通综合解决方案.pptx
- 建筑防潮设计-PowerPointPresentati.pptx
- SPC统计过程控制程序.pptx
- SPC统计方法基础知识.pptx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0