Keras에서 custom dataset을 불러오기 위해 data generator를 사용한다. batch-by-batch에 data를 순회하면서 불러오도록 해준다.
여기서 Keras data generator는 infinite해야 한다. 안그러면 StopIteration을 발생한다.
* 코드 / Code
def subtract_mean_gen(x_source,y_source,avg_image,batch):
batch_list_x=[]
batch_list_y=[]
for line,y in zip(x_source,y_source):
x=line.astype('float32')
x=x-avg_image
batch_list_x.append(x)
batch_list_y.append(y)
if len(batch_list_x) == batch:
yield (np.array(batch_list_x),np.array(batch_list_y))
batch_list_x=[]
batch_list_y=[]
model = resnet.ResnetBuilder.build_resnet_18((img_channels, img_rows, img_cols), nb_classes)
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
val = subtract_mean_gen(X_test,Y_test,avg_image_test,batch_size)
model.fit_generator(subtract_mean_gen(X_train,Y_train,avg_image_train,batch_size), steps_per_epoch=X_train.shape[0]//batch_size,epochs=nb_epoch,validation_data = val,
validation_steps = X_test.shape[0]//batch_size)
* 에러 / Error
239/249 [===========================>..] - ETA: 60s - loss: 1.3318 - acc: 0.8330Exception in thread Thread-1:
Traceback (most recent call last):
File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/usr/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/usr/local/lib/python2.7/dist-packages/keras/utils/data_utils.py", line 560, in data_generator_task
generator_output = next(self._generator)
StopIteration
240/249 [===========================>..] - ETA: 54s - loss: 1.3283 - acc: 0.8337Traceback (most recent call last):
File "cifa10-copy.py", line 125, in <module>
validation_steps = X_test.shape[0]//batch_size)
File "/usr/local/lib/python2.7/dist-packages/keras/legacy/interfaces.py", line 87, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1809, in fit_generator
generator_output = next(output_generator)
StopIteration
위 처럼 for문으로 정해진 batch와 step에 맞게만 코드를 짜도 이상이 없지 않느냐 라고 생각할 수 있다.
하지만 Keras는 마지막 step일지라도 queue에 다음 batch를 저장해둔다. 따라서 위의 에러는 keras가 다음 batch에 대한 데이터를 얻어오려고 하는데, data generator는 이미 끝이 난 것이다.
* 수정 / Modified Code
def subtract_mean_gen(x_source,y_source,avg_image,batch):
while True:
batch_list_x=[]
batch_list_y=[]
for line,y in zip(x_source,y_source):
x=line.astype('float32')
x=x-avg_image
batch_list_x.append(x)
batch_list_y.append(y)
if len(batch_list_x) == batch:
yield (np.array(batch_list_x),np.array(batch_list_y))
batch_list_x=[]
batch_list_y=[]
따라서 이와 같이 for문으로 batch 수에 맞게 돌더라고 while True나 while 1을 통해 infinite한 루프를 만들어 주어야 한다.
그래서 train하는 코드에서 keras의 data generator 함수를 사용할 때, 인자로 epochs와 step_per_epochs를 전달해주는 이유도 이와 같다.
keras data generator는 infinite loop이기 때문에 epochs와 batch가 적절하게 주어저도 step_per_epochs를 통해 batch를 몇 번 수행할지 steps을 인자로 전달받아, 그 수(step_per_epochs)만큼 batch를 돌기 때문이다.
'딥러닝 프레임워크 > Keras' 카테고리의 다른 글
Keras Progbar(Progress-bar) new-line update 해결하기 (0) | 2019.10.30 |
---|