アプリとサービスのすすめ

アプリやIT系のサービスを中心に書いていきます。たまに副業やビジネス関係の情報なども気ままにつづります

名言メモ

なんでもできるは何もできない

なんでも屋は中途半端で何にもできないのと変わらないという捉え方もできるということ

ロジックはインスピレーションを殺す

ロジカルに物を考えていては、インスピレーション的な思考はその分、思い浮かばなくなるということ。

狙撃の秘訣は?習熟だ

何事も練習。シモヘイヘの言葉

教訓

サークル入る時はスポーツで楽しむことを第一に考える。人間関係はスポーツで楽しめれば勝手についてくる。人間関係をメインにしたら失敗確率が高い

呪詛式

強い呪いをかければ、その分強い呪いが、自分にもふりかかるということ。
現実でも本来のものをいじると、それが崩れたときの反動はより大きくなる。
なるべくシンプルに自然のままにを維持すれば、仮にそれが崩れても反動は小さくて済む。
髪の毛とかもいじりすぎると、崩れた時もとに戻らなくなるのがいい例。他にもいろんな例に当てはまるのが呪詛式。

参考:シャーマンキングのハオの式神の呪詛。アンナに破られたときハオに呪いが帰ってきたシーン

家宝は練って待て

一年の計は年末年始、一月にあり。

任天堂圧縮技術の巻

技術の進化をベースとした進化はもうやめよう。
スペックによる力による競争ではなく、アイデアによる独創性で勝利した。
本来娯楽って枯れた技術を使って人が驚けばいいわけです。
別に最先端がどうかではなく、人が驚くかどうかが問題なのだから。
技術者が満足するまで開発期間伸ばす。

独創 驚き 娯楽に徹せよ
枯れた技術の水平志向

GitHub key アクセスチェーン削除

sudo open /System/Applications/Utilities/"Keychain Access.app"

で開いてgithubのkeyを削除

GitHubでhttpsのパスワード認証が廃止された。Please use a personal access token instead. - Qiita

avoid companies taking japanese like working style

有給休暇について、バカ過ぎます。

有給休暇が権利であり、どちらかというと取るのが当たり前で取らないのはオカシイという義務に近い権利なのが外資でした。なので、年間の有給休暇は外資に移ってから毎年使い切っていました。

有給は、出せば通ります。当り前ですね。いつ出されても良い様にリソース管理をするのがマネージャの仕事です。だからマネージャは高い給料をもらっています、日本は違うようですけど。

マネージャは仕事を振って来るけど、その仕事に最適な環境も提供してくれました。それがマネージャ職なのだと、外資に移って知りました。無理難題を頼む時は、必ずそれ相応のインセンティブ、つまりお金が付いてきます。

結果が全てですから、それに即した仕事のやり方をしていますし、他人は他人、自分は自分、評価もそのようになっていますので、実に働きやすいのです。

そのうえ、私の居た職場にはバカがいませんでした。

バカは直ぐにクビになるので残った人は精鋭ぞろい。ひとこと言っただけで、バッチリ物事が進みます。そして結果はピカイチです。そんな人しか残れない職場だと、給料は良いし、風通しは最高だし、実に清々しく毎日の仕事に向かうことが出来ます。

日本の陰鬱とした職場になど、倍の給料もらっても働きたくはありません。

SAR画像(VV, VH)の散乱強度と地上の植生状況(tif画像)との相関を調べてみた

やることは大雑把に言うと、SARの散乱強度から地上の植物の生育状況を予測するためにSAR画像とground画像(植生状況を表す地上の画像)の相関関係を調べる。(データはQGISで作成)

f:id:trafalbad:20211208160252p:plain

概要の類似問題の詳細はこのサイトを参考に。



主な概要

・SAR画像を説明変数に、ground画像をtarget変数にする(SARからgroundを予測)
・SAR画像は散乱強度を、groundは植物の生育段階を表す画像。各値はpixel値に埋め込まれてる
・groundの生育段階は0, 1, 2, 3 の4つ
・SARとgroundの画像サイズは10×10pixel
・画像を回帰問題かつテーブルデータとして使って相関関係を計測
アルゴリズムはTabnetで、画像はテーブルデータに変換して使う

今回は画像同士の相関関係を「画像をテーブルデータとして扱って、Tabnetを使い相関を調べる」、という特殊なやり方でやったのでその備忘録としてやったことをまとめてく。


目次
1. Tabnetとは
2. 画像をテーブルデータとして扱う
3. Tabnetの回帰用の画像テーブルデータを作る
4. Tabnetで学習
5. 評価:結果を見てみる



1. Tabnetとは

Tabnetは決定木系のアルゴリズムのDNN版で入力はxgboostやlightbgmと同じテーブルデータ。
最近はkaggleやいろんなコンペで上位勢が使ってるアルゴリズム


f:id:trafalbad:20211208160634p:plain




主な特徴は使った所管で簡単にまとめるとこんな感じ

1. データの前処理なしでDNN特有の、勾配降下法を使った最適化のend-to-endの学習が可能

2. 決定木特有の逐次学習で効率的に学習できる

3. 特徴量の重要度、選択マスクの視覚化など解釈しやすい指標が多く使える

4. データセットをマスクした特徴の教師なし事前学習を行ったあとに、教師あり学習を行うことで大幅に精度向上

事前学習の仕方



左の図のエンコーダ部分で特徴の一部にマスクをして学習し、デコーダ部分でマスクの予測を行わせて事前学習をします。その後、右の本学習で、事前学習で得た重みを用いて転移学習をします。

f:id:trafalbad:20211208160716p:plain



2. 画像をテーブルデータとして扱う

まず、どうやって相関関係をみるかは下の順番で調べる

1. 回帰問題かつテーブルデータ形式にした画像データを使う
2. 1〜11月分のデータを学習
3. 12月のデータを予測して、予測精度で相関関係があるか計測する



画像をテーブルデータとして扱うときに必須の工程は

・1次元に変換する

・回帰問題かクラス分類のどちらかで解く


画像のテーブルデータでの扱い方
f:id:trafalbad:20211208165450p:plain

画像テーブルデータを使った分類問題の例はMNISTを使った例があった。




回帰問題は1次元配列にしたtarget用の画像の全部のpixel値を予測する。

今回は11ヶ月分のSAR画像とground画像(植物の生育データ)を学習データとして、テストデータを12月分のSAR画像とground画像とした。


データ概要

全データ:1~12月の間の1ヶ月ごとのSAR画像HH, HVとground(地上の植生データ)

訓練データ:1〜11月のSARとground

テストデータ:12月のSARとground

説明変数用画像:SAR画像(散乱強度)

target用の画像:ground(植生状況)



3. Tabnetの回帰用の画像テーブルデータを作る

今回はpytorch版のtabnetを使った。
下のコードで、Tabnet用の回帰問題用の画像のテーブルデータを作成した。


訓練データの作成



def log(x)
  return np.log10(x)+3

def set_pixel(sar, ground):
  sar[ground==0]=0
  return sar

def create_data(vv, vh, gr):
  # tif画像の読み込み
  vv = cv2.imread(vv, -1)
  vh = cv2.imread(vh, -1)
  gr = cv2.imread(gr, -1)
  # SAR画像はマイナスにしないため対数をとる
  vv, vh = log(vv), log(vh)
  # groudのピクセルが0のところはSARも0にする
  vv = set_pixel(vv, gr)
  vh = set_pixel(vh, gr)
  #  画像を1次元に変換
  vv, vh, gr = np.ravel(vv).reshape(1, -1), np.ravel(vh).reshape(1, -1), np.ravel(gr).reshape(1, -1)
  return vv, vh, gr

vv = 'VV.tif'
vh = 'VH.tif'
gr = 'ground.tif'

vv, vh, gr = create_data(vv, vh, gr)
print(vv.shape)
>>>
(1, 100)


これを11ヶ月分繰り返し結合して、下のshape=(700, 8)のvh, VV, grを作成


特徴量が足らないのでSAR画像を「足し算、引き算、かけ算、割り算」で増やした。

add = vv + vh
sub = vv-vh
mul = vv * vh
dev1 = vv / vh
dev2= vh / vv

11ヶ月分結合して、訓練データを可視化してみる。

# 結合してnumpyから pandasに変換
arr = np.vstack([vv, vh, add, sub, mul, dev1, dev2, dr])
X_train = pd.DataFrame(data=arr.T)#.plot()#.get_figure().savefig('df1.png')
X_train.columns=['vv', 'vh', 'add', 'sub', 'mul', 'vv/vh', 'vh/vv', 'ground']
# Nanを0で置換
X_train = X_train.fillna(0)
print(X_train.shape)
X_train.plot()

>>>>
(700, 8)

f:id:trafalbad:20211208164654p:plain

pandasのDataFrameは下の感じ

df.head()

f:id:trafalbad:20211208164737p:plain

testデータの作成


testデータも同じように作った。

テストデータは1ヶ月分なので少ない、shapeは(100, 8)

可視化結果
f:id:trafalbad:20211208164816p:plain


あとはgroundをtargetデータとして、どのくらいSAR画像から、正確にgroundを予測できるかをTabnetで検証してみる。



4. Tabnetで学習

Tabnetをクラス化したやつ

class TBNet:
    def __init__(self, train_df, target_name, epochs=3):
         
        self.epochs = epochs
        self.target_name=target_name
        self.local_interpretability_step = 3
        self.tabnet_params = dict(n_d=15, n_a=15,
                                n_steps=8,
                                gamma=0.2,
                                seed=10,
                                lambda_sparse=1e-3,
                                optimizer_fn=torch.optim.Adam,
                                optimizer_params=dict(lr=2e-2, weight_decay=1e-5),
                                mask_type="entmax",
                                scheduler_params=dict(
                                    max_lr=0.05,
                                    steps_per_epoch=int(X_train.shape[0] / 256),
                                    epochs=self.epochs,
                                    is_batch_level=True,
                                ),
                                verbose=5,
                            )
        self.src_df = train_df
        
    def train(self, X_train, X_test, y_train, y_test, model_type='classification'):
        # model
        if model_type=='classification':
            self.model = TabNetClassifier(**self.tabnet_params)
        elif model_type=='regression':
            self.model = TabNetRegressor(**self.tabnet_params)

        self.model.fit(
            X_train=X_train, y_train=y_train,
            eval_set=[(X_test, y_test)],
            eval_metric=['rmsle', 'mae', 'rmse', 'mse'],
            max_epochs=self.epochs,
            patience=30,
            batch_size=256,
            virtual_batch_size=128,
            num_workers=2,
            drop_last=False,
            loss_fn=torch.nn.functional.l1_loss)
        
        
        print('show feature importance')
        self.plot_metric()
        self.feature_importances()
        self.local_interpretability(X_test) 
        print('prediction')
        pred = self.model.predict(X_test)
        scores = self.calculate_scores(pred, y_test)
        print(scores)
        return self.model, pred, y_test
        
    def plot_metric(self):
        for param in ['loss', 'lr', 'val_0_rmsle', 'val_0_mae', 'val_0_rmse', 'val_0_mse']:
            plt.plot(self.model.history[param], label=param)
            plt.xlabel('epoch')
            plt.grid()
            plt.legend()
            plt.show()
            
    def feature_importances(self):
        df = self.src_df
        feature_name =[str(col) for col in df.columns if col!=self.target_name]
        print(len(feature_name))
        feat_imp = pd.DataFrame(self.model.feature_importances_, index=feature_name)
        feature_importance = feat_imp.copy()

        feature_importance["imp_mean"] = feature_importance.mean(axis=1)
        feature_importance = feature_importance.sort_values("imp_mean")

        plt.tick_params(labelsize=18)
        plt.barh(feature_importance.index.values, feature_importance["imp_mean"])
        plt.title("feature_importance", fontsize=18)
        
    def local_interpretability(self, X_test):
        """どの特徴量を使うか decision making するのに用いた mask(Local interpretability)"""
        n_steps = self.local_interpretability_step
        explain_matrix, masks = self.model.explain(X_test)
        fig, axs = plt.subplots(n_steps, 1, figsize=(21, 3*n_steps))
        
        for i in range(n_steps):
            axs[i].imshow(masks[i][:50].T)
            axs[i].set_title(f"mask {i}")
            axs[i].set_yticks(range(len(self.src_df.columns[:-1])))
            axs[i].set_yticklabels(list(self.src_df.columns[:-1]))
            
    def calculate_scores(self, true, pred):
        scores = {}
        scores = pd.DataFrame(
            {
                "R2": r2_score(true, pred),
                "MAE": mean_absolute_error(true, pred),
                "MSE": mean_squared_error(true, pred),
                "RMSE": np.sqrt(mean_squared_error(true, pred)),
            },
            index=["scores"],
        )
        return scores

さっきのpandasデータからnumpy配列の訓練、テストデータを作る。

y_train = X_train['ground'].values.reshape(700, 1)
y_test = X_test['ground'].values.reshape(100, 1)
print(y_train.shape, y_test.shape)

X_trains = X_train.drop('ground', axis=1).values
X_tests = X_test.drop('ground', axis=1).values
print(X_trains.shape, X_tests.shape)
>>>>
(700, 1) (100, 1)
(700, 7) (100, 7)


Tabnetでtrainして評価。

tbnet = TBNet(X_train, target_name='drone', epochs=50)
model, pred, y_test = tbnet.train(X_trains, X_tests, y_train, y_test, model_type='regression')

5.評価:結果を見てみる

loss関数推移
f:id:trafalbad:20211208164910p:plain


accuracy(rmse)の推移
f:id:trafalbad:20211208164939p:plain

各predictionの値

              R2      MAE      MSE      RMSE
scores  0.822602  0.12724  0.08228  0.286844

特徴量の重要度
f:id:trafalbad:20211208165031p:plain

Local interpretability
f:id:trafalbad:20211208165122p:plain

Tabnetではそれに加えて、どの特徴量を使うか決定(decision making)するのに用いた mask というのを見ることができます。Local interpretabilityとも呼ばれます。mask は n_steps の数だけあり、ここでは予測したデータの先頭 50 個についてのみ図示してみます。

画像で予測データとテストデータでどれだけ正確に予測できてる図示してみる。

f:id:trafalbad:20211208163726p:plain
RMSE= 0.286844

RMSEが0.3だとかなりよく予測できるんだなと思った。
lightbgmとか決定木系のアルゴリズムなら全部、このやり方で相関がみれるはず。


画像をテーブルデータとして扱った相関関係の調べ方でした。

参考サイト

Tabnetはどのように使えるのか
TabNet-pytorch
学習帳9_B:SARの画像(SARデータ利用に進む前に)

pythonの並行処理・並列処理コード集の備忘録

pythonで並行処理・並列処理系のコードの備忘録


f:id:trafalbad:20211101112613j:plain

・multiprocessing
・concurrent.futures


thread

・古いPython2系のバージョンだとこのモジュールしかなかったりするものの、基本的には使い勝手が悪いので使わない。

・Python3系では間違って使わないように、_threadとアンダースコアがつけられているらしい。

・threading

・thread上位互換。並行処理のベーシックなビルトインモジュール。

インターフェイスが大分親切になった。

・Python3系はもちろんのこと、2.7系とかでももう使えるので、基本的にはthreadを使うくらいならthreadingを使うことになる。


・concurrent.futures

・Python3.2以降に登場。基本的にthreadingよりもさらに優秀。

・なお、futureは並列処理のFutureパターンに由来する。(1960~1970年代などに発展し、提案された結構昔からあるもの)

・スレッド数の上限を指定して、スレッドの使いまわしなどをしたりしてくれるらしい。(最初に同時に動かす最大数 max_workers を決めるとスレッドを使いまわしてくれるので上で紹介した普通のスレッドよりかしこい)

・また、マルチスレッドとマルチプロセスの切り替えも1行変える程度で、このモジュールで扱えるので、途中で変えたくなったり比較してみる際などにも便利
****
したがってCPUバウンドなピュアPythonコードを threading でマルチスレッド化しても速くならない。 subprocess による外部プログラム実行やI/OなどGIL外の処理を待つ場合には有効。
一方 multiprocessing は新しいインタプリタを os.fork() で立ち上げるので、 CPUバウンドなPythonコードもGILに邪魔されず並列処理できる。 ただし通信のため関数や返り値がpicklableでなければならない。
それらの低級ライブラリを使いやすくまとめたのが concurrent.futures (since 3.2) なので、とりあえずこれを使えばよい。

****

・threadingとどちらを使うか、という点に関しては、concurrent.futuresが使える環境(Python2.7.xなど)であればそちらを、使えない古い環境であればthreadingという選択で良さそう。

multiprocessing


1.multiprocessingのPoolで複数のcpuコアで処理


import time, os, sys
from multiprocessing import Pool
import multiprocessing

print("start worker={}", os.getpid())

def nijou(inputs):
    x = inputs
    print('input: %d' % (x))
    time.sleep(2)
    retValue = x * x
    print('double: %d' % (retValue))
    return(retValue)

if __name__ == "__main__":
    num_cpu = multiprocessing.cpu_count()
    case = int(sys.argv[1])
    values = [x for x in range(10)]
    if case==1:
        with Pool(processes=num_cpu) as p:
            print('case1')
            stime = time.time()
            print(values)
            # list is required
            result = p.map(nijou, values)
    else:
        with Pool(processes=num_cpu) as p:
            print('case2')
            stime = time.time()
            print(values)
            # list is required
            result = p.map(nijou, values)
    print(result)
    print('time is ', time.time() -stime)

もう少し複雑な処理。データのdownloadはこちら

import time, sys
import numpy as np
from multiprocessing import Pool

def cos_sim(v1, v2):
    v1_ = np.array(v1)
    v2_ = np.array(v2)
    return np.dot(v1_, v2_) / (np.linalg.norm(v1_) * np.linalg.norm(v2_))

class Sample:
    def __init__(self, user_size=943, item_size=1682, file_path="ml-100k/u.data", pool=True):
        self.file_path = file_path
        # user数×アイテム数のリスト
        self.eval_table = [[0 for _ in range(item_size)] for _ in range(user_size)]
        # user数×user数のcos類似度テーブル
        self.sim_table = [[0 for _ in range(user_size)] for _ in range(user_size)]
        
        self.pool = pool
    def distinguish_info(self, line):
        u_id, i_id, rating, timestamp = line.replace("\n", "").split("\t")
        # u_idとi_idはitemのindexを一つずらす
        return int(u_id)-1, int(i_id)-1, float(rating), timestamp


    def calc_cossim(self, target_u_id, target_user_eval):
        for u_id ,user_eval in enumerate(self.eval_table):
            self.sim_table[target_u_id][u_id] = cos_sim(target_user_eval, user_eval)
        if self.pool:
            return self.sim_table[target_u_id]
    def wrapper(self, args):
        return self.calc_cossim(*args)
        
    def run(self):
        f = open(self.file_path , 'r')
        # userとitemのテーブル作成
        start = time.time()
        for line in f:
            u_id, i_id, rating, _ = self.distinguish_info(line)
            self.eval_table[u_id][i_id] = rating

        # テーブルに基づいてcos類似度作成
        for target_u_id, target_user_eval in enumerate(self.eval_table):
            self.calc_cossim(target_u_id, target_user_eval)
        times = time.time()-start
        print("total time is:{}".format(times))
        
    def run_pool(self, processes=8):
        f = open(self.file_path , 'r')
        # userとitemのテーブル作成
        start = time.time()
        for line in f:
            u_id, i_id, rating, _ = self.distinguish_info(line)
            self.eval_table[u_id][i_id] = rating

        # テーブルに基づいてcos類似度作成
        tmp = [(target_u_id, target_user_eval) for target_u_id, target_user_eval in enumerate(self.eval_table)]
        with Pool(processes=processes) as pool:
            # 変更
            self.right_sim_table = pool.map(self.wrapper, tmp)
        times = time.time()-start
        print("total time is :{}".format(times))


if __name__ == "__main__":
    pool = str(sys.argv[1])
    if pool=='pool':
        s = Sample(pool=True)
        s.run_pool(processes=8)
    else:
        s = Sample(pool=None)
        s.run()
$ python3 complicated_pool.py pool
>>>>
total time is :40.1155219078064

$ python3 complicated_pool.py none
>>>>
total time is:194.3105709552765

2.ManagerでProcess間を順番通りに行う

from multiprocessing import Manager, Process

def f6(d, l):
    # 辞書型に値を詰め込みます.
    d[1] = '1'
    d["2"] = 2
    d[0.25] = None
    # 配列を操作します(ここでは逆順に).
    l.reverse()

if __name__ == "__main__":
    # マネージャーを生成します.
    with Manager() as manager:
        # マネージャーから辞書型を生成します.
        d = manager.dict()
        # マネージャーから配列を生成します.
        l = manager.list(range(10))
        # サブプロセスを作り実行します.
        p = Process(target=f6, args=(d,l))
        p.start()
        p.join()
        # 辞書からデータを取り出します.
        print(d)
        # 配列からデータを取り出します.
        print(l)


3.QueueでProcess間で値の受け渡し

import time
from multiprocessing import Queue, Process

def f2(q):
    time.sleep(3)
    q.put([42, None, "Hello"])

if __name__ == "__main__":
    q = Queue()
    # キューを引数に渡して、サブプロセスを作成
    p = Process(target=f2, args=(q,))
    p.start()
    # wqait for queue get()
    print(q.get())
    p.join()


4.Lockで制御

from multiprocessing import Lock, Process

def f4(lock, i):
    # ロックを取得します.
    lock.acquire()
    # ロック中は、他のプロセスやスレッドがロックを取得できません(ロックが解放されるまで待つ)
    try:
        print('Hello', i)
    finally:
        # ロックを解放します.
        lock.release()

if __name__ == "__main__":
    # ロックを作成します.
    lock = Lock()
    for num in range(10):
        Process(target=f4, args=(lock, num)).start()

5.apply_async(Poolと特に代わりない)

import time, sys, os
from multiprocessing import Pool, Process

def nijou(inputs):
    x = inputs
    print('input: %d' % x)
    time.sleep(2)
    retValue = x * x
    print('double: %d' % retValue)
    return(retValue)

class PoolApply:
    def __init__(self, processes):
        self.processes = processes
    def pool_apply(self):
        p = Pool(self.processes)
        stime = time.time()
        values = [x for x in range(10)]
        #print(values)
        # not list
        #for x in range(10):
        result = p.apply(nijou, args=[values[9]])
        print(result)
        print('time is ', time.time() -stime)
        p.close()
        
    def pool_apply_async(self):
        p = Pool(self.processes)
        stime = time.time()
        # プロセスを2つ非同期で実行
        values = [x for x in range(10)]
        result = p.apply_async(nijou, args=[values[9]])
        result2 = p.apply_async(nijou, args=[values[9]])
        print(result.get())
        print(result2.get())
        print('time is ', time.time() -stime)
        p.close()
        
if __name__ == "__main__":
    case_no = int(sys.argv[1])
    num_process = int(sys.argv[2])
    pool = PoolApply(num_process)
    if case_no==1:
        pool.pool_apply()
    elif case_no==2:
        pool.pool_apply_async()


concurrent.futures

1.MultiProcess

import math, time
import sys
import concurrent.futures

PRIMES = [
    112272535095293,
    112582705942171,
    112272535095293,
    115280095190773,
    115797848077099,
    1099726899285419]

def is_prime(n):
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False

    sqrt_n = int(math.floor(math.sqrt(n)))
    for i in range(3, sqrt_n + 1, 2):
        if n % i == 0:
            return '0'
    return '{}'.format(n)

class MultiProcess():
    def run(self):
        for number, prime in zip(PRIMES, map(is_prime, PRIMES)):
            print(f'{number} is prime: {prime}')
    
    def multi_precoss_run(self):
        with concurrent.futures.ProcessPoolExecutor() as executor:
            for number, prime in zip(PRIMES, executor.map(is_prime, PRIMES)):
                print(f'{number} is prime: {prime}')
    

if __name__ == '__main__':
    multi = str(sys.argv[1])
    MULTI = MultiProcess()
    startTime = time.time()
    if multi=='m':
        print('multi process')
        MULTI.multi_precoss_run()
    else:
        print('No multi process')
        MULTI.run()
    runTime = time.time() - startTime
    print(f'Time:{runTime}[sec]')
$ python3  multi_process.py m
>>>>
multi process
112272535095293 is prime: 112272535095293
112582705942171 is prime: 112582705942171
112272535095293 is prime: 112272535095293
115280095190773 is prime: 115280095190773
115797848077099 is prime: 115797848077099
1099726899285419 is prime: 0
Time:0.5466821193695068[sec]
$ python3  multi_process.py. none
>>>>
No multi process
112272535095293 is prime: 112272535095293
112582705942171 is prime: 112582705942171
112272535095293 is prime: 112272535095293
115280095190773 is prime: 115280095190773
115797848077099 is prime: 115797848077099
1099726899285419 is prime: 0
Time:2.1444289684295654[sec]

2.MultiThread(Poolと同じ)

import concurrent.futures
import urllib.request
import time, sys, os

URLS = ['http://www.foxnews.com/',
        'http://www.cnn.com/',
        'http://europe.wsj.com/',
        'http://www.bbc.co.uk/',
        'http://some-made-up-domain.com/']

# Retrieve a single page and report the URL and contents
def load_url(url, timeout):
    with urllib.request.urlopen(url, timeout=timeout) as conn:
        return conn.read()

class ConcurrentFutures():
    def get_detail(self):
        # Start the load operations and mark each future with its URL
        for url in URLS:
            try:
                data = load_url(url,60)
            except Exception as exc:
                print(f'{url} generated an exception: {exc}')
            else:
                print(f'{url} page is len(data) bytes')
    
    def mlti_thread_get_detail(self):
        # We can use a with statement to ensure threads are cleaned up promptly
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            # Start the load operations and mark each future with its URL
            future_to_url = {executor.submit(load_url, url, 60): url for url in URLS}
            for future in concurrent.futures.as_completed(future_to_url):
                url = future_to_url[future]
                try:
                    data = future.result()
                except Exception as exc:
                    print(f'{url} generated an exception: {exc}')
                else:
                    print(f'{url} page is len(data) bytes')

def main():
    pool = str(sys.argv[1])
    CFthread = ConcurrentFutures()
    startTime = time.time()
    if pool=='pool':
        print('multi thread')
        CFthread.mlti_thread_get_detail()
    else:
        print('no thread')
        CFthread.get_detail()
    runTime = time.time() - startTime
    print (f'Time:{runTime}[sec]')

if __name__ == '__main__':
    main()
$ python3 mlti_thread.py pool
>>>
multi thread
〜〜〜〜
Time:7.765803098678589[sec]
$ python3 mlti_thread.py none
>>>>
no thread
〜〜〜
Time:9.271435022354126[sec]

衛生系のデータGeoTiffをいじるための「QGISツール」の操作 part1

今回衛生データをいじるためにQGISというツールになれるために色々いじってみた。
のでその備忘録part1。

目次
1.GDALをinstall
2.QGISツールのdownload
3.QGISをいじってみる-データdownload
4.QGISをいじってみる-プラグインの追加
5.QGISをいじってみる-ファイルを操作


1.GDALをinstall

GDALはGeoTiffを読み込むpythonのライブラリ。今回はanacondaにinstallした。


GDALをanacondaにinstall

$ conda install -c conda-forge gdal


世界の地域のtif画像がdownloadできるサイト「Natural Earth」からdownloadして試してみる。


test.py

from osgeo import gdal, gdalconst
import sys
if __name__ == "__main__":
    tif_path = str(sys.argv[1])
    print("gdal version", gdal.VersionInfo())
    src = gdal.Open(tif_path,
                    gdalconst.GA_ReadOnly)  # tifの読み込み (read only)
    print(type(src))  # "osgeo.gdal.Dataset"
$ python3 test.py NE1_HR_LC.tif
>>>>>
gdal version 3010200
<class 'osgeo.gdal.Dataset'>

超初歩】AnacondaにGDALをインストールしてみた

condaコマンド


2.QGISツールのdownload

地理情報を扱うためのツールQGISmacにdownloadした。
QGISのサイトからmac用のinstaller

最新リリース(機能が最も豊富):のversion3.20版」をdownload。


少し重いので数分かかる。





3.QGISをいじってみる-データdownload

基盤地図情報ダウンロードサービスからdownload。

・「数値標高モデル」を選択

FG-GML-5438-00-DEM5A.zipファイルをdownload

基盤地図情報ダウンロードサービスは登録が必要なので注意。






4.QGISをいじってみる-プラグインの追加

QGISを起動して上部のタグ「プラグイン」から「プラグインの管理とインストール」を選択。
追加するプラグインは「QuickDEM4JP」。

国土地理院が提供する基盤地図情報数値標高モデル(DEM)のXML形式及びそのZIPファイルを GeoTIFF形式のDEMとTerrain RGBに変換します。



searchから「QuickDEM4JP」で検索してinstall。
f:id:trafalbad:20211010202551p:plain


プラグインから「QuickDEM4JP」を起動して、ファイルを読み込む。


・「形式」は「'xml'または'xml'を含む'zip'」

・「DEM」項目の右側にある三点リーダをクリックしてファイル選択のダイアログが出ますので、下部のプルダウンを「*.zip」に変更し、ダウンロードしたファイル(zip)を選択

・「出力先」では画像を出力したいフォルダを選択。

・「CRS」は出力したい座標参照系を設定(デフォルトはプロジェクトのCRS)


下のようになれば問題なし。
f:id:trafalbad:20211010202437p:plain


OKをクリックして処理が終われば下のようにファイルがQGISで表示される。

f:id:trafalbad:20211010202518p:plain


5.QGISをいじってみる-ファイルを操作


後は右の操作で色々な画像を生成できます。

コントラストで「最小値・最大値までの範囲以外は無視」の場合



f:id:trafalbad:20211010202619p:plain


混合モードを「乗算」にしてみる



f:id:trafalbad:20211010202649p:plain


**さっき選択した保存先のdownloadフォルダにtiff形式の画像が保存される。

プロジェクトで「名前をつけて保存」で保存できる。

f:id:trafalbad:20211010202711p:plain

それを開けば続きができる。


画像を扱うのでQGISをいじる機会があった。調べてみるとQGISはかなり需要があるっぽいので習得しといて損ないなと思う。



参考サイト

国土地理院の標高データ(DEM)をQGIS上でサクッとGeoTIFFを作って可視化するプラグインを公開しました!(Terrain RGBもあるよ)

yolov5とDeepSortでマルチスレッドのリアルタイムtracking物体検出【機械学習】

yolov5とDeepsortとかいうtrackingのアルゴリズムを使ってtrackingの物体検出をしてみた。

最終的にpythonGUIツールtkinterでマルチスレッド化して動かした。

全体像

f:id:trafalbad:20211004023701p:plain

備忘録として使った技術をまとめてく。

目次
1.yolov5
2.trackingアルゴリズム「DeepSort」
3.GUIツールの「tkinter
4.pythonでマルチスレッド
5.出力結果

1.yolov5

yolov5はDarknetを使ってない。
今回はmac上で動かして、Pytorch のyolov5を使った。


yolov5は

Small(YOLOv5s) => Medium(YOLOv5m) => Large(YOLOv5l)=> Xlarge(YOLOv5x)

の順で大きさと精度が増してくらしい。


f:id:trafalbad:20211004021907p:plain

f:id:trafalbad:20211004021930p:plain



2.trackingアルゴリズム「DeepSort」

DeepSORTのアーキテクチャは物体検出とトラッキングに分かれてて、yolov5の推論後に、物体検出のbboxに番号をつけて検出物をtrackingする。

f:id:trafalbad:20211004022050p:plain

ここのgithubを参考にした
Yolov5_DeepSort_Pytorch
yolov4-deepsort


DeepSortとyolov5のtrackingコードの一部。

yolov5_tracking.py(の一部)

import sys
from pathlib import Path
import cv2, os
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import PIL.Image, PIL.ImageTk
from multiprocessing import Queue

def load_model(weights, device, half=False):
    global deepsort
    cfg = get_config()
    cfg.merge_from_file('Deepsort/configs/deep_sort.yaml')
    deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
                        max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
                        max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
                        max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
                        use_cuda=True)
    yolov5 = attempt_load(weights, map_location=device)  # load FP32 model
    if half:
        yolov5.half()  # to FP16
    return yolov5

def yolov5_detection(q:Queue, opt, save_vid=False, show_vid=False, tkinter_is=False):
    initialize()
    weights, source, imgsz, conf_thres, iou_thres = opt.weights, opt.source, opt.imgsz, opt.conf_thres, opt.iou_thres
    save_txt, classes, agnostic_nms, augment = opt.save_txt, opt.classes, opt.agnostic_nms, opt.augment
    nosave, exist_ok = opt.nosave, opt.exist_ok

    # 〜省略〜

    # Load model
    w = weights[0] if isinstance(weights, list) else weights
    classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
    check_suffix(w, suffixes)  # check weights have acceptable suffix
    pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes)  # backend booleans
    stride, names = 64, [f'class{i}' for i in range(1000)]  # assign defaults
    
    yolov5 = load_model(weights, device, half=False)
    stride = int(yolov5.stride.max())  # model stride
    names = yolov5.module.names if hasattr(yolov5, 'module') else yolov5.names  # get class names
    
    # Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
    else:
        dataset = LoadImages(source, img_size=imgsz, stride=stride)
    bs = 1  # batch_size

    # Run inference
    if pt and device.type != 'cpu':
        yolov5(torch.zeros(1, 3, *imgsz).to(device).type_as(next(yolov5.parameters())))  # run once
    dt, seen = [0.0, 0.0, 0.0], 0
    t0 = time.time()
    for frame_idx, (path, img, im0s, vid_cap) in enumerate(dataset):
        # Inference
        t1 = time_sync()
        img = preprocess_img(img, device, half=half)
        pred = yolov5(img, augment=opt.augment)[0]

        # Apply NMS
        pred = non_max_suppression(
            pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_sync()

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            p, s, im0 = path, '', im0s

            s += '%gx%g ' % img.shape[2:]  # print string
            save_path = str(Path(project) / Path(p).name)

            annotator = Annotator(im0, line_width=2, pil=not ascii)
            deepsort, annotator, s = deepsort_detection(annotator, det, img, im0, names, s)
        
            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            im0 = annotator.result()
            if show_vid:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            elif save_vid:
                vid_path = save_path
                if isinstance(vid_writer, cv2.VideoWriter):
                    vid_writer.release()  # release previous video writer
                if vid_cap:  # video
                    fps = vid_cap.get(cv2.CAP_PROP_FPS)
                    w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                    h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                else:  # stream
                    fps, w, h = 30, im0.shape[1], im0.shape[0]
                    save_path += '.mp4'

                vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
            elif tkinter_is:
                im_rgb = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
                im_rgb = cv2.resize(im_rgb, (1280, 720))
                q.put(im_rgb)

    print('Done. (%.3fs)' % (time.time() - t0))


3.GUIツールの「tkinter

pythonGUIツールはありすぎるけどtkinterが1, 2位の人気があるっぽい。(2021/09)

Google trandで一年間の推移を見ても全ての国で人気が衰えてないし、グラフ化してもはっきりわかる。
f:id:trafalbad:20211004022448p:plain

f:id:trafalbad:20211004022505j:plain

yolov5_tracking.pyで取得した画像をtkinterで表示。

tkinter_app.py

import tkinter as tk
from tkinter import ttk
import cv2
import PIL.Image, PIL.ImageTk
from tkinter import font
import time
from multiprocessing import Queue


class Application(tk.Frame):
    def __init__(self,master, q:Queue, video_source=0):
        super().__init__(master)

        self.master.geometry("1280x768")
        self.master.title("Tkinter with Video Streaming and Capture")
        
        self.q = q
        
        self.font_setup()
        self.vcap = cv2.VideoCapture( video_source )
        self.width = self.vcap.get( cv2.CAP_PROP_FRAME_WIDTH )
        self.height = self.vcap.get( cv2.CAP_PROP_FRAME_HEIGHT )

        self.create_widgets()
        self.create_frame_button(self.master)
        self.delay = 15 #[ms]
        self.update()


    def font_setup(self):
        self.font_frame = font.Font( family="Meiryo UI", size=15, weight="normal" )
        self.font_btn_big = font.Font( family="Meiryo UI", size=20, weight="bold" )
        self.font_btn_small = font.Font( family="Meiryo UI", size=15, weight="bold" )

        self.font_lbl_bigger = font.Font( family="Meiryo UI", size=45, weight="bold" )
        self.font_lbl_big = font.Font( family="Meiryo UI", size=30, weight="bold" )
        self.font_lbl_middle = font.Font( family="Meiryo UI", size=15, weight="bold" )
        self.font_lbl_small = font.Font( family="Meiryo UI", size=12, weight="normal" )
        
    def create_widgets(self):

        #Frame_Camera
        self.frame_cam = tk.LabelFrame(self.master, text = 'Camera', font=self.font_frame)
        self.frame_cam.place(x = 10, y = 10)
        self.frame_cam.configure(width = self.width+30, height = self.height+50)
        self.frame_cam.grid_propagate(0)

        #Canvas
        self.canvas1 = tk.Canvas(self.frame_cam)
        self.canvas1.configure( width= self.width, height=self.height)
        self.canvas1.grid(column= 0, row=0,padx = 10, pady=10)

    def create_frame_button(self, root):
        # Frame_Button
        self.frame_btn = tk.LabelFrame(root, text='Control', font=self.font_frame)
        self.frame_btn.place(x=10, y=650 )
        self.frame_btn.configure(width=self.width, height=120 )
        self.frame_btn.grid_propagate(0)

        # Close
        self.btn_close = tk.Button( self.frame_btn, text='Close', font=self.font_btn_big )
        self.btn_close.configure( width=15, height=1, command=self.press_close_button )
        self.btn_close.grid( column=1, row=0, padx=20, pady=10 )


    def update(self):
        frame = self.q.get() 
        self.photo = PIL.ImageTk.PhotoImage(image = PIL.Image.fromarray(frame))

        #self.photo -> Canvas
        self.canvas1.create_image(0,0, image= self.photo, anchor = tk.NW)
        self.master.after(self.delay, self.update)

    def press_close_button(self):
        self.master.destroy()
        self.vcap.release()
        self.canvas.delete("o")


4.pythonでマルチスレッド

c++でマルチスレッドをしたことがあるけど、phthonでしてみた。
yolov5で取得したimgをtkinterで出力する。


yolov5_tracking.pyのProcess部分

elif tkinter_is:
     im_rgb = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
     im_rgb = cv2.resize(im_rgb, (1280, 720))
     q.put(im_rgb)


tkinter_app.pyでyolov5_tracking.pyを取得して、canvasにplotする部分

def __init__(self,master, q:Queue, video_source=0):
        super().__init__(master)

        self.master.geometry("1280x768")
        self.master.title("Tkinter with Video Streaming and Capture")
        
        self.q = q

def update(self):
        frame = self.q.get() 
        self.photo = PIL.ImageTk.PhotoImage(image = PIL.Image.fromarray(frame))

        #self.photo -> Canvas
        self.canvas1.create_image(0,0, image= self.photo, anchor = tk.NW)
        self.master.after(self.delay, self.update)

最終的にtkinter_app.py とyolov5_tracking.pyをマルチスレッド化して動かす
main.py

#https://stackoverflow.com/questions/23599087/multiprocessing-python-core-foundation-error/23982497
from multiprocessing import Process, Queue

def tkapp_thread(q):
    import tkinter as tk
    from tkapp_thread import Application
    root = tk.Tk()
    app = Application(root, q, video_source=0)
    app.mainloop()
    
def yolov5_thread(q):
    from yolov5_detect import yolov5_detection
    from option_parser import get_parser
    opt = get_parser()
    yolov5_detection(q, opt, save_vid=False, show_vid=False, tkinter_is=True)
    
if __name__ == '__main__':
    q = Queue()
    p1 = Process(target = tkapp_thread, args=(q,))
    p2 = Process(target = yolov5_thread, args=(q,))
    p1.start()
    p2.start()
    p1.join()
    p2.join()


5.出力結果

Yolov5 + DeepSort + multi-thread + tkinter + 動画
で出力した結果はこちら

出力1
f:id:trafalbad:20211004023008g:plain


出力2
f:id:trafalbad:20211004023022g:plain

CPUのMac上だとかなり遅いけど、精度はかなりいい。
DeepSort + multi-thread + Tkinter
の組み合わせは面白かった。

衛星のSAR画像-セグメンテーションコンペの備忘録

衛星データのSAR画像を用いたセグメンテーションのコンペがあったので、その際の使ったコードとか手法の備忘録。

コンペ内容は事情により省略。手法だけまとめてきます。

大雑把に言うと、過去と現在の画像から特定の領域を0, 1でセグメンテーションするタスク。

目次
1. 使ったネットワーク「HRNet」
2.グレースケールのtifフォーマット画像の読み込み
3.tifフォーマット画像のrgb化
4.使える手法
5.予測したmask画像の0, 1化
6.クロップして予測してconcat
7.よかった手法、ダメだった手法



1.使ったネットワーク「HRNet」

使ったのは姿勢推定とかでSOTを出したネットワーク。すごい軽かったし、カスタマイズしやすかった。

f:id:trafalbad:20210905185837p:plain

reluをmishに変えたりしたのが効果的だった。

tensorflow版のmish活性化関数

import tensorflow_addons as tfa
x = tfa.activations.mish(x)

ちなみにクラスが1のセグメンテーションタスクだったので、アウトプットサイズも1でOK。


2.グレースケールのtifフォーマット画像の読み込み

グレースケールのtifフォーマットの画像は医療系データとか衛星データでよく見かける。

読み込みは普通のpngとかjpgとかとは違って工夫がいる。


pillowで読み込み

sar = Image.open('image.tif')

opencvでよみこみ

sar = cv2.imread('image.tif', -1)

tif用の画像の可視化関数。colablatelyで、使ってるものをjupyterで使えるようにした。

def cv2_imshow(a):
    """A replacement for cv2.imshow() for use in Jupyter notebooks.
    Args:
    a : np.ndarray. shape (N, M) or (N, M, 1) is an NxM grayscale image. shape
      (N, M, 3) is an NxM BGR color image. shape (N, M, 4) is an NxM BGRA color
      image.
    """
    a = a.clip(0, 255).astype('uint8')
    # cv2 stores colors as BGR; convert to RGB
    if a.ndim == 3:
        if a.shape[2] == 4:
            a = cv2.cvtColor(a, cv2.COLOR_BGRA2RGBA)
        else:
            a = cv2.cvtColor(a, cv2.COLOR_BGR2RGB)
    plt.imshow(a/255, 'gray'),plt.show()
    display.display(Image.fromarray(a))


f:id:trafalbad:20210905232204p:plain




3.tifフォーマット画像のrgb化

過去と現在の画像があるので
・差分をとる
・rgb化する方法

とかの方法がある。grayscaleで単に学習させるより、かなり効果的。

# load
im1 = load_tif('0VV.tif')
im2 = load_tif('0VH.tif')
im3 = load_tif('1VV.tif')

anno_test = cv2.imread('train.png', -1)

# RGB化
r = im1 * 255 
g = im2 * 255 
b = im3 * 255
rgb = np.dstack((r,g,b))
rgb = rgb(0, 255).astype('uint8')

f:id:trafalbad:20210905231932p:plain


4.使える手法

他の使える方法まとめ。

1.単純な引き算以外で差分をとる方法

diff = np.maximum(im1- im2], 0.5) - 0.5 
im_pred = np.heaviside(diff, 0)


2.とても細かい部分が多いので、細かい部分を除去する方法

def plot_correctness(im_truth, im_pred):
    r = im_truth
    g = im_pred 
    b = im_pred
    cv2_imshow(cv2.merge((r, g, b)))


def remove_blob(im_in, threshold_blob_area=25): 
    '''remove small blob from your image '''
    im_out = im_in.copy()
    contours, hierarchy = cv2.findContours(im_in.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for i in range (1, len(contours)): 
        index_level = int(hierarchy[0][i][1]) 
        if index_level <= i:
            cnt = contours[i]
            area = cv2.contourArea(cnt)
            if area <= threshold_blob_area: 
                cv2.drawContours(im_out, [cnt], -1, 0, -1, 1)
    return im_out

im_out = remove_blob(im_pred * 255, threshold_blob_area=25)
plot_correctness(anno_test * 255, im_out.astype(np.uint8))


f:id:trafalbad:20210905232050p:plain



5.予測したmask画像の0, 1化

予測したgrayscaleの画像を0と1のマスクに変換する。
sigmoidを使うことがほとんどで、閾値で判別して、0か1を突っ込む。

閾値の設定が割と定まってないのが難儀。

def create_binary_mask(pred, threshold=0.1):
    mask = np.zeros_like(pred)
    mask[pred < threshold] = 0
    mask[pred > threshold] = 1
    return mask

# IOU計算
def calc_IoU(y, y_pred):
    by = np.array(y, dtype=bool)
    by_pred = np.array(y_pred, dtype=bool) 
    overlap = by * by_pred
    union = by + by_pred
    return overlap.sum() / float(union.sum())


6.クロップして予測してconcat

今回のコンペ は

・画像枚数が少ない(40枚以下)
・セグメンテーション領域が細かい
・画像サイズがでかい

ので、1枚の画像を複数にクロップして学習する。予測する画像もクロップしてから予測。そのあとくっつける(concat)。


こうすることで、細かい部分もかなり精密にセグメンテーションできるので、丸ごと画像を入れるより、かなり正確にセグメンテーションできる。(引用:qiita

f:id:trafalbad:20210905191827p:plain





7.よかった手法、ダメだった手法

効果があった手法

・差分とってrgb化
・augmentation :horizontal flip, vertical flip,
horizontal and vertical flip
SGD
・bce dice loss
・cropして予測した後concat(896にリサイズ後にサイズ448×448にクロップして4枚にした)
・unet, efficientunet
ヒストグラム平均化
・sigmoid
・0〜1で正規化

ダメだった手法

・adam
・focal loss, jacard loss
・小さい塊を除去する
・softmax
・標準化

終了2週間前に参加して、IOUが45%でブロンズに入れたので、よくできた方だと思う。手法は過去コンペとか読み漁ったのが良かった。

あとは試行回数とアイデア勝負。引き出しの大きさが重要だな感じた。

あとコンペ のdiscussionとかを読み込むことで、コンペ の概要を手っ取り早くしれるのであと出しの参加でも十分勝負できた。

MobilenetベースのBlazefaceをResnetベースに蒸留して性能を移植してみた【機械学習】

Blazefacegoogleの訓練済みのデバイス用の顏/手の認識モデルで速くて高性能なモデル。
役割は顔/手の位置とキーポイントを高速に検出する機械学習モデル。


BlazeFaceはMobileNetをバックボーンをベースにして作られてる。

今回はやることは

1.tfliteのblazefaceをpytorch用の重みに変換する。

2.それで、本来のMobileNetバックボーンに加えて、Resnetバックボーンのblazefaceを自作する。

3.自作ResNetバックボーンにMobileNetバックボーン(本家)の性能を「蒸留(Distillation)」を使って移植し、再現。

なので今回はその備忘録。

目次
1.Distillationについて
2.今回使ったloss
3.Resnetベースのバックボーンモデル
4.蒸留の精度
5.精度曲線とloss曲線




1.Distillationについて

蒸留はかなりやり方が多様で、クリエイティビティな側面が大きい。
個人的に下の神記事のWavenetとやり方とkaggleのサイトがかなり参考になった。

Distillation Implementation with Freesound2
Deep Learningにおける知識の蒸留


BlazeFaceは出力が2つある。

・出力2はbboxとランドマークの位置情報
・出力1はクラスの確率

なので、出力1を蒸留用のlossに使うことにした。

あとは普通に出力1, 2の教師あり学習データは、pseudo-labelingで代用。

蒸留の仕組み
f:id:trafalbad:20210701113724p:plain

BlazeFaceの出力部分

class BlazeFace(nn.Module):

〜〜〜〜〜
  c = torch.cat((c1, c2), dim=1)  # (b, 896, 1)

        r1 = self.regressor_8(x)        # (b, 32, 16, 16)
        r1 = r1.permute(0, 2, 3, 1)     # (b, 16, 16, 32)
        r1 = r1.reshape(b, -1, 16)      # (b, 512, 16)

        r2 = self.regressor_16(h)       # (b, 96, 8, 8)
        r2 = r2.permute(0, 2, 3, 1)     # (b, 8, 8, 96)
        r2 = r2.reshape(b, -1, 16)      # (b, 384, 16)

        r = torch.cat((r1, r2), dim=1)  # (b, 896, 16)
        return [r, c]

〜〜〜〜〜

以下の部分が蒸留で大事なのではと思った。

・KL Divergence_lossに入れるのはネットワークの出力に近いやつほどいい。(確率分布)

・Softmax with Temperature(SwT)はKL Divergence_lossと併用が一般的。

・なるべく蒸留はlossでは確率分布を学びやすいように目的に合わせて、前処理する。

教師あり学習はラベル問題とかセグメンテーションタスクみたいにリアルなデータを使うのが普通だが、教師データがなくてもpseudo-labelingでもいけた。

・lossはMSE以外にも出力が似るように収束するlossならOK。





2.今回使ったloss

蒸留用lossと教師ありデータ用lossでうまくいったのは下の2つ。蒸留用lossは一般的なnn.KLDivLoss()を使って、確率分布を前処理する形にした。
教師ありも本家のlossとは違い、出力が似やすいように前処理した。



1つ目:best性能loss。前処理重視でMAEnn.L1Loss()とかMSEnn.MSELoss(), F.binary_cross_entropyを使った

import torch
import torch.nn as nn
import torch.nn.functional as F

def kl_divergence_loss(logits, target):
    T = 0.01
    alpha = 0.6
    thresh = 100
    criterion = nn.L1Loss()
    # c : preprocess for distillation
    log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + criterion(log2div, tar2div) * (1-alpha)
    
    # r
    anchor = load_anchors("src/anchors.npy")
    rlogits = decode_boxes(logits[0], anchor)
    rtarget = decode_boxes(target[0], anchor)
    rloss = criterion(rlogits, rtarget) 
     
    return closs + rloss
def kl_divergence_loss(logits, target):
    T = 0.01
    alpha = 0.6
    thresh = 100
    criterion = nn.MSELoss()
    # c : preprocess for distillation
    log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1).detach()
    closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + F.binary_cross_entropy(log2div, tar2div) * (1-alpha)
    
    # r
    anchor = load_anchors("src/anchors.npy")
    rlogits = decode_boxes(logits[0], anchor)
    rtarget = decode_boxes(target[0], anchor)
    rloss = criterion(rlogits, rtarget) 
     
    return closs + rloss


Tを0.01〜10にしてもあんまり影響なく、いかに蒸留と教師あり学習のlossを上手く学びやすい形にするかが重要。


2つ目:あんまり性能出なかったloss。前処理しないと多分、blazefaceみたいな特殊な出力はうまく蒸留できない。

def kl_divergence_loss(logits, target):
    T = 0.01
    alpha = 0.6
    thresh = 100
    criterion = nn.L1Loss()
    # c : preprocess for distillation
    log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + criterion(logits[1], target[1]) * (1-alpha)
    
    # r
    rloss = criterion(logits[0], target[0]) 
    return closs + rloss
def kl_divergence_loss(logits, target):
    T = 0.01
    alpha = 0.6
    thresh = 100
    criterion = nn.MSELoss()
    # c : preprocess for distillation
    log2div = logits[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    tar2div = target[1].clamp(-thresh, thresh).sigmoid().squeeze(dim=-1)
    closs = nn.KLDivLoss(reduction="batchmean")(F.log_softmax((log2div / T), dim = 1), F.softmax((tar2div / T), dim = 1))*(alpha * T * T) + criterion(logits[1].squeeze(dim=-1), target[1].squeeze(dim=-1)) * (1-alpha)
    
    # r
    anchor = load_anchors("src/anchors.npy")
    rlogits = decode_boxes(logits[0], anchor)
    rtarget = decode_boxes(target[0], anchor)
    rloss = criterion(rlogits, rtarget) 


3.Resnetベースのバックボーンモデル


Resnetベースはresnet18の構造はMobilenetのモデルにresidual構造を加えたネットワークを自作。

residual構造にした部分

class BlazeFace(nn.Module):
〜〜〜
if self.resnet_backborn:
            print('resnet base')
            inputs = x
            o1 = self.conv(inputs)
            o2 = self.relu(o1)
            o3 = self.b1(o2)
            o4 = self.b2(o3) + self.resconv1(o2)
            o5 = self.b3(o4)
            o6 = self.b4(o5)
            o7 = self.b5(o6) + self.resconv2(o5)
            o8 = self.b6(o7)
            o9 = self.b7(o8)
            o10 = self.b8(o9)
            o11 = self.b9(o10)
            o12 = self.b10(o11) + self.resconv3(o8)
            x = self.backbone1(o12) # (b, 88, 16, 16)
            x1_ = self.b11(x)
            x2_ = self.b12(x1_)
            x3_ = self.b13(x2_) + x1_
            x4_ = self.b14(x3_)
            h = self.backbone2(x4_) + x1_  # (b, 96, 8, 8)
〜〜〜


4.蒸留の精度

実験1.

MobileNetベース(googleのやつ)



f:id:trafalbad:20210701112406p:plain


f:id:trafalbad:20210701112421p:plain


Resnetベース(蒸留したやつ)

f:id:trafalbad:20210701112451p:plain


f:id:trafalbad:20210701112506p:plain

ほとんど本家のblazefaceと同じ精度。蒸留ってすご。





5.精度曲線とloss曲線

精度曲線(MAE)

R出力(出力1)
f:id:trafalbad:20210701074609p:plain


C出力(出力2)
f:id:trafalbad:20210701074639p:plain

loss曲線

train loss
f:id:trafalbad:20210701074654p:plain

validation loss
f:id:trafalbad:20210701074709p:plain


初めは全然違うが段々、値が似てくる。評価用のテストスクリプトでテストできるようにして、なるべく高速で多くPDCAが回しやすいようにした。

tensorboardのグラフでデータがかなり客観的でにとらえられた。








今回初めてだったけど、かなり上手くいった。

教師データがないのが痛かったけどpseudo-labelingで上手いように学習できたのは前処理でさんざん悩んだおかげ。

蒸留は、ハードウェアでは、The・トレンド感な技術なだけあって、今回みたいな変な出力を蒸留するのは相当難かった。


精度はかなり忠実に移植できたので、相当いい結果になったつもり。



参考サイト

tensorflow-BlazeFaceのGitHub

Deep Learningにおける知識の蒸留

モデルの蒸留を実装し freesound2019 コンペで検証してみた。

knowledge-distillation-pytorch