generatorCompleted.py 2.08 KB
Newer Older
Kiryuu Sakuya's avatar
Kiryuu Sakuya committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import os
import cv2


class batchGenerator:
    def __init__(self, basePath='data/processed/train_224/', batchSize=256):
        self.basePath = basePath
        # 读取全部文件名
        self.fileList = os.listdir(self.basePath)
        # 打乱文件名顺序
        for i in range(10):
            np.random.shuffle(self.fileList)
        # 记录总样本数
        self.num_files = len(self.fileList)
        # 记录现在样本索引的游标,每次读取bacth后,游标像后移动
        # 一个epoch后,即文件读完时,游标回到 0
        self.curIndex = 0
        # 该生成器每次返回的样本数量(最后一次返回的数量为 总数%batchSize )
        self.batchSize = batchSize
        self.labels = ['bus','family sedan','fire engine','racing car']
        self.num_labels = len(self.labels)

    def getBatch(self):

        # 记录当前batch的图片值和对应的标签
        curBatchX = []
        curBatchY = []

        endIndex = self.curIndex + self.batchSize
Kiryuu Sakuya's avatar
Kiryuu Sakuya committed
31
        if endIndex >= self.num_files:
Kiryuu Sakuya's avatar
Kiryuu Sakuya committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
            endIndex = None

        # 当前batch再次打乱顺序
        curSampleList = [fileName for fileName in self.fileList[self.curIndex:endIndex]]
        np.random.shuffle(curSampleList)

        for fileName in curSampleList:
            # 读取当前图片
            file = os.path.join(self.basePath, fileName)
            image = cv2.imread(file)
            # 确定当前图片标签
            cur_type = fileName.split('(')[0].strip()
            try:
                curLabel = np.zeros(self.num_labels)
                curLabel[self.labels.index(cur_type)] = 1
            except:
                print('file name error')
                print(fileName)
                exit()

            # 添加值到待返回的列表
            curBatchX.append(list(image))
            curBatchY.append(curLabel)

        # 改变curIndex的值
        self.curIndex = endIndex
        if endIndex is None:
            np.random.shuffle(self.fileList)
            self.curIndex = 0

        return np.array(curBatchX), np.array(curBatchY)