Unverified Commit 3c4135b0 authored by Kiryuu Sakuya's avatar Kiryuu Sakuya 🎵
Browse files

Finish 05

parent 19c2d8b9
This diff is collapsed.
# -*- coding: utf-8 -*-
import os
import numpy as np
from time import time
import tensorflow as tf
import matplotlib.pyplot as plt
from mnist import input_data #help reading the mnist data
def load_data():
mnist = input_data.read_data_sets("data", one_hot=True)
return mnist
def fc_layer(input, #input data
input_dim, #the number of input neurals
output_dim, #the number of output neurals
activation=None): #activation function
W = tf.Variable(tf.random.normal([input_dim, output_dim], stddev=0.1))
b = tf.Variable(tf.zeros([output_dim]))
output = tf.matmul(input, W) + b
if activation is not None: #默认不使用激活函数
output = activation(output) #若传入激活函数,则用其对输出结果进行变换
return output
def model(x, mlp):
with tf.name_scope("mlp"):
fc_layer1 = fc_layer(x, 784, mlp[0], activation=tf.nn.relu)
#fc_layer2 = fc_layer(fc_layer1, mlp[0], mlp[1], activation=tf.nn.relu)
final_fc_layer = fc_layer(fc_layer1, mlp[0], mlp[1], activation=None)
with tf.name_scope("softmax"):
pred = tf.nn.softmax(final_fc_layer)
return pred, final_fc_layer
def train(mnist):
#tensorboard settings
tf.compat.v1.reset_default_graph() #reset compution graoh
logdir = "logs"
#training hyperparameters
train_epochs = 40
learning_rate = 0.0001
batch_size = 50
#other parameters
loss_list = []
display_step = 1
save_step = 10 #存储模型的粒度
ckpt_dir = "ckpt_dir" #保存模型文件的目录
total_batch = int(mnist.train.num_examples/batch_size) #how many batches in 1 epoch
with tf.name_scope("Input"):
x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="X") #'None' help to use mini-batch gradient descent
y = tf.compat.v1.placeholder(tf.float32, [None, 10], name="Y")
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) #[B, H, W, C]
tf.summary.image("input", image_shaped_input, 10) #write 25 images to tensorboard
with tf.name_scope("Model"):
pred, final_fc_layer = model(x, mlp=[256, 10])
tf.summary.histogram("final_fc_layer", final_fc_layer) #将前向输出值记录在tensorboard并以直方图显示
with tf.name_scope("Loss"):
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=final_fc_layer, labels=y)) #cross entropy loss
tf.summary.scalar("loss", loss_function) #write the loss as scalar to tensorboard
with tf.name_scope("Optimizer"):
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_function)
with tf.name_scope("Accuracy"):
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) #argmax(-, 1) - return the index of the max number in each column
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) #tf.cast(-, tf.float32) - transform the type to tf.float32
tf.summary.scalar("accuracy", accuracy) #write the accuracy as scalar to tensorboard
init = tf.compat.v1.global_variables_initializer()
#模型持久化 - 声明完所有变量后
saver = tf.compat.v1.train.Saver()
start_time = time() #记录训练开始时间
with tf.compat.v1.Session() as sess:
merged = tf.compat.v1.summary.merge_all() #合并所有的summary
writer = tf.compat.v1.summary.FileWriter(logdir, sess.graph) #create writer for writing the computation graph to the log
for epoch in range(train_epochs):
for batch in range(total_batch):
#next_batch() is made for training the model, so it will shuffle the data inside automatically
# and read the next batch data when it is run again.
batch_images_xs, batch_labels_ys = mnist.train.next_batch(batch_size=batch_size)
sess.run(optimizer, feed_dict={x:batch_images_xs, y:batch_labels_ys})
# writer.add_summary(summary, epoch) #write the summary into the file
#Every 1 epoch, compute loss and accuracy on validation set. Validation set is not batched.
loss, acc = sess.run([loss_function, accuracy], feed_dict={x:mnist.validation.images, y:mnist.validation.labels})
#print loss and accuracy
if (epoch+1) % display_step == 0:
print("epoch=%02d" % (epoch+1), "loss=", "{:.9f}".format(loss), "accuracy=", "{:.4f}".format(acc))
if (epoch+1) % save_step == 0:
#model_persistence(saver, ckpt_dir)
if not os.path.exists(ckpt_dir):
saver.save(sess, os.path.join(ckpt_dir, "mnist_model_{:06d}.ckpt".format(epoch+1))) #save model
print("mnist_model_{:06d}.ckpt saved".format(epoch+1))
saver.save(sess, os.path.join(ckpt_dir, "mnist_model.ckpt")) #save model
print("Model saved.")
duration = time() - start_time #运行总时间
print("Train Finished. Take:", "{:.2f}".format(duration), "s")
#Test the model on training set
acc_test = sess.run(accuracy, feed_dict={x:mnist.train.images, y:mnist.train.labels})
print("Training set Accuracy=", acc_test)
#Test the model on test set
pred_show, acc_test = sess.run([tf.argmax(pred, 1), accuracy], feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Test set Accuracy=", acc_test)
#output loss curve
fig1 = plt.figure()
return pred_show
def visualize_prediction(images, #images list
labels, #labels list
pred_show, #pred_show list
index, #start visualizing the prediction from the 'index'th image
num=10): #display 'num=10' images at a time
fig = plt.gcf() #获取当前图表(get current figure)
fig.set_size_inches(10, 12) #单位-英寸,1英寸=2.54cm
if num > 25:
num = 25 #最多显示25个子图
for i in range(0, num):
ax = plt.subplot(5, 5, i+1) #获取当前要处理的子图
ax.imshow(np.reshape(images[index], (28, 28)), cmap="binary") #显示第index个image
title = "label=" + str(np.argmax(labels[index])) #构建该图要显示的title
if len(pred_show) > 0: #若pred_show不为空,则显示
title += ", predict=" + str(pred_show[index])
ax.set_title(title, fontsize=10) #显示图上的title信息
ax.set_xticks([]) #不显示坐标轴
index += 1
def pred_errors(pred_show):
compare_list = pred_show == np.argmax(mnist.test.labels, 1)
err_list = [i for i in range(len(compare_list)) if compare_list[i]==False]
print("sum:" + len(err_list))
#print error predict and corresponding labels
for x in err_list:
print("index=" + str(x) + "labels=", np.argmax(labels[x]), "predict=", pred_show[x])
return err_list
def main(argv=None):
mnist = load_data()
pred_show = train(mnist)
visualize_prediction(mnist.test.images, mnist.test.labels, pred_show, 10, 25)
if __name__=='__main__':
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment