GitHubに画像解析用のKerasディレクトリを公開しました。
お久しぶりです、ひぐです!
大学院に入ってからニューラルネットワークを使って、医療画像の補完を行う研究をするようになりました。
そこで今日は自分が普段使ってる研究用のコードを紹介したいと思います!
結構KaggleやQiitaとかでNNライブラリ用のソースコードを検索すると、使い切りなコードが多くないですか?
今回は繰り返し実験できるようにディレクトリごとコードを公開しました!
モデルの保存や、出力結果の記録をクラス単位で実装しています。よかったら参考にしてください。
なかなか上手に書けない部分もあるので、ご教授いただければ幸いです。。。笑
というかまだまだ絶賛修正中なので温かく見守ってくださいw
URLはこちら github.com
コードの概要
全体的にはこんなイメージです。
主な機能は以下のようになっています。
* main.pyを走らせると自動で、loss関数のグラフ、出力画像を自動作成
* batch size training rateなどのハイパーパラメータははmain.py の引数で渡せる
* データをジェネレーターで読み込むのでデータ量が多くなってもメモリエラーにならない
* 結果を出力するディレクトリに使用したModelの名前とlossの値を使用する(ので見やすい)
main.py
import いろいろ(省略) INPUT_SIZE = (256, 256) CONCAT_LEFT_RIGHT=True CHANGE_SLIDE2_FILL = True def train(parser): configs = json.load(open('./settings.json')) reporter = Reporter(parser=parser) loader = Loader(configs['dataset_path2'], parser.batch_size) if CHANGE_SLIDE2_FILL: loader.change_slide2fill() reporter.add_log_documents('Done change_slide2fill.') if CONCAT_LEFT_RIGHT: loader.concat_left_right() reporter.add_log_documents('Done concat_left_right.') train_gen, valid_gen, test_gen = loader.return_gen() train_steps, valid_steps, test_steps = loader.return_step() # ---------------------------model---------------------------------- input_channel_count = parser.input_channel output_channel_count = 3 first_layer_filter_count = 32 network = UNet(input_channel_count, output_channel_count, first_layer_filter_count) model = network.get_model() model.compile(optimizer='adam', loss='mse') model.summary() # ---------------------------training---------------------------------- batch_size = parser.batch_size epochs = parser.epoch config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.9 config.gpu_options.allow_growth = True sess = tf.Session(config=config) # fit_generatorのコールバック関数の指定・TensorBoardとEarlyStoppingの指定 logdir = os.path.join('./logs', dt.today().strftime("%Y%m%d_%H%M")) os.makedirs(logdir, exist_ok=True) tb_cb = TensorBoard(log_dir=logdir, histogram_freq=1, write_graph=True, write_images=True) es_cb = EarlyStopping(monitor='val_loss', patience=parser.early_stopping, verbose=1, mode='auto') print("start training.") # Pythonジェネレータ(またはSequenceのインスタンス)によりバッチ毎に生成されたデータでモデルを訓練します. history = model.fit_generator( generator=train_gen, steps_per_epoch=train_steps, epochs=epochs, validation_data=valid_gen, validation_steps=valid_steps, # use_multiprocessing=True, callbacks=[es_cb, tb_cb]) print("finish training. And start making predict.") train_preds = model.predict_generator(train_gen, steps=train_steps, verbose=1) valid_preds = model.predict_generator(valid_gen, steps=valid_steps, verbose=1) test_preds = model.predict_generator(test_gen, steps=test_steps, verbose=1) print("finish making predict. And render preds.") # ==========================report==================================== reporter.add_val_loss(history.history['val_loss']) reporter.add_model_name(network.__class__.__name__) reporter.generate_main_dir() reporter.plot_history(history) reporter.save_params(parser, history) input_img_list = [] # reporter.plot_predict(train_list, Left_RGB, Right_RGB, train_preds, INPUT_SIZE, save_folder='train') reporter.plot_predict(loader.train_list, loader.Left_slide, loader.Left_RGB, train_preds, INPUT_SIZE, save_folder='train') reporter.plot_predict(loader.valid_list, loader.Left_slide, loader.Left_RGB, valid_preds, INPUT_SIZE, save_folder='valid') reporter.plot_predict(loader.test_list, loader.Left_slide, loader.Left_RGB, test_preds, INPUT_SIZE, save_folder='test') model.save("model.h5") def get_parser(): parser = argparse.ArgumentParser( prog='generate parallax image using U-Net', usage='python main.py', description='This module generate parallax image using U-Net.', add_help=True ) parser.add_argument('-e', '--epoch', type=int, default=200, help='Number of epochs') parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch size') parser.add_argument('-t', '--trainrate', type=float, default=0.85, help='Training rate') parser.add_argument('-es', '--early_stopping', type=int, default=20, help='early_stopping patience') parser.add_argument('-i', '--input_channel', type=int, default=7, help='input_channel') parser.add_argument('-a', '--augmentation', action='store_true', help='Number of epochs') return parser if __name__ == '__main__': parser = get_parser().parse_args() train(parser)
ディレクトリのパスはsetting.jsonで一括管理しています。
trainという巨大な関数にargparserで引数を渡して、ハイパーパラメータを用いています。
自分の研究では、入力、教師データどちらにも画像を使うため独自のジェネレータを作成しています。
repoter.py
importあれこれ class Reporter: ROOT_DIR = "Result" IMAGE_DIR = "image" LEARNING_DIR = "learning" INFO_DIR = "info" MODEL_DIR = "model" PARAMETER = "parameter.txt" IMAGE_PREFIX = "epoch_" IMAGE_EXTENSION = ".png" def __init__(self, result_dir=None, parser=None): self._root_dir = self.ROOT_DIR self.create_dirs() self.parameters = list() # def make_main_dir(self): def add_model_name(self, model_name): if not type(model_name) is str: raise ValueError('model_name is not str.') self.model_name = model_name def add_val_loss(self, val_loss): self.val_loss = str(round(min(val_loss))) def generate_main_dir(self): main_dir = self.val_loss + '_' + dt.today().strftime("%Y%m%d_%H%M") + '_' + self.model_name self.main_dir = os.path.join(self._root_dir, main_dir) os.makedirs(self.main_dir, exist_ok=True) def create_dirs(self): os.makedirs(self._root_dir, exist_ok=True) def plot_history(self,history,title='loss'): # 後でfontsize変える plt.rcParams['axes.linewidth'] = 1.0 # axis line width plt.rcParams["font.size"] = 24 # 全体のフォントサイズが変更されます。 plt.rcParams['axes.grid'] = True # make grid plt.plot(history.history['loss'], linewidth=1.5, marker='o') plt.plot(history.history['val_loss'], linewidth=1., marker='o') plt.tick_params(labelsize=20) plt.title('model loss') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(['loss', 'val_loss'], loc='upper right', fontsize=18) plt.tight_layout() plt.savefig(os.path.join(self.main_dir, title + self.IMAGE_EXTENSION)) if len(history.history['val_loss'])>=10: plt.xlim(10, len(history.history['val_loss'])) plt.ylim(0, int(history.history['val_loss'][9]*1.1)) plt.savefig(os.path.join(self.main_dir, title +'_remove_outlies_'+ self.IMAGE_EXTENSION)) def add_log_documents(self, add_message): self.parameters.append(add_message) def save_params(self,parser,history): #early_stoppingを考慮 self.parameters.append("Number of epochs:" + str(len(history.history['val_loss']))) self.parameters.append("Batch size:" + str(parser.batch_size)) self.parameters.append("Training rate:" + str(parser.trainrate)) self.parameters.append("Augmentation:" + str(parser.augmentation)) self.parameters.append("input_channel:" + str(parser.input_channel)) self.parameters.append("min_val_loss:" + str(min(history.history['val_loss']))) self.parameters.append("min_loss:" + str(min(history.history['loss']))) # self.parameters.append("L2 regularization:" + str(parser.l2reg)) output = "\n".join(self.parameters) filename=os.path.join(self.main_dir,self.PARAMETER) with open(filename, mode='w') as f: f.write(output) @staticmethod def get_concat_h(im1, im2): dst = Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) return dst def plot_predict(self, img_num_list, Left_RGB, Right_RGB, preds, INPUT_SIZE, max_output=20,save_folder='train'): if len(img_num_list) > max_output: img_num_list=img_num_list[:max_output] for i, num in enumerate(img_num_list): if i == 1: print(preds[i].astype(np.uint8)) pred_img = array_to_img(preds[i].astype(np.uint8)) train_img = load_img(Left_RGB[num], target_size=INPUT_SIZE) teach_img = load_img(Right_RGB[num], target_size=INPUT_SIZE) concat_img = self.get_concat_h(train_img, pred_img) concat_img = self.get_concat_h(concat_img, teach_img) os.makedirs(os.path.join(self.main_dir,save_folder), exist_ok=True) array_to_img(concat_img).save(os.path.join(self.main_dir, save_folder, f'pred_{num}.png'))
データの保存を担っています。
Kerasではfit関数を動かすとその返り値にhistoryオブジェクトという出力のログが入ったインスタンスを返します。
これと、parserをmain.pyから受け取ってデータを保存しています。
保存先はResult dir以下に、使用パラメータ・出力結果・lossグラフなどをまとめて格納します。
loader.py
import あれこれ config = json.load(open('./settings.json')) class Loader(object): # コンストラクタ def __init__(self, json_paths, batch_size, init_size=(256, 256)): self.size = init_size self.DATASET_PATH = json_paths self.add_member() self.batch_size = batch_size # def define_IO(self): def add_member(self): """ jsonファイルに記載されている、pathをクラスメンバとして登録する。 self.Left_RGBとかが追加されている。 """ for key, path in self.DATASET_PATH.items(): setattr(self, key, glob.glob(os.path.join(path, '*png'))) #左右の画像を結合してデータを二倍にする def concat_left_right(self): self.Left_slide += self.Right_slide self.Left_RGB += self.Right_RGB self.Left_disparity += self.Left_disparity self.Right_disparity += self.Right_disparity self.Left_bin += self.Left_bin self.Right_bin += self.Right_bin print('Done concat_left_right.') #入力で使う画像を平均値で埋めた画像にする def change_slide2fill(self): self.Left_slide = self.Left_fill self.Right_slide = self.Right_fill def return_gen(self): self.imgs_length = len(self.Left_RGB) # self.train_paths = (self.Left_slide, self.Right_disparity, self.Left_disparity) # sel = self.Left_RGB self.train_list, self.valid_list, self.test_list = self.train_valid_test_splits(self.imgs_length) self.train_steps = math.ceil(len(self.train_list) / self.batch_size) self.valid_steps = math.ceil(len(self.valid_list) / self.batch_size) self.test_steps = math.ceil(len(self.test_list) / self.batch_size) self.train_gen = self.generator_with_preprocessing(self.train_list, self.batch_size) self.valid_gen = self.generator_with_preprocessing(self.valid_list, self.batch_size) self.test_gen = self.generator_with_preprocessing(self.test_list, self.batch_size) return self.train_gen, self.valid_gen, self.test_gen def return_step(self): return self.train_steps, self.valid_steps, self.test_steps @staticmethod def train_valid_test_splits(imgs_length: 'int', train_rate=0.8, valid_rate=0.1, test_rate=0.1): data_array = list(range(imgs_length)) tr = math.floor(imgs_length * train_rate) vl = math.floor(imgs_length * (train_rate + valid_rate)) random.shuffle(data_array) train_list = data_array[:tr] valid_list = data_array[tr:vl] test_list = data_array[vl:] return train_list, valid_list, test_list def load_batch_img_array(self, batch_list, prepro_callback=False,use_bin=True): teach_img_list = [] input_img_list = [] for i in batch_list: LS_img = img_to_array( load_img(self.Left_slide[i], color_mode='rgb', target_size=self.size)).astype(np.uint8) LD_img = img_to_array( load_img(self.Left_disparity[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8) RD_img = img_to_array( load_img(self.Right_disparity[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8) if use_bin: LB_img = img_to_array( load_img(self.Left_bin[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8) # LB_img=np.where(LB_img==255,1,LB_img) RB_img = img_to_array( load_img(self.Right_bin[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8) # RB_img = np.where(RB_img == 255, 1, RB_img) input_img = np.concatenate((LS_img, LD_img, RD_img, LB_img, RB_img), 2).astype(np.uint8) else: input_img=np.concatenate((LS_img, LD_img, RD_img), 2).astype(np.uint8) teach_img = img_to_array( load_img(self.Left_RGB[i], color_mode='rgb', target_size=self.size)).astype(np.uint8) input_img_list.append(input_img) teach_img_list.append(teach_img) return np.stack(input_img_list), np.stack(teach_img_list) def generator_with_preprocessing(self, img_id_list, batch_size):#, *input_paths while True: for i in range(0, len(img_id_list), batch_size): batch_list = img_id_list[i:i + batch_size] batch_input, batch_teach = self.load_batch_img_array(batch_list) yield(batch_input, batch_teach) class DataSet: pass
Data dirからデータ(画像)を読み取ってmain.pyにジェネレータ形式で渡します。
このコードは特にまだまだ改善の余地があります...
実験ごとに入力チャンネル数を変えたいのですが、
ジェネレータに読みだした後、それらを結合するとことが難しく、悩んでいます。
jsonからディレクトリのパスを受け取って、その直下の画像ファイルをすべてクラスメンバにして
読み込むようにしているのがおしゃれポイントです
for key, path in self.DATASET_PATH.items(): setattr(self, key, glob.glob(os.path.join(path, '*png')))
おわりに
ザックリですが以上になります!
わからないところや修正したほうがいいと思う部分がありましたら、連絡いただけたら幸いです!
今後は他の人でも使えるように、どんなタスクでも使えるように、調整して再度公開したいです。
就職したときにも、恥ずかしくないようにきれいで再利用性の高いコードをかけるように頑張っていきたいです!
では~