Unverified Commit 5a6075a9 authored by Kiryuu Sakuya's avatar Kiryuu Sakuya 🎵
Browse files

Finish 06

parent 9f8762fb
This diff is collapsed.
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import datetime as dt
datasets = tfds.load("cifar10")
train_dataset, test_dataset = datasets["train"], datasets["test"]
assert isinstance(train_dataset, tf.data.Dataset)
cifar10_builder = tfds.builder("cifar10")
# See the information on the dataset
# info = cifar10_builder.info
# print(info)
for batch in train_dataset.batch(50000):
x_train = batch['image']
y_train = batch['label'].numpy().astype('uint8')
for batch in test_dataset.batch(10000):
x_test = batch['image']
y_test = batch['label'].numpy().astype('uint8')
# Normalize pixel values to be between 0 and 1
x_test = x_test / 255
x_train = x_train / 255
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(32, 32, 3)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
callbacks = [
# Write TensorBoard logs to `./logs` directory
tf.keras.callbacks.TensorBoard(log_dir='logs/{}'.format(dt.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
]
history = model.fit(x_train, y_train, epochs=50, validation_data=(x_test, y_test), callbacks=callbacks)
# plt.plot(history.history['accuracy'], label='accuracy')
# plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
# plt.xlabel('Epoch')
# plt.ylabel('Accuracy')
# plt.ylim([0.5, 1])
# plt.legend(loc='lower right')
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(test_acc)
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment