Higu`s diary

新米データサイエンティストのブログ。技術についてゆるく書きます〜

【個人開発】Railsで近くのラーメンを1タップで探せるiOSアプリ「ちかめん」のAPIを作った話

まえがき

近くのラーメンを1タップで探せるiOSアプリ「ちかめん」を作っています💪
まだまだ友人:iOS, 自分:バックエンドと2人で協力して制作中ですが
一旦、本記事では自分の担当であるバックエンド側のAPI開発備忘録を紹介します。
個人開発をしてみたいと思っている方の参考になればと思います!

想定している読者

  • プログラミングはやったことあるけどWeb開発はやったことない人
  • これからWebアプリケーション作成に挑戦してみたい人

どんなAPI

下記のようにデータを返します(Gifクリックで拡大します)

Image from Gyazo

URLは下記のような構成になっています。

           Prefix Verb    URI Pattern                                                                              Controller#Action
        api_v1_shops_near GET    /api/v1/shops/near(.:format)                                                             api/v1/shops#sort_by_near
      api_v1_shop_reviews GET    /api/v1/shops/:shop_id/reviews(.:format)                                                 api/v1/reviews#index
                          POST   /api/v1/shops/:shop_id/reviews(.:format)                                                 api/v1/reviews#create
       api_v1_shop_review GET    /api/v1/shops/:shop_id/reviews/:id(.:format)                                             api/v1/reviews#show
                          PATCH  /api/v1/shops/:shop_id/reviews/:id(.:format)                                             api/v1/reviews#update
                          PUT    /api/v1/shops/:shop_id/reviews/:id(.:format)                                             api/v1/reviews#update
                          DELETE /api/v1/shops/:shop_id/reviews/:id(.:format)                                             api/v1/reviews#destroy
    api_v1_shop_addresses GET    /api/v1/shops/:shop_id/addresses(.:format)                                               api/v1/addresses#index
                          POST   /api/v1/shops/:shop_id/addresses(.:format)                                               api/v1/addresses#create
      api_v1_shop_address GET    /api/v1/shops/:shop_id/addresses/:id(.:format)                                           api/v1/addresses#show
                          PATCH  /api/v1/shops/:shop_id/addresses/:id(.:format)                                           api/v1/addresses#update
                          PUT    /api/v1/shops/:shop_id/addresses/:id(.:format)                                           api/v1/addresses#update
                          DELETE /api/v1/shops/:shop_id/addresses/:id(.:format)                                           api/v1/addresses#destroy
       api_v1_shop_photos GET    /api/v1/shops/:shop_id/photos(.:format)                                                  api/v1/photos#index
                          POST   /api/v1/shops/:shop_id/photos(.:format)                                                  api/v1/photos#create
        api_v1_shop_photo GET    /api/v1/shops/:shop_id/photos/:id(.:format)                                              api/v1/photos#show
                          PATCH  /api/v1/shops/:shop_id/photos/:id(.:format)                                              api/v1/photos#update
                          PUT    /api/v1/shops/:shop_id/photos/:id(.:format)                                              api/v1/photos#update
                          DELETE /api/v1/shops/:shop_id/photos/:id(.:format)                                              api/v1/photos#destroy
             api_v1_shops GET    /api/v1/shops(.:format)                                                                  api/v1/shops#index
                          POST   /api/v1/shops(.:format)                                                                  api/v1/shops#create
              api_v1_shop GET    /api/v1/shops/:id(.:format)                                                              api/v1/shops#show
                          PATCH  /api/v1/shops/:id(.:format)                                                              api/v1/shops#update
                          PUT    /api/v1/shops/:id(.:format)                                                              api/v1/shops#update
                          DELETE /api/v1/shops/:id(.:format)                                                              api/v1/shops#destroy

主なURLと機能はこんな感じです

  • 緯度経度を入力すると近くのラーメンデータをDBに保存して返す near URL
  • nearで保存したリソースを返すshop, reviews , photos , addresses URL
  • 過去にアクセスされた緯度経度を保存し、再び近隣でアクセスした場合はDBからキャッシュを返す

データ取得はGoogle Map APIから行っています。
このAPIは月々2万円分は無料で使用できるため、必要なデータだけ取得し、
再度必要になる場合は保存したデータから返すような機構にして、リクエスト数を抑えています。

なんで作ったか

RailsやWebアプリケーションの基礎を身につけ、社内でのコミニュケーションを取りやすくするためです。
自分はWeb企業にデータサイエンティストとして入社したのですが、Webアプリの基礎知識がないと、 MLモデル導入のインタフェースなど業務上で齟齬が生まれると感じたからです。
また、社内の勉強会や雑談はWebの知識を前提としていることが多く、そこで話が通じるようになりたいと思っていました。
具体的には下記のような知識を得られると思い作成しました。

  • RailsなどのWebフレームワークの使い方
  • Webアプリケーション開発の全体像
  • GCP, AWSなどのクラウドの使い方
  • チーム開発におけるコミニュケーション方法(Pull Request, issueの作り方等)

どんな実装になってるか

(雑ですが)全体像を示します。
f:id:zerebom:20210602083135p:plain

サーバーをAWSにアップロードしており、URLにアクセスすると必要に応じてGCPまたはDBからデータを取得するようになっています。

DBには下記のような構成でデータが格納されており、上に載せたようなURLでアクセスできます。 f:id:zerebom:20210531085624p:plain

必要なデータはGoogle MAP APIから取得して、これをDBに整形してから格納しています。 このAPIは月2万円分までは無料で使えるため、その限度を超えないようにデータを保存しています。
サーバーの構成などは殆ど下記のURLを参考に作成しました。

作成の流れ

2020/11/10から作成しはじめました。
自分達と友人が使えるアプリがいいねということで、何個か案をだし、つくば市(当時住んでいた)のラーメンを探せるアプリ「つくめん」を作ることにしました。 最初にAPIのURL設計とデータの形式を決め、モックを作ってそれぞれ個別に作業を進めました。
お互い引っ越すタイミングで、「来年つくばにいなくない?」ということに気付き、現在地から近辺のデータを取得する「ちかめん」に変更しました。

実働は30~50時間くらいで、コミットログからどの時期にどれくらい作業していたかがわかります。 f:id:zerebom:20210602075333p:plain

一通り実装が終わったので友人に共有したところ、データが意図しない型、重複、nullになっていたりと穴だらけだったので、
一旦はもっとシンプルな「つくめん」としてアプリをリリースしようと現在作業を進めています。
ちかめんのリリースは少し時間がかかりそうなので、今までやったことを忘れないようにブログを書いた次第です。

どうやって知識をキャッチアップしたか

Railsは日本語で無料の良質な情報がWeb上にたくさんあるので、知りたい情報がなくて困ることはなかったです。

  • Rails tutorialを雑にやる
  • Railsリファレンスを使いながら調べながらすすめる
  • 必要に応じて書籍も確認する

という感じで進めました。
自分のような初心者の方がコードを書く場合、良質な教材がある・気軽に聞ける人がいる言語で書き始めるとよいのかなーと思いました。

得られた知識

得られた知識はこんな感じです。

  • Rails, Rubyの基本文法とそれぞれの強み
  • RSpecを使ったTDD開発
  • クラウド上にサーバーをデプロイするノウハウ

また、自分がバックエンドエンジニアとして働くには以下の事を更に勉強する必要がありそうだと実感できました。

  • 複雑性を避ける設計
  • DB・サーバー間の分離などのインフラ構成・通信
  • データの信頼度の担保
  • デバッグ・追加検証しやすいログの設計

得られた体験

機能追加を話すのは楽しい。実装は大変。

なんてことない機能も、実際に動くものを作るのは想像より遥かに大変でした。
飽きないように、そしてちゃんとユーザーに使ってもらえるように、届けたい価値はなにかを考えてMVPで実装することが大事だと感じました。
このあたりをちゃんと考えると実務のプロダクト開発にも活かせると思うので次作るときは、 下記の本とか参考に意識したいです。🤲

フレームワークの恩恵を得られる構成で作るとラク

RailsActiveRecordにより、基本的なCRUD機能やルーティングを少量のコードで実装することが出来ます。
フレームワークの特性を知り、自分たちが提供したいアプリの機能をそこにマッピングしていくとコスパよく実装できると感じました。

リリースできる品質にするのは難しい

GCPからデータを取ってきてDBに入れるだけでも、データの重複、欠損値などをvalidateするのに苦労しました。
またデプロイ時には開発環境の差異で落ちたり、ネットワークの知識が足りずポートが開いていなかったりといろいろ大変でした汗
事業に成り立たせるにはサービスを落とさないようにしたり、大量のデータをさばいたりと更に考えることがいっぱいなのだなと実感し、世の中のエンジニアに敬意を払いたいと思いました...笑

お世話になったサイト・書籍

railstutorial.jp

prog-8.com

railsguides.jp

qiita.com

終わり

やっぱり動くものができあがるのはすごく楽しいと感じました。
これからWeb開発をしたい!という人の参考になったらうれしいです!
アプリをリリースできたらまた記事を書きたいと思います! では〜

2020年の振り返りと2021年の目標

こんにちは、ひぐです。
年の瀬なので、今年を振り返りたいと思います!

ということで、早速去年の目標を採点してみました。 f:id:zerebom:20201231172704p:plain

凄惨たる結果です...
まあ振り返ることが大事だと思うのでやっていきます。。

できたこと

研究した

学部からテーマを変えた上に、去年はかなり就活に時間を振ってたので
今年は頑張って研究しました。

  • 国際学会オーラル発表 1(研究結果引き継いでまとめただけなのですごくない)
  • 国内学会オーラル発表 1
  • 国際ジャーナル投稿 1 (査読中)
  • 修士論文

ですが、結果を出すのがなかなか難しかったです。
ニューラルネットの研究は再現するのが大変なので大変でした()

MLエンジニアとしての基礎知識を身に着けた

来年度から社会人になるため、今年はまとまった時間を使って、
MLエンジニアに必要な基礎知識をがっつり勉強するぞ!というのがテーマでした。

興味が広範に渡ってしまったためどれも中途半端気味ですが、割と勉強できたと思います。
特に良かった本に☆をつけました

統計・数学

統計検定準1級・1級を取るために勉強しました。
友人と朝8~11時Zoomで輪講できたのが良かったです。(続けたかった)

読んだ本

プログラミング設計

インターンで設計の大切さを知った&研究コードのスパゲッティぷりに、
辟易していたので色々と読みました。

学生っぽいコードから少し脱却できたような気がします。

読んだ本

Webバックエンド

内定者インターンを4ヶ月取り組みました。

あたりを少し理解しました!

Railsを使って友達と個人開発をしているのでこれも卒業までに形にしたいです。

ビジネス

読み物としてどれも面白かったです。 日常生活で特に活かせなかったのが反省。

読んだ本

自己啓発

一人で過ごしたり、課題に取り組む時間が多かったので、 セルフマネジメント的なことは以前より考える時間は増えて、できるようになった気がします。

読んだ本

  • issueから始めよ
  • 金持ち父さん貧乏父さん
  • コンサル1年目が学ぶこと
  • 独学大全☆

心身ともに健康に過ごせた

コロナで全然外に出なくなったので、友達とテニスをたくさんするようになりました。 散歩したり風呂に入ったりでストレスを無理なく調整できるようになった気がします。

インターンのお賃金で部屋のQOLを上げることができました。

そのた

  • vim, alfredなどを使ってPC操作力が上がった
  • Notionで読んだ本をメモするようになった f:id:zerebom:20201231182233p:plain
  • ブログがホットエントリに載った

  • 学費をちゃんと稼いだ

  • テニスのバックハンド・ボレーがちょっと上達
  • AtCoder茶色になった

できなかったこと

データ分析コンペ

目標Kaggle Master!と息巻いた割に、そもそもコンペに参加もせず全然だめでした。 参加コンペ数とかサブミット数とか自分で制御できる目標にするべきでした。

英語

TOEIC 850点以上は未達でした。(むしろ下がった)

英語は筋トレみたいに続けることが大事だと思うので、 スタディサプリに課金して毎日やることにしました。

登壇

コロナでイベントがないから、、、と言い訳して出ませんでした。
LT枠で参加してから考える的な思い切りの良さが大事かも。

来年の目標

社会人1年目なのでキャリアの設計が〜とか考え込みすぎず、 とりあえず興味があることにガンガン取り組んでいきたいです。

  • お仕事がんばる
  • 実現可能な目標を建てて有言実行する
  • イベントやコンペは迷ったら参加
  • (とりあえず)統計検定準1級

おわり。

Pytorch-lightning+Hydra+wandbで作るNN実験レポジトリ

Kaggle Advent Calender2020の 11日目の記事です。

昨日はhmdhmdさんのこちらの記事です! 2020年、最もお世話になった解法を紹介します - Qiita

明日はarutema47さんの記事です! (後ほどリンクはります)

本記事では、深層学習プロジェクトで使用すると便利なライブラリ、
Pytorch-lightningとHydraとwandb(Weights&Biases)について紹介したいと思います。

f:id:zerebom:20201211164010p:plain

対象読者

  • Pytorchのボイラープレートコードを減らせないか考えている
  • 下記ライブラリについては聞いたことあるけど、試すのは億劫でやってない

書いてあること

  • 各ライブラリの役割と簡単な使い方
  • 各ライブラリを組み合わせて使う方法
  • 各ライブラリのリファレンスのどこを読めばよいか、更に勉強するにはどうすればよいか

また、上記3つのライブラリを使用したレポジトリを用意しました。 ブログと一緒に見ていただくとわかりやすいかと思います! github.com

はじめに各ライブラリを個別に解説し、次に上記レポジトリに注釈を入れながら説明したいと思います。

Pytorch-lightning

www.pytorchlightning.ai

概要

Pytorch-lightningはPytorchの軽量ラッパーです。 ボイラープレートコードを排除しつつ、可読性を向上させることが出来ます。

特徴(Lightning Philosophy)

公式GitHubを見ると以下の原則を念頭に置いて設計されているようです。

  1. 最大限の柔軟性を持てる
  2. ボイラープレートコードを抽象化しつつ、必要があればアクセスできる
  3. システムに必要なものをすべてを保持し、自己完結できる
  4. 以下の4要素に分割し、整理する
    1. 研究コード(Lightning Module)
    2. エンジニアリングコード(Trainer)
    3. 非必須の研究コード(Callback)
    4. データ (Pytorch DataloaderかLightningDataModule)

どのようなモデルを実装しても似たインターフェースになり、他人(≒過去の自分)のコードでもすぐに理解できる。 MultiGPUやTPUでも実装の変更が殆どないといった点が素晴らしいと思います。

Install

pip install pytorch-lightning

使い方

上記Lightning Philosophyにあるように、4つのパートについてそれぞれ説明していきます。

Trainer

Trainerは後述するLightningModule, Callbacks, DataModuleを引数にとり、Training loop を司るクラスです。いわゆる親玉。
CPU・MultiGPU・TPUなどの実行環境や、デバックモードやエポック数など、
細かい設定を引数に取ることが出来ます。

例)

# 関数にconfig(cfg)を渡すデコレータ(後述)
@hydra.main(config_path='../config', config_name='pix2pix')
def main(cfg):
    
    # モデルの動的呼び出し
    model = hydra.utils.instantiate(cfg.model.instance,hparams=hparams, cfg=cfg)
    
    dm = DataModule(cfg)
    dm.setup()
    
    trainer = pl.Trainer(
            logger = wandb_logger,
            checkpoint_callback=model_checkpoint,
            callbacks=[lr_logger,early_stopping,wandb_callback],
            **cfg.trainer.args,
        )
    
 # 学習開始
    trainer.fit(model, dm)

どのような引数を取ることができるかは下記リファレンスに記載されています。
trainer - PyTorch Lightning 1.0.8 documentation

LightningModule

参考: LightningModule - PyTorch Lightning 1.0.8 documentation

モデルの学習に関するデータフローなどを記載するクラス。
親クラスにtorch.nn.Module を継承しており、高機能な nn.Module と捉えるとわかりやすいです。
各フェーズごとにメソッドが定義されており、その中に処理を書くことで実行されます。

また、Trainerにわたすことで実行環境に応じたコードを自動で実行してくれます。
つまり、x.cuda()x.to(device) を呼び出したり、
DataloaderDistributedSampler(data) を渡す必要がなくなります。

生えてるメソッドの説明

  • forward
    nn.Module と同じ

  • XX_step
    引数 batch にdataloaderの中身が格納されているので、 これをモデルに通して、誤差を計算します。 Loggerやtqdm barに渡したい値はself.logに渡すことで記録されます。

  • XX_epoch_end
    epochが終わったときにどんな処理を行うかを書きます。 epoch間のlossやmetricの平均値などを記録すると良いと思います。

  • configure_optimizers
    optimizer,schedulerの初期化をします。

例)

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
                
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
    
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding
        
        # dataloaderの返り値, indexが格納されている
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
                
       # 追跡したい値はself.log取ることができる
        self.log('train_loss', loss)
        return loss
        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Callbacks

参考:Lightning in 2 steps - PyTorch Lightning 1.0.8 documentation

学習に非必須な処理を書くクラス。 Logger , EarlyStopping LearningRateScheduler など。

一般的によく使われる処理はすでに用意されています。
自分で書く場合は、LightningModuleと同じメソッドが生えているので、 各フェーズで何をするかを書きます。
こちらもTrainerに引数として渡します。

例)

class DecayLearningRate(pl.Callback)

    def __init__(self):
        self.old_lrs = []

    def on_train_start(self, trainer, pl_module):
        # track the initial learning rates
        for opt_idx in optimizer in enumerate(trainer.optimizers):
            group = []
            for param_group in optimizer.param_groups:
                group.append(param_group['lr'])
            self.old_lrs.append(group)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        for opt_idx in optimizer in enumerate(trainer.optimizers):
            old_lr_group = self.old_lrs[opt_idx]
            new_lr_group = []
            for p_idx, param_group in enumerate(optimizer.param_groups):
                old_lr = old_lr_group[p_idx]
                new_lr = old_lr * 0.98
                new_lr_group.append(new_lr)
                param_group['lr'] = new_lr
             self.old_lrs[opt_idx] = new_lr_group

DataModule

データにまつわる全てのコードを集約するためのクラスです。 DatasetとDataLoaderの呼び出しをします。 (Pytorch-lightning 1.0以前はデータにまつわる処理も LightningModule に書く必要がありました。)

例)

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)

pytorch-lightning.readthedocs.io

Hydra

hydra.cc

概要

Facebook Open Sourceから公開されている設定管理ツールです。 YAMLやDataclassから設定変数を階層構造をもたせて動的に呼び出すことができます。

特徴

  • 複数のソースから階層的に設定を構築できる
  • コマンドラインから設定の指定や上書きが可能
  • 1つのコマンドで複数のジョブを実行できる

Install

pip install hydra-core --upgrade

使い方

最小限の使い方

configファイルにyaml形式で設定を書きます。

# config.yaml
db:
  driver: mysql
  user: omry
  pass: secret

参考:YAML Syntax - Ansible Documentation

設定を呼び出す側には呼び出したい関数に @hydra.main デコレータを渡します。 各要素にはドットノーテーションでアクセスできます。

# main.py
@hydra.main(config_path="./config",config_name="config")
def my_app(cfg : DictConfig) -> None:
    print(cfg)
        print(cfg.db.driver)

Config の構造化

configディレクトリを階層構造にすることで、configも階層構造でもたせる事が可能になります。 例えば、以下のようにディレクトリを構築します。

config
├── config.yaml
├── data
│   ├── cifar10.yaml
│   └── mnist.yaml
└── model
    ├── resnet_18.yaml
    └── resnet_50.yaml

そしてrootとなる config.yaml を下記のように記載すると、 data, modelそれぞれ指定したyamlファイルを読み込んでくれます。

# config.yaml
defaults:
  - data: cifar10
  - model: resnet_18

呼び出される側の設定ファイルの1行目に # @package _group_ と記載すると、 そのディレクトリ名経由でドットノーテーションアクセスできます。

# @package _group_
shape: [1,28,28]
batch_size: 8

./config/data/default_data.yaml に書いたなら、cfg.data.shape とアクセスできます。
非常に便利な機能ですが、どの単位でフォルダを分割するかが結構難しいです...!
良いアイディアがあったら教えてほしいです🙇

コマンドラインでの上書き

コマンドライン引数にわたすことで呼び出すファイルや設定を上書きできます。

# 設定ファイルの入れ替え
python train.py data=cifar10
# 値の変更
python train.py trainer.min_epoch=100

インスタンスの動的呼び出し

参考: Instantiating objects with Hydra | Hydra

Hydraの強力な機能にオブジェクトの動的呼び出しがあります。 _target_ に呼び出したいオブジェクトを指定し、引数を列挙することで動的に呼び出すことが出来ます。

# ./config/callbacks/default_callbacks.yaml

# @package _group_
EarlyStopping:
    # 呼び出したいオブジェクト名
  _target_: pytorch_lightning.callbacks.EarlyStopping
    # 第2要素以降はオブジェクトの引数
  monitor: ${trainer.metric}
  mode: ${trainer.mode}
# 呼び出し側。コード内のローカル変数を引数にしたい場合はキーワード引数で渡す。
model_checkpoint = instantiate(cfg.callbacks.ModelCheckpoint,patience=patience)

ユニットテストやnotebookで呼び出す方法

initialize_config_dir メソッドを使うことで呼び出すことが出来ます。

from hydra.experimental import initialize_config_dir, compose
with initialize_config_dir(config_dir=config_dir):
    cfg = compose(config_name=config_name)

Wandb

概要

www.wandb.com

Wandbは機械学習プロジェクトの実験のトラッキング、ハイパーパラメータの最適化、
モデルやデータのバージョンニングを行うライブラリです。

install・会員登録

pip install wandb 公式サイトの右上のLoginリンクからsign upすれば使えます。 GithubGoogleアカウントを連携すればOKです。

Weights & Biases - Developer tools for ML

特徴

大きく分けて下記の機能があるようです。 (自分はまだDashboardくらいしか使っておりません🙇)

  • Dashboard: 数行のコード追加で実験logを記録
  • Sweeps: 複数の条件のモデルを一度に実行
  • Artifacts: モデルやデータなどをバージョンニングフォルダのように管理
  • Reports: 自分の実験を他者に見やすく公開

複数のPCの実験結果をWebブラウザから一括で見られるのは非常に便利だと感じました。

使い方

最小限の使い方

主にPytorch-lightningで使用する方法について説明します。
参考: PyTorch LightningとWandbの連動方法

使い方は非常に簡単で、 pytorch_lightning.loggers.WandbLogger を呼び出して、Trainer に渡すだけです。 tagsやnameを指定すると、project内の実験をソートしたり条件を絞るのに便利。

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(
        name="ResNet18-cifar10",
        project=”ImageCrassfication”,
        tags=["ResNet18","cifar10"])

trainer = pl.Trainer(logger=wandb_logger,**trainer.args)

f:id:zerebom:20201211171805p:plain

Traceしたい値は、LightningModuleで self.log に指定するだけで自動で追加してくれます。

def training_step(self, batch, batch_idx):
        loss, k_dice, c_dice = self.calc_loss_and_dice(batch)

        self.log("k_dice", k_dice)
        self.log("c_dice", c_dice)
        return loss

f:id:zerebom:20201211171741p:plain

さらに、 watch メソッドを呼ぶだけで、モデルの重みの分布を記録してくれます。 wandb_logger.watch(model, log='gradients', log_freq=100)

f:id:zerebom:20201211170457p:plain

画像の出力

GANのGeneratorの生成結果や、Segmentの結果などを出力する機能もあります。 Pytorch-lightningで使用するには Callback を書き、Trainerに渡すことで保存できます。

class ImageSegmentationLogger(Callback):
    def __init__(self, val_samples, num_samples=8,log_interval=5):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        self.log_interval = log_interval

    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)

        #[B,C,Z,Y,X]
        pred_probs = pl_module(val_imgs)
        #[B,Z,Y,X] -> [B,Y,X]
        preds = torch.argmax(pred_probs, 1)[:,0,...].cpu().numpy()
        val_labels =torch.argmax(val_labels, 1) [:,0,...].cpu().numpy()

        class_labels = {
            0: "gd",
            1: "kidney",
            2: "cancer",
            3: "cyst"
        }

        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, masks={
                "predictions": {
                    "mask_data": pred,
                    "class_labels": class_labels
                },
                "groud_truth": {
                    "mask_data": y,
                    "class_labels": class_labels
                }
            })
                for x, pred, y in zip(val_imgs[:self.num_samples],
                                    preds[:self.num_samples],
                                    val_labels[:self.num_samples])]
        })

f:id:zerebom:20201211173456p:plain

サンプルレポジトリの解説

GitHub - zerebom/hydra-pl-wandb-sample-project

config dir

実験パラメータは ./config/default_config.yamlを通して渡されます。

├── config
│   ├── callbacks
│   │   └── default_callbacks.yaml
│   ├── data
│   │   └── default_data.yaml
│   ├── default_config.yaml
│   ├── env
│   │   └── default_env.yaml
│   ├── logger
│   │   └── wandb_logger.yaml
│   ├── model
│   │   └── autoencoder.yaml
│   └── trainer
│       └── default_trainer.yaml

train.py

実行する train.py は基本的にpl.Trainerの組み立てだけ行います。

@hydra.main(config_path='../config', config_name='default_config')
def train(cfg: DictConfig) -> None:
    model = instantiate(cfg.model.instance,cfg=cfg)

    dm = DataModule(cfg.data)
    dm.setup()

    wandb_logger = instantiate(cfg.logger)
    wandb_logger.watch(model, log='gradients', log_freq=100)

    early_stopping = instantiate(cfg.callbacks.EarlyStopping)
    model_checkpoint = instantiate(cfg.callbacks.ModelCheckpoint)
    wandb_image_logger = instantiate(cfg.callbacks.WandbImageLogger,
                            val_imgs=next(iter(dm.val_dataloader()))[0])

    trainer = pl.Trainer(
        logger = wandb_logger,
        checkpoint_callback = model_checkpoint,
        callbacks=[early_stopping,wandb_image_logger],
        **cfg.trainer.args,
    )

    trainer.fit(model, dm)

それぞれのパーツは ./src/factory dirに入っています。

└── src
    ├── factory
    │   ├── dataset.py
    │   ├── logger.py
    │   └── networks
    │       └── autoencoder.py

output dir

poetry run python train.py を実行すると、output dirが作成されます。
ここに重みやlogが格納されます。

output
└── sample-project
    └── simple-auto-encoder-1

output dir の出力先は default_config.yaml で設定できます。

hydra:
  run:
    dir: ${env.save_dir}

# env.save_dir: ${env.root_dir}/output/${project}/${name}-${version}

おわりに

ここまで読んでいただきありがとうございます!
(スコープを広げすぎてとっ散らかってしまった気がする...)

実験管理ディレクトリの作成はなかなか奥が深く、人によって個性が出ると思います。
exp dirを作って1実験1ファイルで管理する方法も人気が高いみたいですね。今度そのような方法にも挑戦してみたいです。

今後、もっと複雑なGANモデルを実装して公開したいと思ってます。
(このGANを書き直す予定)
lucidrains/lightweight-gan

参考文献

レポジトリ

Pytorch-lightning

Hydra

Wandb

  • wandb Gallery
    WandbのReport機能を使って公開された実験のギャラリー
    The Gallery by Weights & Biases

  • pytorch-lightning with wandb
    上のGalleryの中でPytorch-lightningを一緒に使っているライブラリ
    PyTorch Lightning

データ分析でコードをクリーンに保つ技術

こんにちは、ひぐです。
最近データサイエンティストのための良いコーディング習慣という記事を読みました。
www.thoughtworks.com

こうした方がいいよなという自分の経験則が綺麗に言語化されていてよかったです。
ここではデータ分析でコードをクリーンに保つ技術について、記事の内容と自分の取り組みを合わせて紹介したいと思います。

自分はまだチームでの開発経験などが浅いため、間違っている部分もあるかもしれません。
あらかじめご了承ください汗

コードが汚くなる要因

f:id:zerebom:20200610210812p:plain:w300
コードが解くべき問題の複雑さを増長させている時、そのコードは汚いと言えます。
汚いコードは汚い部屋で探し物をする時などと同じく、簡単な作業を困難にしてしまいます。

では、どのような書き方をするとコードが汚くなるのでしょうか。

元記事には下記のような例が記載されています。

  1. 関数やクラスを使って処理を抽象化しない
  2. 一つの関数に長く複数の処理を書く
  3. ユニットテストを書かず、リファクタリング時に1から書き直す

部屋で例えると、

  • 一つの収納箱にあれこれ詰め込む
  • 物の定位置を決めず、空いているところに収納する
  • 整理してない収納箱を全てひっくり返して、再配置する

といった振る舞いと似てそうです。

処理が1箇所に纏まっていないことや、
どこに何が書いてあるかわからないことが複雑さを冗長させていると言えます。

jupyter notebookはコードを煩雑にしやすい

さらにデータ分析でおなじみのjupyter notebookは

  1. df.head()/describe()などデータの内部を確認できる機能が豊富
  2. 上下のセルから変数の中身が引き継がれる

といった特徴から、プロジェクト序盤は素早いフィードバックを得られて便利ですが、
これらの特徴は裏を返せば

  1. 変数の影響範囲が広くなりやすい
  2. 処理に影響を及ぼさないコードが増えやすい

とも言え、コード量が増えると急速に煩雑になってしまいます。

インテリアデザイナーには「平たい場所は乱雑さを蓄積しやすい」 という通説があるそうですが、
何でも1箇所に書けてしまう"notebook"は、データ分析における平たい場所であると言えます。

良いコードにする振る舞い

では良いコードにするにはどのようにすれば良いのでしょうか。
元記事では下記の5点が紹介されていました。

コードを綺麗に保つ

データ分析に限らず、綺麗なコードを書くセオリーがあります。
たとえば

  • DEAD CODEを消す
  • 処理の内容が明快にわかる変数名をつける
  • 似た記述はまとめる(DRYである)

データ分析も例外ではなく、これらのセオリーには従うべきです。
リーダブルコードなどの書籍にまとめられていて、目を通しておくべきでしょう。

関数を使って実装を抽象化する

一つの関心ごとに対しては一つの関数でまとめ、処理を抽象化するべきです。
そうすることで、以下のメリットが得られます。

  • 読みやすい
  • テストしやすい
  • 再利用しやすい(引数を与えて、ハードコーディングを防げる)

これは例を見てみるとわかりやすいです。

# bad example
pd.qcut(df['Fare'], q=4, retbins=True)[1] # returns array([0., 7.8958, 14.4542, 31.275, 512.3292])


df.loc[ df['Fare'] <= 7.90, 'Fare'] = 0
df.loc[(df['Fare'] > 7.90) & (df['Fare'] <= 14.454), 'Fare'] = 1
df.loc[(df['Fare'] > 14.454) & (df['Fare'] <= 31), 'Fare']   = 2
df.loc[ df['Fare'] > 31, 'Fare'] = 3
df['Fare'] = df['Fare'].astype(int)
df['FareBand'] = df['Fare']

# good example (after refactoring into functions)
df['FareBand'] = categorize_column(df['Fare'], num_bins=4)

good exampleの例であれば、pd.qcutの意味や引数を覚えていなくても、
連続値の'Fare'列をbin詰めする処理ができます。

イメージはこんな感じです笑
f:id:zerebom:20200611102441j:plain:w300 f:id:zerebom:20200611102458j:plain:w300 f:id:zerebom:20200611102508j:plain:w300

引用元(?):
ヘヴィメタバンド、スティール・パンサーのサッチェル氏→機材紹介がハードロック過ぎる - Togetter

なるべく早い段階でjupyter notebookを.pyに移行する

上で言及したように、notebookはコード量に伴い煩雑さが急速に増してしまいます。
したがって、コード量が増えてきたらなるべく早く.pyコードに書き換えるべきです。

notebookからpyファイルに書き換える手順は元記事で下記のように紹介されています。 f:id:zerebom:20200611102550p:plain 引用元: clean-code-ml/refactoring-process.md at master · davified/clean-code-ml · GitHub

  1. notebookが正しく動くか確認する
  2. 自動変換でpyファイルに出力する
  3. debugコードを取り除く(df.head()など)
  4. Code smell(直したい部分)をリストアップする
  5. 一纏めにしたい部分を特定する
  6. ユニットテストを書く
  7. テストを通すように記述する
  8. importを整理する
  9. 動作を確認し、commitする
  10. 繰り返す

テスト駆動開発で行う

データ分析業務もソフトウェア開発の例外にもれず、テストを書くべきです。

例えばモデルの挙動を調べるテストでは、
ベースラインで想定想定スコアを超えない場合はエラーとみなすコードを書くと良いそうです。

テストコードについては自分の知識も浅いので、またいつか改めて記事を書きたいと思います。

こまめなコミットをする

コミットを小さく頻繁に行うことで、下記のメリットが得られます。

  • 自分がどの問題に取り組んでいるか簡単に理解できる
  • 処理のロールバックが簡単にできる

自分なりの工夫点

最後に自分なりのコードを綺麗にする工夫点をいくつか紹介します。

業務ごとにコードをまとめスニペット化する

データ分析では、タスクが変わっても似たような処理を書くことが多いです。
コードをスニペットとして保存しておくと、似たような処理が必要になった時少ない作業量で書き終えることができます。

また、スニペットにすることを意識しながらコードを書くことで
自然と汎用性の高いコードが書けるようになります。

自分はGitHub GistとDashを使ってスニペットを保存しています。

gist.github.com

https://kapeli.com/dash

自分なりのルールを設ける

自分なりのルールを設けて、いつも似たコードを書くようにしています。
そうすることで他のスニペットとの互換性を良くしたり、素早くコードを書くことができます。

また自分はNN系のコードを書くときはhydraとpytorch-lightningを
使うことでいつも同じステップで書けるようにしています。

github.com
hydra.cc

データ分析のコードはあまり高級なラッパーを使うと、すぐ破綻してしまうので
その塩梅が難しいですが、うまく使えば綺麗にかけるでしょう。

メソッド名、I/Oなどを組み込み関数や有名ライブラリに近づける

sklearnやpandasなどの有名なライブラリの入出力と対応させてコードを書くことで、
他者とのコミニュケーションコストを抑えることができます。

綺麗な人のコードを見る

Kaggleなどデータ分析コンペティションでは、上位の人が解法を公開していることが多いので、
それを眺めると良いと思います。

他には、nyanpさんのnyaggleなど参考にさせていただいていますm(__)m GitHub - nyanp/nyaggle: Code for Kaggle and Offline Competitions

まとめ

以上です!
振り返ってみると当たり前のことばかりですが、全部を常に実践するのは難しい...!
綺麗なコードが書けるということはエンジニアの技量としてかなり本質的なものだと思っているので、
今後も頑張っていきたいと思います。

久しぶりにブログを書いたら、文章書くのが難しすぎてびっくりしました。
こっちも頑張っていきたいです。では〜

2019年を振り返りと2020年の目標

こんにちは、ひぐです。

もう年の瀬ですね〜
今年のトピックは大きく分けて、研究、就活、プログラミングの勉強の3つという感じでした。

今年のよかったこと、反省点を踏まえて来年も頑張りたいので、
それぞれまとめていきたいと思います!

概要

自分を4行でまとめる

2019年4月に工学系学部を卒業して同大学大学院に進学。
プログラミングは1年半前に始める。
研究内容は、学部:政治と自然言語処理→院生:深層学習を用いた医療画像における腫瘍の自動識別
21卒で就活中

今年をざっくりまとめる

就活

インターンシップ

下記インターンシップに参加しました。

zerebom.hatenablog.com

参加することで、企業の風土や開発環境を知れるだけでなく、 自分の目標となる先輩や、優秀な同期と出会うことができました。
1月からCAのAI事業本部で長期インターンをさせていただけることになったので、こちらも頑張りたいです。

自己分析

企業との面接でどういうエンジニアになりたい?という質問になかなか答えられず大変でした。 そのため、10月以降は特に自分の将来を真剣に考えました。

多動な人間なのでその時その時でやりたいことは常にたくさんあるのですが、
もっと軸足を定めて考えるべきだったと反省してます。

ざっくりですが、
- 社会的に正しいことをする
- 頭とコミニュケーション両方を使う
- 自分の学びを同業者・同期に還元していく

この3つは自分の中で大きな軸だなぁと思っています。 面接で取り繕って話すことは絶対したくないので、
ちゃんと考えて言葉にできるようにしていきたいです。

研究

学会発表

新規研究

4年次の研究を経て、もっと深層学習ちゃんと勉強したくてテーマを変えました。 大変ですが、後悔はしてないです。

プログラミングの勉強・成果

その他

  • TOEIC 755点になった
  • Twitterのフォロワーが610人になった
  • ブログの月間PVが1500~2500くらいになった
  • 筋トレが4ヶ月くらい続いた、10回*3セットで上がる重量が増えた
    ベンチプレス30->45kg
    デッドリフト40kg->65kg
    スクワット60kg->75kg

今年できたこと

  • たくさん行動する経験する
  • 尊敬できる人に会いにいく
  • 規則正しい生活
  • 人を傷つけない

今年できなかったこと

  • 目標に到達する前にやめてしまった(KaggleとかKaggleとか)
  • ブログ以外のアウトプット(LT/論文投稿など)
  • 部屋をきれいに保つこと
  • 一つ一つを丁寧にこなすこと(手広くやりすぎた)

来年の目標(抽象)

  • 将来の夢を考える(人生の目標)
  • 基礎を固める(線形代数統計学・CS)
  • 専門分野を深く学ぶ

来年の目標(具体)

  • 論文投稿
  • LT登壇
  • TOEIC 850以上
  • Kaggle Master
  • 統計検定1級

まとめ

今年1年間は自分に対して向き合って、たくさん勉強できたなと思っています。
勉強や就活をする上でいろんな人に話を聞きに行ったり、新たな友達ができたのも大きな成果でした。

その一方でエンジニア・院生以外の人とは殆ど会わず、
世間一般から遠ざかったような気持ちもしました。

今年は就活、学校、研究全部あったのでしょうがない部分もありますが、
小粒の成果がいっぱいって感じになってしまいました。 来年以降は一つの目標に対してじっくり取り組んで大きな成果を出していきたいです。

勉強のための勉強ではなく、
なんのために勉強するかも今まで以上にしっかり考えていきたいです。

来年もよろしくお願いします〜

おわり

マイナビ × SIGNATE Student Cup2019に参加して9位でした

こんにちは、ひぐです。

先日マイナビ × SIGNATE Student Cupに参加し、9/342位になりました!
この記事ではどんな取り組みを行ったかを書きたいと思います。
なるべく本コンペに参加してない人にも内容がわかる記事にしたいと思います。

本コンペの基礎情報

マイナビ × SIGNATE Student Cupとは年に1度開かれる、学生のみが参加できるデータ分析のオンライン大会です。
お題とデータが渡され、機械学習を用いて目的変数を予測し精度を競う大会です。

本コンペのテーマは「東京都23区の賃貸物件の家賃予測」です。
各物件に対し、「面積、方角、所在階」などの情報が与えれ、そのデータを元に家賃を予測する、といった内容でした。

f:id:zerebom:20191108215858p:plain
引用:https://signate.jp/competitions/182

データ量はTrain,Testどちらとも3万程度でした。

本コンペの特徴

本コンペのデータは以下4つの特徴があり、これらをうまく取り扱うことが精度向上の鍵になったかと思います。

  • 外れ値がある上に、評価指標がRMSE
    目的変数である賃料は非常に右に裾が長い分布であり、最も高い物件の賃料は250万円もするものでした。
    評価指標がRMSE(二乗平均平方根誤差)であるため、これらの高級物件の誤差をいかに小さくするかが重要でした。

    f:id:zerebom:20191108224222p:plain
    目的変数の分布

  • データが汚い&数値ではなく文字データとして与えられている
    データがすぐに使える形で与えられていなかったため、正規表現等を駆使して情報を取り出す必要がありました。
    また、欠損値や書き間違いも多く含まれ、丁寧に処理する必要がありました。

  • Train,Testで同じようなデータが含まれている
    物件データの中には、同じアパートの別の部屋などが含まれており、 普通に学習するより、学習データと同じ値で埋めるほうが精度が高くなりました。

    f:id:zerebom:20191108224133p:plain
    同じようなデータ群

  • 外部データ使用可能
    このコンペでは外部データの使用が認められていました。
    土地データのオープンデータは非常に多く、どれをどのように使うかが大事になったかと思います。

弊チームの取り組み

最終的なPipelineは以下の通りです。

f:id:zerebom:20191109000508p:plain
Pipeline

基本的に「各物件のデータから推論より、同じ・似た物件データの賃料からキャリブレーションする」 というつもりで進めていました。

コンペ全期間の大まかな流れと精度の変化は以下の通りです。

  • 3人とも個別で学習(17000~16000程度)
  • チームマージしアンサンブル。LogをとってMAEで学習(15000程度)
  • K-meansで近傍データの作成(14500程度)
  • 住所の修正、単位面積あたりの賃料を推定(13000程度)
  • パラメータ調整、SeedAverage(12400程度)

特に住所の修正、単位面積あたりの賃料の推定が大きかったと思います。
順を追って説明します。

Plotlyを用いて予測誤差の原因を追求

簡単に予測モデルを作ってからは、出力誤差をPlotlyを用いて地図上にMapして、どういった物件が誤差が大きいか確認しました。
Plotlyはインタラクティブに描画されるため、ズームしながら一つ一つ確認できました。
詳しくはQiitaに記事を書いたので良ければ見てください↓
qiita.com

同じ住所なのに、違うアパートが含まれていること、
またそういった物件の誤差が大きいことを確認しました。

f:id:zerebom:20191108225251p:plain

欠損値、異常値補完

今回配布されたデータの物件の所在地には
A:「東京都〇〇区××n丁目x-yy」と正確に記載されているデータもあれば、
B:「東京都〇〇区××n丁目」と丁までしか含まれていないデータも多くありました。

これらを注意深く観察すると、同じアパート(賃料・面積などから判断)でも
Aの形で所在地が埋められてるデータもあればBの形のデータもあることがわかりました。

そこで、面積、所在階、室内設備などの複数条件が同じであれば、同一アパートとみなし、
Bの形で所在地が記載されているデータを同じ物件のAの形の所在地に変換しました。
図にするとこんな感じです。
f:id:zerebom:20191109115533p:plain

こうすることで住所や緯度経度をkeyとした集約特徴量が正確な値になり、精度が向上しました。

外部データの使用

今回収集した外部データは以下の通りです。
- 地価データ
- 駅データ
- 路線数
- 1日の利用者数
- 緯度経度

これらから作成した以下のデータは精度向上に寄与しました。
- 物件とその物件から最も近い駅の距離
- 物件から最も近い距離にある公開されている地価
- 上記の地価の2012年から2017年の変化率
- 六本木ヒルズからの距離

K-meansを用いた近傍データの使用

地域によって賃料が全然違うことから近傍データが効くと考えられました。
そこで、緯度、経度、築年数を元にK-meansでクラスタリングし、 このCategorycal変数から以下のような特徴量を作成しました。

  • 同一クラスタ内の平均地価(賃料/面積)
  • 同一クラスタ内の平均地価×自身の面積
  • 同一クラスタ内の平均地価と自身の地価の差分・比率
  • 同一クラスタ内の平均築年数と自身の築年数の差分・比率

差分や比率を入れることで、各物件がクラスタの中でどのような位置付けがわかります。

前処理を丁寧に行なったこともあり、強力な特徴量となりました。 築年数をk-meansの判断材料に入れることにより、より似た性質の物件を同じクラスタに入れることができました。

外れ値に強いモデルの作成

外れ値も外れ値でない値も正確に予測したかったため、
Logをとってmaeで学習をしました。

また、賃料は面積との相関が強かったため、単位面積あたりの賃料を予測するモデルも作成しました。
面積で割り、さらに築年数を考慮したクラスタリングで特徴量を作ることで、
賃料という立地×築年数×面積×その他要員という複雑な変数をモデルに理解させることができたと思っています。

最終予測結果はlightgbm、Kfold、k-meansのシードを1ずつ変えて
30シード×2モデル×10Foldの600個のモデルから作成しました。

k-meansのSeedによって大きく精度が変わってしまっていたのですが平均をとることで
大きくshakedownすることのない頑健な出力結果となりました。

チームでのコミニュケーション方法

チームメンバーはそれぞれ就活や修論で忙しかったため、
それぞれが進められるときに進めて行きました。

Github,Line,Trelloでやりとりを進めていたのですが、特にTrelloが便利でした。 25MB以下のファイルはほぼ無制限に共有できること、各人の取り組んでる内容、進捗状況がすぐにわかったので、非常にスムーズにコミニュケーションを取れました。

f:id:zerebom:20191108233928p:plain
使い倒されるtrello

参考にしたサイト

飯田コンペ上位手法 signate.jp

Lightgbmのパラメータ調整 nykergoto.hatenablog.jp

Kaggle本 Stacking, Validationの考えをしっかり学べました

感想・まとめ

良かった取り組み

  • チームを想定してコードを書いた
    早い段階でチームで関数の書き方にルールを作ったのでコードのマージが楽だった。(引数も返り値もtrain,testをまとめたDataFrameにする等)
    前処理担当、モデル担当、外部データ担当と分けることで責任感を持ちつつ作業ができた。

  • 一度使ったらおしまいのコードを書かないようにした。
    よく使う関数はクラス、関数化した(target_encoding,save_data,lgb_predictorなど)

  • Lightgbmのバージョンを上げる
    なんと精度が上がります

改善するべき取り組み

  • どんなコンペにも対応できる柔軟なPipelineコードを作っておく。
  • 実験のログをもっと綺麗にとる
  • lightgbmに詳しくなる(最後まで気づかなくて、max_depth=8,num_leaves=31とかだった)

まとめ

今までこれほど良い順位でコンペを終えられたことがなかったので嬉しい反面、
入賞する気概で取り組んでいたので9位という結果は非常に悔しいです。

個人で取り組むと、だれてしまったり諦めてしまいがちなコンペもチームでやればモチベーションも上がる上に、
他の人のアイデアから異なるアイデアが浮かんだりと、アンサンブル学習の威力を実感でき、非常に楽しかったです。

最後3日で順位が20位くらい上がったこともあり、停滞期で諦めないことも大事だなと思いました。 (とはいえ、上位の人たちはずっと上位だったので地力の差も感じました)

今後は今回学んだことをしっかり復習してKaggleでメダルを取れるように頑張っていきたいと思います。 また研究や、企業でデータ分析を生かして社会に貢献できるようにも頑張りたいです。

それでは最後までご覧いただきありがとうございました!

よければTwitterのフォローもよろしくお願いします( ^ω^ )

Pythonにおける勾配ブースティング予測モデルをラクチンに作成するラッパーを公開しました

こんばんは、ひぐです。

今回はPythonの勾配ブースティングライブラリを使いやすくしたラッパーについて紹介します。 今回紹介するラッパーを使うと以下のメリットがあります。

  • PandasのDataFrameといくつかの引数を渡すだけで予測結果が返ってくる
  • 本来はそれぞれ使い方が微妙に異なるライブラリを、全く同じ記法で使える
  • 出力した予測値を正解データとすぐに比較できる、可視化メソッドが使える
  • パラメータやValidationの分け方を簡単に指定できる
  • ターゲットエンコーディングが必要な場合、列と関数を渡すことで自動でリークしないように計算してくれる

機械学習を用いたデータ分析で必ず必要になる予測モデルを作成するプロセスですが、
これらをいつも同じ使い方で使用できるのは大きなメリットだと思います!

よければ是非使ってください!

使い方

用意するもの
- train/testデータ(DataFrame)
- ハイパーパラメータ(dict)

まず使用するハイパーパラメータを定義します。

from script import RegressionPredictor
cat_params = {

    'loss_function': 'RMSE',
    'num_boost_round': 5000,
    'early_stopping_rounds': 100,
}

xgb_params = {
        'num_boost_round':5000,
        'early_stopping_rounds':100,
        'objective': 'reg:linear',
        'eval_metric': 'rmse',
    }

lgbm_params = {
    'num_iterations': 5000,
    'learning_rate': 0.05,
    'objective': 'regression',
    'metric': 'rmse',
    'early_stopping_rounds': 100}

そしてインスタンスの呼び出し、学習します。

catPredictor = RegressionPredictor(train_df, train_y, test_df, params=cat_params,n_splits=10, clf_type='cat')
catoof, catpreds, catFIs=catPredictor.fit()


xgbPredictor = RegressionPredictor(train_df, train_y, test_df, params=xgb_params,n_splits=10, clf_type='xgb')
xgboof, xgbpreds, xgbFIs = xgbPredictor.fit()


lgbPredictor = RegressionPredictor(train_df, train_y, test_df, params=lgbm_params,n_splits=10, clf_type='lgb')
lgboof, lgbpreds, lgbFIs = lgbPredictor.fit()

rfPredictor = RegressionPredictor(train_df, train_y, test_df, sk_model=RandomForestRegressor(rf_params), n_splits=10, clf_type='sklearn')
rfoof, rfpreds, rfFIs = rfPredictor.fit()

これだけです!
fitの返り値はそれぞれ、trainの予測データ、testの予測データ、Feature Importanceのnumpy arrayです 。 Kaggleなどのデータ分析の場合、これらをcsvにするだけですぐに提出できるようになります。

予測値についてデータ可視化したい場合は以下のようにします。

lgbPredictor.plot_FI(50)
lgbPredictor.plot_pred_dist()

ソースコード

class RegressionPredictor(object):
    '''
    回帰をKfoldで学習するクラス。
    TODO:分類、多クラス対応/Folderを外部から渡す/predictのプロット/できれば学習曲線のプロット
    '''
    def __init__(self, train_X, train_y, split_y, test_X, params=None, Folder=None, sk_model=None, n_splits=5, clf_type='xgb'):
        self.kf = Folder(n_splits=n_splits)
        self.columns = train_X.columns.values
        self.train_X = train_X
        self.train_y = train_y
        self.test_X = test_X
        self.params = params
        self.oof = np.zeros((len(self.train_X),))
        self.preds = np.zeros((len(self.test_X),))
        if clf_type == 'xgb':
            self.FIs = {}
        else:
            self.FIs = np.zeros(self.train_X.shape[1], dtype=np.float)
        self.sk_model = sk_model
        self.clf_type = clf_type

    @staticmethod
    def merge_dict_add_values(d1, d2):
        return dict(Counter(d1) + Counter(d2))
   
    def rmse(self):
        return int(np.sqrt(mean_squared_error(self.oof, self.train_y)))
    
    def get_model(self):
        return self.model

    def _get_xgb_callbacks(self):
        '''nround,early_stopをparam_dictから得るためのメソッド'''
        nround=1000
        early_stop_rounds=10
        if self.params['num_boost_round']:
            nround=self.params['num_boost_round']
            del self.params['num_boost_round']

        if self.params['early_stopping_rounds']:
            early_stop_rounds=self.params['early_stopping_rounds']
            del self.params['early_stopping_rounds']
        return nround ,early_stop_rounds

    def _get_cv_model(self, tr_X, val_X, tr_y, val_y, val_idx):

        if self.clf_type == 'cat':
            clf_train =Pool(tr_X,tr_y)
            clf_val =Pool(val_X,val_y)
            clf_test =Pool(self.test_X)
            self.model=CatBoost(params=self.params)
            self.model.fit(clf_train,eval_set=[clf_val])
            self.oof[val_idx]=self.model.predict(clf_val)
            self.preds += self.model.predict(clf_test) / self.kf.n_splits
            self.FIs += self.model.get_feature_importance()

        elif self.clf_type == 'lgb':
            clf_train = lgb.Dataset(tr_X, tr_y)
            clf_val = lgb.Dataset(val_X, val_y, reference=lgb.train)
            self.model = lgb.train(self.params, clf_train, valid_sets=clf_val)
            self.oof[val_idx] = self.model.predict(val_X, num_iteration=self.model.best_iteration)
            self.preds += self.model.predict(self.test_X, num_iteration=self.model.best_iteration) / self.kf.n_splits
            self.FIs += self.model.feature_importance()

        elif self.clf_type == 'xgb':
            clf_train = xgb.DMatrix(tr_X, label=tr_y, feature_names=self.columns)
            clf_val = xgb.DMatrix(val_X, label=val_y, feature_names=self.columns)
            clf_test = xgb.DMatrix(self.test_X, feature_names=self.columns)
            evals = [(clf_train, 'train'), (clf_val, 'eval')]
            evals_result = {}

            nround,early_stop_rounds=  self._get_xgb_callbacks()
            self.model = xgb.train(self.params,
                                    clf_train,
                                    num_boost_round=nround,
                                    early_stopping_rounds=early_stop_rounds,
                                    evals=evals,
                                    evals_result=evals_result)

            self.oof[val_idx] = self.model.predict(clf_val)
            self.preds += self.model.predict(clf_test) / self.kf.n_splits
            self.FIs = self.merge_dict_add_values(self.FIs, self.model.get_fscore())

        elif self.clf_type == 'sklearn':
            self.model = self.sk_model
            self.model.fit(tr_X, tr_y)
            self.oof[val_idx] = self.model.predict(val_X)
            self.preds += self.model.predict(self.test_X) / self.kf.n_splits
            self.FIs += self.model.feature_importances_
        else:
            raise ValueError('clf_type is wrong.')

    def fit(self):
        for train_idx, val_idx in self.kf.split(self.train_X, self.train_y):
            X_train = self.train_X.iloc[train_idx, :]
            X_val = self.train_X.iloc[val_idx, :]
            y_train = self.train_y[train_idx]
            y_val = self.train_y[val_idx]
            self._get_cv_model(X_train, X_val, y_train, y_val, val_idx)
        print('this self.model`s rmse:',self.rmse())

        return self.oof, self.preds, self.FIs

    def plot_FI(self, max_row=50):
        plt.figure(figsize=(10, 20))
        if self.clf_type == 'xgb':
            df = pd.DataFrame.from_dict(self.FIs, orient='index').reset_index()
            df.columns = ['col', 'FI']
        else:
            df = pd.DataFrame({'FI': self.FIs, 'col': self.columns})
        df = df.sort_values('FI', ascending=False).reset_index(drop=True).iloc[:max_row, :]
        sns.barplot(x='FI', y='col', data=df)
        plt.show()
    
    def plot_pred_dist(self):
        fig, axs = plt.subplots(1, 2, figsize=(18, 5))
        sns.distplot(self.oof, ax=axs[1], label='oof')
        sns.distplot(self.train_y, ax=axs[0], label='train_y')
        sns.distplot(self.preds, ax=axs[0], label='test_preds')
        plt.show()

以上です!
未実装な部分はいっぱいあるので逐次修正していきたいと思います!
ゆくゆくは親クラスを作って、分類回帰でクラスを分けて継承していくみたいにしたいと思います。
こういうふうに実装した方がいいよなど知見があればコメント頂けたら幸いです。

最後まで読んでいただきありがとうございました~

google-site-verification: google1c6f931fc8723fac.html