Unverified Commit 9a3eda24 authored by Kiryuu Sakuya's avatar Kiryuu Sakuya 🎵
Browse files

Add 04 homework

parent c694500e
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from mnist import input_data
import matplotlib.pyplot as plt
def load_data():
mnist = input_data.read_data_sets('data', one_hot=True)
return mnist
def model(x):
with tf.name_scope("weights"):
W = tf.Variable(tf.random.normal([784, 10]), name="W")
with tf.name_scope("biases"):
b = tf.Variable(tf.zeros([10]), name="b")
with tf.name_scope("neural"):
forward = tf.matmul(x, W) + b #don't use multiply or *
with tf.name_scope("softmax"):
pred = tf.nn.softmax(forward)
return pred
def train(mnist):
#tensorboard settings
tf.compat.v1.reset_default_graph() #reset compution graoh
logdir = "logs"
#training hyperparameters
train_epochs = 150
learning_rate = 0.025
batch_size = 100
#other parameters
loss_list = []
display_step = 1
total_batch = int(mnist.train.num_examples/batch_size) #how many batches in 1 epoch
with tf.name_scope("Input"):
x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="X") #'None' help to use mini-batch gradient descent
y = tf.compat.v1.placeholder(tf.float32, [None, 10], name="Y")
with tf.name_scope("Model"):
pred = model(x)
with tf.name_scope("LossFunction"):
loss_function = tf.reduce_mean(-tf.compat.v1.reduce_sum(y*tf.math.log(pred), reduction_indices=1)) #cross entropy loss
#write loss as scalar to tensorboard
tf.summary.scalar("loss", loss_function)
with tf.name_scope("Optimizer"):
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
with tf.name_scope("Accuracy"):
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) #argmax(-, 1) - return the index of the max number in each column
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) #tf.cast(-, tf.float32) - transform the type to tf.float32
init = tf.compat.v1.global_variables_initializer()
merged = tf.compat.v1.summary.merge_all() #合并需要记录的文件,方便一次性写入
with tf.compat.v1.Session() as sess:
#create writer for writing the computation graph to the log
writer = tf.compat.v1.summary.FileWriter(logdir, sess.graph)
for epoch in range(train_epochs):
for batch in range(total_batch):
#next_batch() is made for training the model, so it will shuffle the data inside automatically
# and read the next batch data when it is run again.
batch_images_xs, batch_labels_ys = mnist.train.next_batch(batch_size=batch_size)
sess.run(optimizer, feed_dict={x:batch_images_xs, y:batch_labels_ys})
#writer.add_summary(summary, epoch) #After 1 epoch, 将运行在val set的所有summary写入文件
#After 1 epoch, compute loss and accuracy on validation set. Validation set is not batched.
loss, acc = sess.run([loss_function, accuracy], feed_dict={x:mnist.validation.images, y:mnist.validation.labels})
#print loss and accuracy
if (epoch+1) % display_step == 0:
print("epoch=%02d" % (epoch+1), "loss=", "{:.9f}".format(loss), "accuracy=", "{:.4f}".format(acc))
print("Train Finished.")
#Test the model on training set
acc_test = sess.run(accuracy, feed_dict={x:mnist.train.images, y:mnist.train.labels})
print("Training set Accuracy=", acc_test)
#Test the model on test set
pred_show, acc_test = sess.run([tf.argmax(pred, 1), accuracy], feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Test set Accuracy=", acc_test)
#output loss curve
fig1 = plt.figure()
return pred_show
def visualize_prediction(images, #images list
labels, #labels list
pred_show, #pred_show list
index, #start visualizing the prediction from the 'index'th image
num=10): #display 'num=10' images at a time
fig = plt.gcf() #获取当前图表(get current figure)
fig.set_size_inches(10, 12) #单位-英寸,1英寸=2.54cm
if num > 25:
num = 25 #最多显示25个子图
for i in range(0, num):
ax = plt.subplot(5, 5, i+1) #获取当前要处理的子图
ax.imshow(np.reshape(images[index], (28, 28)), cmap="binary") #显示第index个image
title = "label=" + str(np.argmax(labels[index])) #构建该图要显示的title
if len(pred_show) > 0: #若pred_show不为空,则显示
title += ", predict=" + str(pred_show[index])
ax.set_title(title, fontsize=10) #显示图上的title信息
ax.set_xticks([]) #不显示坐标轴
index += 1
def main(argv=None):
mnist = load_data() #load data
pred_show = train(mnist)
visualize_prediction(mnist.test.images, mnist.test.labels, pred_show, 10, 25)
if __name__=='__main__':
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment