Higu`s diary

新米データサイエンティストのブログ。技術についてゆるく書きます〜

Pytorch-lightning+Hydra+wandbで作るNN実験レポジトリ

Kaggle Advent Calender2020の 11日目の記事です。

昨日はhmdhmdさんのこちらの記事です! 2020年、最もお世話になった解法を紹介します - Qiita

明日はarutema47さんの記事です! (後ほどリンクはります)

本記事では、深層学習プロジェクトで使用すると便利なライブラリ、
Pytorch-lightningとHydraとwandb(Weights&Biases)について紹介したいと思います。

f:id:zerebom:20201211164010p:plain

対象読者

  • Pytorchのボイラープレートコードを減らせないか考えている
  • 下記ライブラリについては聞いたことあるけど、試すのは億劫でやってない

書いてあること

  • 各ライブラリの役割と簡単な使い方
  • 各ライブラリを組み合わせて使う方法
  • 各ライブラリのリファレンスのどこを読めばよいか、更に勉強するにはどうすればよいか

また、上記3つのライブラリを使用したレポジトリを用意しました。 ブログと一緒に見ていただくとわかりやすいかと思います! github.com

はじめに各ライブラリを個別に解説し、次に上記レポジトリに注釈を入れながら説明したいと思います。

Pytorch-lightning

www.pytorchlightning.ai

概要

Pytorch-lightningはPytorchの軽量ラッパーです。 ボイラープレートコードを排除しつつ、可読性を向上させることが出来ます。

特徴(Lightning Philosophy)

公式GitHubを見ると以下の原則を念頭に置いて設計されているようです。

  1. 最大限の柔軟性を持てる
  2. ボイラープレートコードを抽象化しつつ、必要があればアクセスできる
  3. システムに必要なものをすべてを保持し、自己完結できる
  4. 以下の4要素に分割し、整理する
    1. 研究コード(Lightning Module)
    2. エンジニアリングコード(Trainer)
    3. 非必須の研究コード(Callback)
    4. データ (Pytorch DataloaderかLightningDataModule)

どのようなモデルを実装しても似たインターフェースになり、他人(≒過去の自分)のコードでもすぐに理解できる。 MultiGPUやTPUでも実装の変更が殆どないといった点が素晴らしいと思います。

Install

pip install pytorch-lightning

使い方

上記Lightning Philosophyにあるように、4つのパートについてそれぞれ説明していきます。

Trainer

Trainerは後述するLightningModule, Callbacks, DataModuleを引数にとり、Training loop を司るクラスです。いわゆる親玉。
CPU・MultiGPU・TPUなどの実行環境や、デバックモードやエポック数など、
細かい設定を引数に取ることが出来ます。

例)

# 関数にconfig(cfg)を渡すデコレータ(後述)
@hydra.main(config_path='../config', config_name='pix2pix')
def main(cfg):
    
    # モデルの動的呼び出し
    model = hydra.utils.instantiate(cfg.model.instance,hparams=hparams, cfg=cfg)
    
    dm = DataModule(cfg)
    dm.setup()
    
    trainer = pl.Trainer(
            logger = wandb_logger,
            checkpoint_callback=model_checkpoint,
            callbacks=[lr_logger,early_stopping,wandb_callback],
            **cfg.trainer.args,
        )
    
 # 学習開始
    trainer.fit(model, dm)

どのような引数を取ることができるかは下記リファレンスに記載されています。
trainer - PyTorch Lightning 1.0.8 documentation

LightningModule

参考: LightningModule - PyTorch Lightning 1.0.8 documentation

モデルの学習に関するデータフローなどを記載するクラス。
親クラスにtorch.nn.Module を継承しており、高機能な nn.Module と捉えるとわかりやすいです。
各フェーズごとにメソッドが定義されており、その中に処理を書くことで実行されます。

また、Trainerにわたすことで実行環境に応じたコードを自動で実行してくれます。
つまり、x.cuda()x.to(device) を呼び出したり、
DataloaderDistributedSampler(data) を渡す必要がなくなります。

生えてるメソッドの説明

  • forward
    nn.Module と同じ

  • XX_step
    引数 batch にdataloaderの中身が格納されているので、 これをモデルに通して、誤差を計算します。 Loggerやtqdm barに渡したい値はself.logに渡すことで記録されます。

  • XX_epoch_end
    epochが終わったときにどんな処理を行うかを書きます。 epoch間のlossやmetricの平均値などを記録すると良いと思います。

  • configure_optimizers
    optimizer,schedulerの初期化をします。

例)

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
                
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
    
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding
        
        # dataloaderの返り値, indexが格納されている
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
                
       # 追跡したい値はself.log取ることができる
        self.log('train_loss', loss)
        return loss
        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Callbacks

参考:Lightning in 2 steps - PyTorch Lightning 1.0.8 documentation

学習に非必須な処理を書くクラス。 Logger , EarlyStopping LearningRateScheduler など。

一般的によく使われる処理はすでに用意されています。
自分で書く場合は、LightningModuleと同じメソッドが生えているので、 各フェーズで何をするかを書きます。
こちらもTrainerに引数として渡します。

例)

class DecayLearningRate(pl.Callback)

    def __init__(self):
        self.old_lrs = []

    def on_train_start(self, trainer, pl_module):
        # track the initial learning rates
        for opt_idx in optimizer in enumerate(trainer.optimizers):
            group = []
            for param_group in optimizer.param_groups:
                group.append(param_group['lr'])
            self.old_lrs.append(group)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        for opt_idx in optimizer in enumerate(trainer.optimizers):
            old_lr_group = self.old_lrs[opt_idx]
            new_lr_group = []
            for p_idx, param_group in enumerate(optimizer.param_groups):
                old_lr = old_lr_group[p_idx]
                new_lr = old_lr * 0.98
                new_lr_group.append(new_lr)
                param_group['lr'] = new_lr
             self.old_lrs[opt_idx] = new_lr_group

DataModule

データにまつわる全てのコードを集約するためのクラスです。 DatasetとDataLoaderの呼び出しをします。 (Pytorch-lightning 1.0以前はデータにまつわる処理も LightningModule に書く必要がありました。)

例)

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)

pytorch-lightning.readthedocs.io

Hydra

hydra.cc

概要

Facebook Open Sourceから公開されている設定管理ツールです。 YAMLやDataclassから設定変数を階層構造をもたせて動的に呼び出すことができます。

特徴

  • 複数のソースから階層的に設定を構築できる
  • コマンドラインから設定の指定や上書きが可能
  • 1つのコマンドで複数のジョブを実行できる

Install

pip install hydra-core --upgrade

使い方

最小限の使い方

configファイルにyaml形式で設定を書きます。

# config.yaml
db:
  driver: mysql
  user: omry
  pass: secret

参考:YAML Syntax - Ansible Documentation

設定を呼び出す側には呼び出したい関数に @hydra.main デコレータを渡します。 各要素にはドットノーテーションでアクセスできます。

# main.py
@hydra.main(config_path="./config",config_name="config")
def my_app(cfg : DictConfig) -> None:
    print(cfg)
        print(cfg.db.driver)

Config の構造化

configディレクトリを階層構造にすることで、configも階層構造でもたせる事が可能になります。 例えば、以下のようにディレクトリを構築します。

config
├── config.yaml
├── data
│   ├── cifar10.yaml
│   └── mnist.yaml
└── model
    ├── resnet_18.yaml
    └── resnet_50.yaml

そしてrootとなる config.yaml を下記のように記載すると、 data, modelそれぞれ指定したyamlファイルを読み込んでくれます。

# config.yaml
defaults:
  - data: cifar10
  - model: resnet_18

呼び出される側の設定ファイルの1行目に # @package _group_ と記載すると、 そのディレクトリ名経由でドットノーテーションアクセスできます。

# @package _group_
shape: [1,28,28]
batch_size: 8

./config/data/default_data.yaml に書いたなら、cfg.data.shape とアクセスできます。
非常に便利な機能ですが、どの単位でフォルダを分割するかが結構難しいです...!
良いアイディアがあったら教えてほしいです🙇

コマンドラインでの上書き

コマンドライン引数にわたすことで呼び出すファイルや設定を上書きできます。

# 設定ファイルの入れ替え
python train.py data=cifar10
# 値の変更
python train.py trainer.min_epoch=100

インスタンスの動的呼び出し

参考: Instantiating objects with Hydra | Hydra

Hydraの強力な機能にオブジェクトの動的呼び出しがあります。 _target_ に呼び出したいオブジェクトを指定し、引数を列挙することで動的に呼び出すことが出来ます。

# ./config/callbacks/default_callbacks.yaml

# @package _group_
EarlyStopping:
    # 呼び出したいオブジェクト名
  _target_: pytorch_lightning.callbacks.EarlyStopping
    # 第2要素以降はオブジェクトの引数
  monitor: ${trainer.metric}
  mode: ${trainer.mode}
# 呼び出し側。コード内のローカル変数を引数にしたい場合はキーワード引数で渡す。
model_checkpoint = instantiate(cfg.callbacks.ModelCheckpoint,patience=patience)

ユニットテストやnotebookで呼び出す方法

initialize_config_dir メソッドを使うことで呼び出すことが出来ます。

from hydra.experimental import initialize_config_dir, compose
with initialize_config_dir(config_dir=config_dir):
    cfg = compose(config_name=config_name)

Wandb

概要

www.wandb.com

Wandbは機械学習プロジェクトの実験のトラッキング、ハイパーパラメータの最適化、
モデルやデータのバージョンニングを行うライブラリです。

install・会員登録

pip install wandb 公式サイトの右上のLoginリンクからsign upすれば使えます。 GithubGoogleアカウントを連携すればOKです。

Weights & Biases - Developer tools for ML

特徴

大きく分けて下記の機能があるようです。 (自分はまだDashboardくらいしか使っておりません🙇)

  • Dashboard: 数行のコード追加で実験logを記録
  • Sweeps: 複数の条件のモデルを一度に実行
  • Artifacts: モデルやデータなどをバージョンニングフォルダのように管理
  • Reports: 自分の実験を他者に見やすく公開

複数のPCの実験結果をWebブラウザから一括で見られるのは非常に便利だと感じました。

使い方

最小限の使い方

主にPytorch-lightningで使用する方法について説明します。
参考: PyTorch LightningとWandbの連動方法

使い方は非常に簡単で、 pytorch_lightning.loggers.WandbLogger を呼び出して、Trainer に渡すだけです。 tagsやnameを指定すると、project内の実験をソートしたり条件を絞るのに便利。

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(
        name="ResNet18-cifar10",
        project=”ImageCrassfication”,
        tags=["ResNet18","cifar10"])

trainer = pl.Trainer(logger=wandb_logger,**trainer.args)

f:id:zerebom:20201211171805p:plain

Traceしたい値は、LightningModuleで self.log に指定するだけで自動で追加してくれます。

def training_step(self, batch, batch_idx):
        loss, k_dice, c_dice = self.calc_loss_and_dice(batch)

        self.log("k_dice", k_dice)
        self.log("c_dice", c_dice)
        return loss

f:id:zerebom:20201211171741p:plain

さらに、 watch メソッドを呼ぶだけで、モデルの重みの分布を記録してくれます。 wandb_logger.watch(model, log='gradients', log_freq=100)

f:id:zerebom:20201211170457p:plain

画像の出力

GANのGeneratorの生成結果や、Segmentの結果などを出力する機能もあります。 Pytorch-lightningで使用するには Callback を書き、Trainerに渡すことで保存できます。

class ImageSegmentationLogger(Callback):
    def __init__(self, val_samples, num_samples=8,log_interval=5):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        self.log_interval = log_interval

    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)

        #[B,C,Z,Y,X]
        pred_probs = pl_module(val_imgs)
        #[B,Z,Y,X] -> [B,Y,X]
        preds = torch.argmax(pred_probs, 1)[:,0,...].cpu().numpy()
        val_labels =torch.argmax(val_labels, 1) [:,0,...].cpu().numpy()

        class_labels = {
            0: "gd",
            1: "kidney",
            2: "cancer",
            3: "cyst"
        }

        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, masks={
                "predictions": {
                    "mask_data": pred,
                    "class_labels": class_labels
                },
                "groud_truth": {
                    "mask_data": y,
                    "class_labels": class_labels
                }
            })
                for x, pred, y in zip(val_imgs[:self.num_samples],
                                    preds[:self.num_samples],
                                    val_labels[:self.num_samples])]
        })

f:id:zerebom:20201211173456p:plain

サンプルレポジトリの解説

GitHub - zerebom/hydra-pl-wandb-sample-project

config dir

実験パラメータは ./config/default_config.yamlを通して渡されます。

├── config
│   ├── callbacks
│   │   └── default_callbacks.yaml
│   ├── data
│   │   └── default_data.yaml
│   ├── default_config.yaml
│   ├── env
│   │   └── default_env.yaml
│   ├── logger
│   │   └── wandb_logger.yaml
│   ├── model
│   │   └── autoencoder.yaml
│   └── trainer
│       └── default_trainer.yaml

train.py

実行する train.py は基本的にpl.Trainerの組み立てだけ行います。

@hydra.main(config_path='../config', config_name='default_config')
def train(cfg: DictConfig) -> None:
    model = instantiate(cfg.model.instance,cfg=cfg)

    dm = DataModule(cfg.data)
    dm.setup()

    wandb_logger = instantiate(cfg.logger)
    wandb_logger.watch(model, log='gradients', log_freq=100)

    early_stopping = instantiate(cfg.callbacks.EarlyStopping)
    model_checkpoint = instantiate(cfg.callbacks.ModelCheckpoint)
    wandb_image_logger = instantiate(cfg.callbacks.WandbImageLogger,
                            val_imgs=next(iter(dm.val_dataloader()))[0])

    trainer = pl.Trainer(
        logger = wandb_logger,
        checkpoint_callback = model_checkpoint,
        callbacks=[early_stopping,wandb_image_logger],
        **cfg.trainer.args,
    )

    trainer.fit(model, dm)

それぞれのパーツは ./src/factory dirに入っています。

└── src
    ├── factory
    │   ├── dataset.py
    │   ├── logger.py
    │   └── networks
    │       └── autoencoder.py

output dir

poetry run python train.py を実行すると、output dirが作成されます。
ここに重みやlogが格納されます。

output
└── sample-project
    └── simple-auto-encoder-1

output dir の出力先は default_config.yaml で設定できます。

hydra:
  run:
    dir: ${env.save_dir}

# env.save_dir: ${env.root_dir}/output/${project}/${name}-${version}

おわりに

ここまで読んでいただきありがとうございます!
(スコープを広げすぎてとっ散らかってしまった気がする...)

実験管理ディレクトリの作成はなかなか奥が深く、人によって個性が出ると思います。
exp dirを作って1実験1ファイルで管理する方法も人気が高いみたいですね。今度そのような方法にも挑戦してみたいです。

今後、もっと複雑なGANモデルを実装して公開したいと思ってます。
(このGANを書き直す予定)
lucidrains/lightweight-gan

参考文献

レポジトリ

Pytorch-lightning

Hydra

Wandb

  • wandb Gallery
    WandbのReport機能を使って公開された実験のギャラリー
    The Gallery by Weights & Biases

  • pytorch-lightning with wandb
    上のGalleryの中でPytorch-lightningを一緒に使っているライブラリ
    PyTorch Lightning

google-site-verification: google1c6f931fc8723fac.html