Higu`s diary

新米データサイエンティストのブログ。技術についてゆるく書きます〜

GitHubに画像解析用のKerasディレクトリを公開しました。

お久しぶりです、ひぐです!

大学院に入ってからニューラルネットワークを使って、医療画像の補完を行う研究をするようになりました。

そこで今日は自分が普段使ってる研究用のコードを紹介したいと思います!
結構KaggleやQiitaとかでNNライブラリ用のソースコードを検索すると、使い切りなコードが多くないですか?

今回は繰り返し実験できるようにディレクトリごとコードを公開しました!
モデルの保存や、出力結果の記録をクラス単位で実装しています。よかったら参考にしてください。

なかなか上手に書けない部分もあるので、ご教授いただければ幸いです。。。笑
というかまだまだ絶賛修正中なので温かく見守ってくださいw

URLはこちら github.com

コードの概要

全体的にはこんなイメージです。 f:id:zerebom:20190518110310j:plain

主な機能は以下のようになっています。
* main.pyを走らせると自動で、loss関数のグラフ、出力画像を自動作成
* batch size training rateなどのハイパーパラメータははmain.py の引数で渡せる
* データをジェネレーターで読み込むのでデータ量が多くなってもメモリエラーにならない
* 結果を出力するディレクトリに使用したModelの名前とlossの値を使用する(ので見やすい)

main.py

import いろいろ(省略)

INPUT_SIZE = (256, 256)
CONCAT_LEFT_RIGHT=True
CHANGE_SLIDE2_FILL = True
def train(parser):
    configs = json.load(open('./settings.json'))
    reporter = Reporter(parser=parser)
    loader = Loader(configs['dataset_path2'], parser.batch_size)
    
    if CHANGE_SLIDE2_FILL:
        loader.change_slide2fill()
        reporter.add_log_documents('Done change_slide2fill.')

    if CONCAT_LEFT_RIGHT:
        loader.concat_left_right()
        reporter.add_log_documents('Done concat_left_right.')


    train_gen, valid_gen, test_gen = loader.return_gen()
    train_steps, valid_steps, test_steps = loader.return_step()

    # ---------------------------model----------------------------------

    input_channel_count = parser.input_channel
    output_channel_count = 3
    first_layer_filter_count = 32

    network = UNet(input_channel_count, output_channel_count, first_layer_filter_count)
    model = network.get_model()

    model.compile(optimizer='adam', loss='mse')
    model.summary()

    # ---------------------------training----------------------------------
    batch_size = parser.batch_size
    epochs = parser.epoch

    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # fit_generatorのコールバック関数の指定・TensorBoardとEarlyStoppingの指定

    logdir = os.path.join('./logs', dt.today().strftime("%Y%m%d_%H%M"))
    os.makedirs(logdir, exist_ok=True)
    tb_cb = TensorBoard(log_dir=logdir, histogram_freq=1, write_graph=True, write_images=True)

    es_cb = EarlyStopping(monitor='val_loss', patience=parser.early_stopping, verbose=1, mode='auto')

    print("start training.")
    # Pythonジェネレータ(またはSequenceのインスタンス)によりバッチ毎に生成されたデータでモデルを訓練します.
    history = model.fit_generator(
        generator=train_gen,
        steps_per_epoch=train_steps,
        epochs=epochs,
        validation_data=valid_gen,
        validation_steps=valid_steps,
        # use_multiprocessing=True,
        callbacks=[es_cb, tb_cb])

    print("finish training. And start making predict.")

    train_preds = model.predict_generator(train_gen, steps=train_steps, verbose=1)
    valid_preds = model.predict_generator(valid_gen, steps=valid_steps, verbose=1)
    test_preds = model.predict_generator(test_gen, steps=test_steps, verbose=1)

    print("finish making predict. And render preds.")

    # ==========================report====================================
    reporter.add_val_loss(history.history['val_loss'])
    reporter.add_model_name(network.__class__.__name__)
    reporter.generate_main_dir()
    reporter.plot_history(history)
    reporter.save_params(parser, history)

    input_img_list = []
    # reporter.plot_predict(train_list, Left_RGB, Right_RGB, train_preds, INPUT_SIZE, save_folder='train')
    reporter.plot_predict(loader.train_list, loader.Left_slide, loader.Left_RGB,
                          train_preds, INPUT_SIZE, save_folder='train')
    reporter.plot_predict(loader.valid_list, loader.Left_slide, loader.Left_RGB,
                          valid_preds, INPUT_SIZE, save_folder='valid')
    reporter.plot_predict(loader.test_list, loader.Left_slide, loader.Left_RGB,
                          test_preds, INPUT_SIZE, save_folder='test')
    model.save("model.h5")


def get_parser():
    parser = argparse.ArgumentParser(
        prog='generate parallax image using U-Net',
        usage='python main.py',
        description='This module generate parallax image using U-Net.',
        add_help=True
    )

    parser.add_argument('-e', '--epoch', type=int,
                        default=200, help='Number of epochs')
    parser.add_argument('-b', '--batch_size', type=int,
                        default=32, help='Batch size')
    parser.add_argument('-t', '--trainrate', type=float,
                        default=0.85, help='Training rate')
    parser.add_argument('-es', '--early_stopping', type=int,
                        default=20, help='early_stopping patience')

    parser.add_argument('-i', '--input_channel', type=int,
                        default=7, help='input_channel')

    parser.add_argument('-a', '--augmentation',
                        action='store_true', help='Number of epochs')

    return parser


if __name__ == '__main__':
    parser = get_parser().parse_args()
    train(parser)

ディレクトリのパスはsetting.jsonで一括管理しています。
trainという巨大な関数にargparserで引数を渡して、ハイパーパラメータを用いています。
自分の研究では、入力、教師データどちらにも画像を使うため独自のジェネレータを作成しています。

repoter.py

importあれこれ

class Reporter:
    ROOT_DIR = "Result"
    IMAGE_DIR = "image"
    LEARNING_DIR = "learning"
    INFO_DIR = "info"
    MODEL_DIR = "model"
    PARAMETER = "parameter.txt"
    IMAGE_PREFIX = "epoch_"
    IMAGE_EXTENSION = ".png"
    
    def __init__(self, result_dir=None, parser=None):
        self._root_dir = self.ROOT_DIR
        self.create_dirs()
        self.parameters = list()
    # def make_main_dir(self):

    def add_model_name(self, model_name):
        if not type(model_name) is str:
            raise ValueError('model_name is not str.')

        self.model_name = model_name
    def add_val_loss(self, val_loss):
        self.val_loss = str(round(min(val_loss)))

    def generate_main_dir(self):
        main_dir = self.val_loss + '_' + dt.today().strftime("%Y%m%d_%H%M") + '_' + self.model_name
        self.main_dir = os.path.join(self._root_dir, main_dir)
        os.makedirs(self.main_dir, exist_ok=True)

    def create_dirs(self):
        os.makedirs(self._root_dir, exist_ok=True)

    def plot_history(self,history,title='loss'):
        # 後でfontsize変える
        plt.rcParams['axes.linewidth'] = 1.0  # axis line width
        plt.rcParams["font.size"] = 24  # 全体のフォントサイズが変更されます。
        plt.rcParams['axes.grid'] = True  # make grid
        plt.plot(history.history['loss'], linewidth=1.5, marker='o')
        plt.plot(history.history['val_loss'], linewidth=1., marker='o')
        plt.tick_params(labelsize=20)

        plt.title('model loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.legend(['loss', 'val_loss'], loc='upper right', fontsize=18)
        plt.tight_layout()

        plt.savefig(os.path.join(self.main_dir, title + self.IMAGE_EXTENSION))
        if len(history.history['val_loss'])>=10:
            plt.xlim(10, len(history.history['val_loss']))
            plt.ylim(0, int(history.history['val_loss'][9]*1.1))

        plt.savefig(os.path.join(self.main_dir, title +'_remove_outlies_'+ self.IMAGE_EXTENSION))

    def add_log_documents(self, add_message):
        self.parameters.append(add_message)


    
    def save_params(self,parser,history):
        
        #early_stoppingを考慮
        self.parameters.append("Number of epochs:" + str(len(history.history['val_loss'])))
        self.parameters.append("Batch size:" + str(parser.batch_size))
        self.parameters.append("Training rate:" + str(parser.trainrate))
        self.parameters.append("Augmentation:" + str(parser.augmentation))
        self.parameters.append("input_channel:" + str(parser.input_channel))
        self.parameters.append("min_val_loss:" + str(min(history.history['val_loss'])))
        self.parameters.append("min_loss:" + str(min(history.history['loss'])))

        # self.parameters.append("L2 regularization:" + str(parser.l2reg))
        output = "\n".join(self.parameters)
        filename=os.path.join(self.main_dir,self.PARAMETER)

        with open(filename, mode='w') as f:
            f.write(output)

    @staticmethod
    def get_concat_h(im1, im2):
        dst = Image.new('RGB', (im1.width + im2.width, im1.height))
        dst.paste(im1, (0, 0))
        dst.paste(im2, (im1.width, 0))
        return dst

    def plot_predict(self, img_num_list, Left_RGB, Right_RGB, preds, INPUT_SIZE, max_output=20,save_folder='train'):
        if len(img_num_list) > max_output:
            img_num_list=img_num_list[:max_output]
        for i, num in enumerate(img_num_list):
            if i == 1:
                print(preds[i].astype(np.uint8))
                        
            pred_img = array_to_img(preds[i].astype(np.uint8))
           
            train_img = load_img(Left_RGB[num], target_size=INPUT_SIZE)
            teach_img = load_img(Right_RGB[num], target_size=INPUT_SIZE)
            concat_img = self.get_concat_h(train_img, pred_img)
            concat_img = self.get_concat_h(concat_img, teach_img)
            os.makedirs(os.path.join(self.main_dir,save_folder), exist_ok=True)
            array_to_img(concat_img).save(os.path.join(self.main_dir, save_folder, f'pred_{num}.png'))

データの保存を担っています。
Kerasではfit関数を動かすとその返り値にhistoryオブジェクトという出力のログが入ったインスタンスを返します。
これと、parserをmain.pyから受け取ってデータを保存しています。
保存先はResult dir以下に、使用パラメータ・出力結果・lossグラフなどをまとめて格納します。

f:id:zerebom:20190518112328p:plain

loader.py

import あれこれ

config = json.load(open('./settings.json'))


class Loader(object):
    # コンストラクタ
    def __init__(self, json_paths, batch_size, init_size=(256, 256)):
        self.size = init_size
        self.DATASET_PATH = json_paths
        self.add_member()
        self.batch_size = batch_size



    # def define_IO(self):
    def add_member(self):
        """
        jsonファイルに記載されている、pathをクラスメンバとして登録する。
        self.Left_RGBとかが追加されている。
        """
        for key, path in self.DATASET_PATH.items():
            setattr(self, key, glob.glob(os.path.join(path, '*png')))
    
    #左右の画像を結合してデータを二倍にする
    def concat_left_right(self):
        self.Left_slide += self.Right_slide
        self.Left_RGB += self.Right_RGB
        self.Left_disparity += self.Left_disparity
        self.Right_disparity += self.Right_disparity
        self.Left_bin += self.Left_bin
        self.Right_bin += self.Right_bin
        print('Done concat_left_right.')
    
    #入力で使う画像を平均値で埋めた画像にする
    def change_slide2fill(self):
        self.Left_slide = self.Left_fill
        self.Right_slide = self.Right_fill


    def return_gen(self):
        self.imgs_length = len(self.Left_RGB)
        # self.train_paths = (self.Left_slide, self.Right_disparity, self.Left_disparity)
        # sel = self.Left_RGB
        self.train_list, self.valid_list, self.test_list = self.train_valid_test_splits(self.imgs_length)
        self.train_steps = math.ceil(len(self.train_list) / self.batch_size)
        self.valid_steps = math.ceil(len(self.valid_list) / self.batch_size)
        self.test_steps = math.ceil(len(self.test_list) / self.batch_size)
        self.train_gen = self.generator_with_preprocessing(self.train_list, self.batch_size)
        self.valid_gen = self.generator_with_preprocessing(self.valid_list, self.batch_size)
        self.test_gen = self.generator_with_preprocessing(self.test_list, self.batch_size)
        return self.train_gen, self.valid_gen, self.test_gen

    def return_step(self):
        return self.train_steps, self.valid_steps, self.test_steps

    @staticmethod
    def train_valid_test_splits(imgs_length: 'int', train_rate=0.8, valid_rate=0.1, test_rate=0.1):
        data_array = list(range(imgs_length))
        tr = math.floor(imgs_length * train_rate)
        vl = math.floor(imgs_length * (train_rate + valid_rate))

        random.shuffle(data_array)
        train_list = data_array[:tr]
        valid_list = data_array[tr:vl]
        test_list = data_array[vl:]

        return train_list, valid_list, test_list

    def load_batch_img_array(self, batch_list, prepro_callback=False,use_bin=True):
        teach_img_list = []
        input_img_list = []
        for i in batch_list:
            LS_img = img_to_array(
                load_img(self.Left_slide[i], color_mode='rgb', target_size=self.size)).astype(np.uint8)
            LD_img = img_to_array(
                load_img(self.Left_disparity[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8)
            RD_img = img_to_array(
                load_img(self.Right_disparity[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8)

            if use_bin:
                LB_img = img_to_array(
                    load_img(self.Left_bin[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8)
                # LB_img=np.where(LB_img==255,1,LB_img)

                RB_img = img_to_array(
                    load_img(self.Right_bin[i], color_mode='grayscale', target_size=self.size)).astype(np.uint8)
                # RB_img = np.where(RB_img == 255, 1, RB_img)

                input_img = np.concatenate((LS_img, LD_img, RD_img, LB_img, RB_img), 2).astype(np.uint8)
            else:
                input_img=np.concatenate((LS_img, LD_img, RD_img), 2).astype(np.uint8)


            teach_img = img_to_array(
                load_img(self.Left_RGB[i], color_mode='rgb', target_size=self.size)).astype(np.uint8)
               
            input_img_list.append(input_img)
            teach_img_list.append(teach_img)

        return np.stack(input_img_list), np.stack(teach_img_list)

    def generator_with_preprocessing(self, img_id_list, batch_size):#, *input_paths
        while True:
            for i in range(0, len(img_id_list), batch_size):
                batch_list = img_id_list[i:i + batch_size]
                batch_input, batch_teach = self.load_batch_img_array(batch_list)

                yield(batch_input, batch_teach)

class DataSet:
    pass

Data dirからデータ(画像)を読み取ってmain.pyにジェネレータ形式で渡します。
このコードは特にまだまだ改善の余地があります...

実験ごとに入力チャンネル数を変えたいのですが、
ジェネレータに読みだした後、それらを結合するとことが難しく、悩んでいます。

jsonからディレクトリのパスを受け取って、その直下の画像ファイルをすべてクラスメンバにして
読み込むようにしているのがおしゃれポイントです

        for key, path in self.DATASET_PATH.items():
            setattr(self, key, glob.glob(os.path.join(path, '*png')))

おわりに

ザックリですが以上になります!
わからないところや修正したほうがいいと思う部分がありましたら、連絡いただけたら幸いです!

今後は他の人でも使えるように、どんなタスクでも使えるように、調整して再度公開したいです。

就職したときにも、恥ずかしくないようにきれいで再利用性の高いコードをかけるように頑張っていきたいです!
では~

google-site-verification: google1c6f931fc8723fac.html