本文来源吾爱破解论坛
之前弄机器学习一直跟有监督学习打交道 ,没怎么看重无监督学习,最近想要往nlp自然语言处理走一走,发现一篇文章《A Neural Conversational Model
》
所以自然绕不开seq2seq与LSTM,所以准备打下深度学习的基础,这两天正在看tensorflow,遂mark一下MNIST的手写识别代码
准确率到了0.98,代码还很不完善,比如隐藏层过多,迭代次数略微少点等,大家一起学习下吧
[Python] 纯文本查看 复制代码
import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot = True) batch_size = 100 n_batch = mnist.train.num_examples // batch_size x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) keep_prob = tf.placeholder(tf.float32) #参加计算神经元占比,eg. 1.0表示100% #有效抑制过拟合:测试数据与训练数据差别不大 Weight_L1 = tf.Variable(tf.truncated_normal([784, 500], stddev = 0.1)) biases_L1 = tf.Variable(tf.zeros([500]) + 0.1) L1 = tf.nn.tanh(tf.matmul(x, Weight_L1) + biases_L1) L1_drop = tf.nn.dropout(L1, keep_prob) lr = tf.Variable(0.001) Weight_L2 = tf.Variable(tf.truncated_normal([500, 300], stddev = 0.1)) biases_L2 = tf.Variable(tf.zeros([300]) + 0.1) L2 = tf.nn.tanh(tf.matmul(L1_drop, Weight_L2) + biases_L2) L2_drop = tf.nn.dropout(L2, keep_prob) Weight_L3 = tf.Variable(tf.truncated_normal([300, 100], stddev = 0.1)) biases_L3 = tf.Variable(tf.zeros([100]) + 0.1) L3 = tf.nn.tanh(tf.matmul(L2_drop, Weight_L3) + biases_L3) L3_drop = tf.nn.dropout(L3, keep_prob) Weight_L4 = tf.Variable(tf.truncated_normal([100, 10], stddev = 0.1)) biases_L4 = tf.Variable(tf.zeros([10]) + 0.1) prediction = tf.nn.softmax(tf.matmul(L3_drop, Weight_L4) + biases_L4) #多分类 #损失函数 # loss = tf.reduce_mean(tf.square(y - prediction)) #二次代价函数 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y, logits = prediction)) #交叉熵 #优化器 # train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #梯度下降 train_step = tf.train.AdamOptimizer(lr).minimize(loss) init = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #cast布尔转数值 with tf.Session() as sess: sess.run(init) for epoch in range(31): sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch))) #迭代降低学习率 for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict = {x:batch_xs, y:batch_ys, keep_prob:1.0}) test_acc = sess.run(accuracy, feed_dict = {x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0}) train_acc = sess.run(accuracy, feed_dict = {x:mnist.train.images, y:mnist.train.labels, keep_prob:1.0}) learning_rate = sess.run(lr) print("Iter " + str(epoch) + ", Testing Accuracy " + str(test_acc) + ", Training Accuracy " + str(train_acc) + ", Learning Rate " + str(learning_rate))
以下是迭代结果
[Python] 纯文本查看 复制代码
Iter 0, Testing Accuracy 0.9502, Training Accuracy 0.95552725, Learning Rate 0.001 Iter 1, Testing Accuracy 0.96, Training Accuracy 0.9675636, Learning Rate 0.00095 Iter 2, Testing Accuracy 0.9679, Training Accuracy 0.97810906, Learning Rate 0.0009025 Iter 3, Testing Accuracy 0.9707, Training Accuracy 0.98196363, Learning Rate 0.000857375 Iter 4, Testing Accuracy 0.9716, Training Accuracy 0.9841273, Learning Rate 0.00081450626 Iter 5, Testing Accuracy 0.9749, Training Accuracy 0.98634547, Learning Rate 0.0007737809 Iter 6, Testing Accuracy 0.9747, Training Accuracy 0.9884909, Learning Rate 0.0007350919 Iter 7, Testing Accuracy 0.9767, Training Accuracy 0.98965454, Learning Rate 0.0006983373 Iter 8, Testing Accuracy 0.9756, Training Accuracy 0.99114543, Learning Rate 0.0006634204 Iter 9, Testing Accuracy 0.9792, Training Accuracy 0.99325454, Learning Rate 0.0006302494 Iter 10, Testing Accuracy 0.9774, Training Accuracy 0.99332726, Learning Rate 0.0005987369 Iter 11, Testing Accuracy 0.9765, Training Accuracy 0.99332726, Learning Rate 0.0005688001 Iter 12, Testing Accuracy 0.9802, Training Accuracy 0.9935273, Learning Rate 0.0005403601 Iter 13, Testing Accuracy 0.9815, Training Accuracy 0.9952, Learning Rate 0.0005133421 Iter 14, Testing Accuracy 0.9777, Training Accuracy 0.99514544, Learning Rate 0.000487675 Iter 15, Testing Accuracy 0.9804, Training Accuracy 0.9954, Learning Rate 0.00046329122 Iter 16, Testing Accuracy 0.9813, Training Accuracy 0.9958909, Learning Rate 0.00044012666 Iter 17, Testing Accuracy 0.9802, Training Accuracy 0.9961636, Learning Rate 0.00041812033 Iter 18, Testing Accuracy 0.9766, Training Accuracy 0.99463636, Learning Rate 0.00039721432 Iter 19, Testing Accuracy 0.9809, Training Accuracy 0.99652725, Learning Rate 0.0003773536 Iter 20, Testing Accuracy 0.9758, Training Accuracy 0.99563634, Learning Rate 0.00035848594 Iter 21, Testing Accuracy 0.9807, Training Accuracy 0.99667275, Learning Rate 0.00034056162 Iter 22, Testing Accuracy 0.979, Training Accuracy 0.9961636, Learning Rate 0.00032353355 Iter 23, Testing Accuracy 0.9813, Training Accuracy 0.99692726, Learning Rate 0.00030735688 Iter 24, Testing Accuracy 0.9806, Training Accuracy 0.997, Learning Rate 0.000291989 Iter 25, Testing Accuracy 0.9799, Training Accuracy 0.9967818, Learning Rate 0.00027738957 Iter 26, Testing Accuracy 0.981, Training Accuracy 0.99725455, Learning Rate 0.0002635201 Iter 27, Testing Accuracy 0.9813, Training Accuracy 0.9972909, Learning Rate 0.00025034408 Iter 28, Testing Accuracy 0.9806, Training Accuracy 0.9973636, Learning Rate 0.00023782688 Iter 29, Testing Accuracy 0.981, Training Accuracy 0.9973818, Learning Rate 0.00022593554 Iter 30, Testing Accuracy 0.9812, Training Accuracy 0.99743634, Learning Rate 0.00021463877
版权声明:
本站所有资源均为站长或网友整理自互联网或站长购买自互联网,站长无法分辨资源版权出自何处,所以不承担任何版权以及其他问题带来的法律责任,如有侵权或者其他问题请联系站长删除!站长QQ754403226 谢谢。