Verified Commit 479c1957 authored by Kiryuu Sakuya's avatar Kiryuu Sakuya 🎵
Browse files

Add exam.3

parent 6f4e1295
import numpy as np
import os
import cv2
class batchGenerator:
def __init__(self, basePath='data/processed/train_224/', batchSize=256):
self.basePath = basePath
# 读取全部文件名
self.fileList = os.listdir(self.basePath)
# 打乱文件名顺序
for i in range(10):
np.random.shuffle(self.fileList)
# 记录总样本数
self.num_files = len(self.fileList)
# 记录现在样本索引的游标,每次读取bacth后,游标像后移动
# 一个epoch后,即文件读完时,游标回到 0
self.curIndex = 0
# 该生成器每次返回的样本数量(最后一次返回的数量为 总数%batchSize )
self.batchSize = batchSize
self.labels = ['bus','family sedan','fire engine','racing car']
self.num_labels = len(self.labels)
def getBatch(self):
# 记录当前batch的图片值和对应的标签
curBatchX = []
curBatchY = []
endIndex = self.curIndex + self.batchSize
if endIndex \u003e= self.num_files:
endIndex = None
# 当前batch再次打乱顺序
curSampleList = [fileName for fileName in self.fileList[self.curIndex:endIndex]]
np.random.shuffle(curSampleList)
for fileName in curSampleList:
# 读取当前图片
file = os.path.join(self.basePath, fileName)
image = cv2.imread(file)
# 确定当前图片标签
cur_type = fileName.split('(')[0].strip()
try:
curLabel = np.zeros(self.num_labels)
curLabel[self.labels.index(cur_type)] = 1
except:
print('file name error')
print(fileName)
exit()
# 添加值到待返回的列表
curBatchX.append(list(image))
curBatchY.append(curLabel)
# 改变curIndex的值
self.curIndex = endIndex
if endIndex is None:
np.random.shuffle(self.fileList)
self.curIndex = 0
return np.array(curBatchX), np.array(curBatchY)
import numpy as np
import os
import cv2
class batchGenerator:
'''
训练数据生成器,需要你完成下面两个函数
'''
def __init__(self, basePath='data/processed/train_224/', batchSize=256):
'''
数据集中有四类图片分别是'bus','family sedan','fire engine','racing car',
每个图片的文件名形式为\"XXX (id).jpg\"\"XXX (id)_flipx.jpg\",例如\"bus (1).jpg\",\"bus (1)_flipx.jpg\"
:param basePath:数据集路径
:param batchSize: 每次获取的图片数量
'''
#********** Begin **********#
#********** End **********#
def getBatch(self):
'''
循环遍历数据集
可以通过分割文件名获取所属类别,然后你需要将类别转为onehot类型的表示形式,例如[0,0,1,0]代表'fire engine'
如果一次循环最后剩余样本数不到bactchSize,仅返回剩余全部样本即可
:return: 批图片数据(batchSize,224,224,3),与每个图片对应的标签(batchSize,4)
'''
#********** Begin **********#
#********** End **********#
\ No newline at end of file
from generatorCompleted import batchGenerator as correctG
from generatorForUsers import batchGenerator as userG
import numpy as np
# 首先至少能够正常构造一个对象
try:
correct_g = correctG(basePath= 'data/processed/valid_224' ,batchSize=80)
user_g = userG(basePath= 'data/processed/valid_224' ,batchSize=80)
except:
print('未能通过本关测试,无法正确新建对象!')
exit()
# 一次读完全部样本,检查用户产生的X,Y和标答产生的X,Y各个维度的总和是否相等,如果各个维度的和相等,则没问题
try:
X_c,Y_c = correct_g.getBatch()
X_u,Y_u = user_g.getBatch()
except:
print('未能通过本关测试,无法正确调用 getBatch() !')
exit()
if not(len(X_c)==len(X_u) and len(Y_c)==len(Y_u)):
print('未能通过本关测试,返回数据长度不正确!')
if not (np.sum(X_c) == np.sum(X_u)):
print('未能通过本关测试,返回的图片数据有误!')
exit()
if not (np.sum(Y_c,axis=0) == np.sum(Y_u,axis=0)).all():
print('未能通过本关测试,返回的标签数据有误!')
exit()
# 检查是否可以做到最后一次读的数量是剩余的全部样本
correct_g_2 = correctG(basePath= 'data/processed/valid_224' ,batchSize=70)
user_g_2 = userG(basePath= 'data/processed/valid_224' ,batchSize=70)
# 取了70个样本
X_u,Y_u = user_g_2.getBatch()
# 这次应该只取了10个样本
X_u_rest,Y_u_rest = user_g_2.getBatch()
if not(len(X_u_rest)==10) and len(Y_u_rest)==10:
print('未能通过本关测试,样本数不足batchSize时没能正确返回剩余样本!')
exit()
# 再次读取时,如果索引更新正确,应该重新读70个样本了
X_u_new,Y_u_new = user_g_2.getBatch()
if not(len(X_u_new)==70) and len(Y_u_new)==70:
print('未能通过本关测试,数据集全部读完后,不能正确开始第二次循环读取!')
exit()
# 如果以上全部通过,几乎可以认定正确
print('恭喜你通过本关测试!',end='')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment