keras 学习笔记,二 ——— data_generator
每次输出一个batch,基于keras.utils.Sequence
Base object for fitting to a sequence of data, such as a dataset.
Every
Sequence
must implement the__getitem__
and the__len__
methods. If you want to modify your dataset between epochs you may implementon_epoch_end
. The method__getitem__
should return a complete batch.Notes
Sequence
are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.
Sequence example: https://keras.io/utils/#sequence
#!/usr/bin/env python # coding: utf-8 from keras.utils import Sequence import numpy as np from keras.preprocessing import image from skimage.io import imread class My_Custom_Generator(Sequence) : def __init__(self, image_filenames, labels, batch_size) : self.image_filenames = image_filenames self.labels = labels self.batch_size = batch_size def __len__(self) : return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int) def __getitem__(self, idx) : batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size] batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size] batch_seq = [] #batch_seq for x in batch_x: #len(x) =16 seq_img = [] for img in x: #len(item) =25 seq_img.append(image.img_to_array(imread(img))) seq_x = np.array([seq_img]) batch_seq.append(seq_img) batch_seq_list = np.array(batch_seq) return batch_seq_list, np.array(batch_y)
两种将数据输出为numpy.array的方法
通过list转为numpy.array
速度快,list转array过程需要注意数据维度变化
''' list batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size] batch_seq = [] #batch_seq for x in batch_x: #len(x) =16 seq_img = [] for img in x: #len(item) =25 seq_img.append(image.img_to_array(imread(img))) seq_x = np.array([seq_img]) batch_seq.append(seq_img) batch_seq_list = np.array(batch_seq) '''
利用np.empty
速度慢,开始前确定batch维度即可
'''numpy batch_x =X_train_filenames[idx * batch_size : (idx+1) * batch_size] batch_seq = np.empty((0,25,224,224,3),float) for x in batch_x: #len(x) =16 seq_batch = np.empty((0,224,224,3),float) for item in x: #len(item) =25 seq_batch = np.append(seq_batch, np.expand_dims(image.img_to_array(imread(item)), axis=0), axis = 0) batch_seq2 = np.append(batch_seq, np.expand_dims((seq_batch), axis=0), axis = 0) '''