level 3
import tensorflow as tf
# mnist是tensorflow中一个实例,使用input_data来下载/引入数据
from tensorflow.examples.tutorials.mnist import input_data
# 获取所有的数据,包括train-set test-set
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
#定义批次和一共有多少批次
batch_size = 50
n_batch = mnist.train.num_examples // batch_size
#命名空间
with tf.name_scope('input'):
# 定义两个占位符
# 占位符就是,它并没有真实的数据,但给了下面代码使用数据的机会,等在session中再通过feed-dict来把数据喂给模型
x = tf.placeholder(tf.float32, [None, 784], name='x_input')
y = tf.placeholder(tf.float32, [None, 10], name='y_input')
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, w) + b)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
train = tf.train.GradientDescentOptimizer(0.2).minimize(cost)
# 预测计算准确度
# equal比较两个参数大小是否一样,一样返回true,不一样是false ,得到的其实是true和false的向量
# argmax(y,1)求y在每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组在中最大值在哪个位置
# 找出同一行模型标签和真实标签最大值的位置,tf.equal将两个比较,一样就是true,不一样就是false
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
# 把预测值转化为浮点 ,true变成1.0 false变成0.0 ,再求平均值,
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
# 前面定义的变量,就需要先初始化变量
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('logs', sess.graph)
for i in range(1): # 迭代100次
for batch in range(n_batch): # 分批次迭代
# 获取本批次的数据
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train, feed_dict={x: batch_xs, y: batch_ys})
# 每迭代一次使用test集测试一下准确度
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print('after ', i, 'the accuracy is ', acc)
2019年08月20日 08点08分
