首页 编程教程正文

神经网络-手写数字识别学习笔记

piaodoo 编程教程 2020-02-22 22:04:51 950 0 python教程

本文来源吾爱破解论坛

之前弄机器学习一直跟有监督学习打交道 ,没怎么看重无监督学习,最近想要往nlp自然语言处理走一走,发现一篇文章《A Neural Conversational Model

所以自然绕不开seq2seq与LSTM,所以准备打下深度学习的基础,这两天正在看tensorflow,遂mark一下MNIST的手写识别代码
准确率到了0.98,代码还很不完善,比如隐藏层过多,迭代次数略微少点等,大家一起学习下吧
[Python] 纯文本查看 复制代码

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)

batch_size = 100
n_batch = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) #参加计算神经元占比,eg. 1.0表示100%
                                       #有效抑制过拟合:测试数据与训练数据差别不大             
Weight_L1 = tf.Variable(tf.truncated_normal([784, 500], stddev = 0.1))
biases_L1 = tf.Variable(tf.zeros([500]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, Weight_L1) + biases_L1)
L1_drop = tf.nn.dropout(L1, keep_prob)
lr = tf.Variable(0.001)

Weight_L2 = tf.Variable(tf.truncated_normal([500, 300], stddev = 0.1))
biases_L2 = tf.Variable(tf.zeros([300]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, Weight_L2) + biases_L2)
L2_drop = tf.nn.dropout(L2, keep_prob)

Weight_L3 = tf.Variable(tf.truncated_normal([300, 100], stddev = 0.1))
biases_L3 = tf.Variable(tf.zeros([100]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, Weight_L3) + biases_L3)
L3_drop = tf.nn.dropout(L3, keep_prob)

Weight_L4 = tf.Variable(tf.truncated_normal([100, 10], stddev = 0.1))
biases_L4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, Weight_L4) + biases_L4) #多分类

#损失函数
# loss = tf.reduce_mean(tf.square(y - prediction)) #二次代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y, logits = prediction)) #交叉熵
#优化器
# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #梯度下降
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #cast布尔转数值

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(31):
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch))) #迭代降低学习率
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict = {x:batch_xs, 
                                              y:batch_ys,
                                              keep_prob:1.0})
        test_acc = sess.run(accuracy, feed_dict = {x:mnist.test.images, 
                                              y:mnist.test.labels,
                                              keep_prob:1.0})
        train_acc = sess.run(accuracy, feed_dict = {x:mnist.train.images, 
                                              y:mnist.train.labels,
                                              keep_prob:1.0})
        learning_rate = sess.run(lr)
        print("Iter " + str(epoch) + 
              ", Testing Accuracy " + str(test_acc) + 
              ", Training Accuracy " + str(train_acc) + 
              ", Learning Rate " + str(learning_rate))


以下是迭代结果
[Python] 纯文本查看 复制代码
Iter 0, Testing Accuracy 0.9502, Training Accuracy 0.95552725, Learning Rate 0.001
Iter 1, Testing Accuracy 0.96, Training Accuracy 0.9675636, Learning Rate 0.00095
Iter 2, Testing Accuracy 0.9679, Training Accuracy 0.97810906, Learning Rate 0.0009025
Iter 3, Testing Accuracy 0.9707, Training Accuracy 0.98196363, Learning Rate 0.000857375
Iter 4, Testing Accuracy 0.9716, Training Accuracy 0.9841273, Learning Rate 0.00081450626
Iter 5, Testing Accuracy 0.9749, Training Accuracy 0.98634547, Learning Rate 0.0007737809
Iter 6, Testing Accuracy 0.9747, Training Accuracy 0.9884909, Learning Rate 0.0007350919
Iter 7, Testing Accuracy 0.9767, Training Accuracy 0.98965454, Learning Rate 0.0006983373
Iter 8, Testing Accuracy 0.9756, Training Accuracy 0.99114543, Learning Rate 0.0006634204
Iter 9, Testing Accuracy 0.9792, Training Accuracy 0.99325454, Learning Rate 0.0006302494
Iter 10, Testing Accuracy 0.9774, Training Accuracy 0.99332726, Learning Rate 0.0005987369
Iter 11, Testing Accuracy 0.9765, Training Accuracy 0.99332726, Learning Rate 0.0005688001
Iter 12, Testing Accuracy 0.9802, Training Accuracy 0.9935273, Learning Rate 0.0005403601
Iter 13, Testing Accuracy 0.9815, Training Accuracy 0.9952, Learning Rate 0.0005133421
Iter 14, Testing Accuracy 0.9777, Training Accuracy 0.99514544, Learning Rate 0.000487675
Iter 15, Testing Accuracy 0.9804, Training Accuracy 0.9954, Learning Rate 0.00046329122
Iter 16, Testing Accuracy 0.9813, Training Accuracy 0.9958909, Learning Rate 0.00044012666
Iter 17, Testing Accuracy 0.9802, Training Accuracy 0.9961636, Learning Rate 0.00041812033
Iter 18, Testing Accuracy 0.9766, Training Accuracy 0.99463636, Learning Rate 0.00039721432
Iter 19, Testing Accuracy 0.9809, Training Accuracy 0.99652725, Learning Rate 0.0003773536
Iter 20, Testing Accuracy 0.9758, Training Accuracy 0.99563634, Learning Rate 0.00035848594
Iter 21, Testing Accuracy 0.9807, Training Accuracy 0.99667275, Learning Rate 0.00034056162
Iter 22, Testing Accuracy 0.979, Training Accuracy 0.9961636, Learning Rate 0.00032353355
Iter 23, Testing Accuracy 0.9813, Training Accuracy 0.99692726, Learning Rate 0.00030735688
Iter 24, Testing Accuracy 0.9806, Training Accuracy 0.997, Learning Rate 0.000291989
Iter 25, Testing Accuracy 0.9799, Training Accuracy 0.9967818, Learning Rate 0.00027738957
Iter 26, Testing Accuracy 0.981, Training Accuracy 0.99725455, Learning Rate 0.0002635201
Iter 27, Testing Accuracy 0.9813, Training Accuracy 0.9972909, Learning Rate 0.00025034408
Iter 28, Testing Accuracy 0.9806, Training Accuracy 0.9973636, Learning Rate 0.00023782688
Iter 29, Testing Accuracy 0.981, Training Accuracy 0.9973818, Learning Rate 0.00022593554
Iter 30, Testing Accuracy 0.9812, Training Accuracy 0.99743634, Learning Rate 0.00021463877

版权声明:

本站所有资源均为站长或网友整理自互联网或站长购买自互联网,站长无法分辨资源版权出自何处,所以不承担任何版权以及其他问题带来的法律责任,如有侵权或者其他问题请联系站长删除!站长QQ754403226 谢谢。

有关影视版权:本站只供百度云网盘资源,版权均属于影片公司所有,请在下载后24小时删除,切勿用于商业用途。本站所有资源信息均从互联网搜索而来,本站不对显示的内容承担责任,如您认为本站页面信息侵犯了您的权益,请附上版权证明邮件告知【754403226@qq.com】,在收到邮件后72小时内删除。本文链接:https://www.piaodoo.com/7486.html

搜索