アプリとサービスのすすめ

アプリやIT系のサービスを中心に書いていきます。たまに副業やビジネス関係の情報なども気ままにつづります

prednetで未来の画像予測、天気画像の予測コンペのログ【機械学習、python】

signateの雲画像予測コンペに初参加したので、そのログ。
個人の備忘録と技術メモのつもりで簡単にやったことつづってきます。

まず結果は17位。1位との精度差は10%くらいだった。初参加にしては上々なできかな。レベル高くて驚いた。

コンペのデータ構成

データは雲画像とmetaデータの2種類。metaデータは湿度、温度、風の向きとかの格子状の2次元データ(グレースケール画像)。


【データの特徴把握のため使った方法】

・月ごとの雲量をヒストグラムで可視化

・TSNEで雲量を分布を可視化

・類似画像検索でmetaデータと雲画像と相関関係の深いデータを調べた

・天気関係の本で気象庁がどうやって雲を予測しているかとか、雲の発生する流れの理解


metaデータの中で、湿度(RH)と風(VGRD)が雲画像と一番関係しているけど、活用方法が思いつかなったので保留した。

使うアルゴリズムは、動画のフレーム予測を予測するGAN系ネットワーク(VideoGAN, Recycle-GAN)とかを調べた中で、convLSTMの発展系のprednetが一番使えそうなので、改良して使用。

f:id:trafalbad:20191117144850j:plain



prednetを改良

以下の点でprednetを改良

・効率的に学習させるためgeneratorを改良

・prednet にresnetを追加することでextrapolationや精度をよくした。

・訓練回数を増やしてなるべく精度を上げた

・prednetの活性化関数のA_activation='relu'なのを、差分をとりたいなら"tanh"にすると有効だった。

・画像を簡単にするよりも、なるべく多くの画像群での一連の動作を何回も学習させた方が効果的。


resnetを入れて、時刻tの予測画像をt+1の予測画像にaddした。そうすることで一から学習することなく、前の画像をもらいながら予測できる。


【extrapolationで24時間分の画像予測部分】

f:id:trafalbad:20191117145006j:plain


公開されてたprednetは複雑でネットワークがかなり改造しにくかったので、もっと改造しやすいprednetを実装すればよかった。


【convLSTM系ネットワーク改造例】
f:id:trafalbad:20191117145027j:plain



【他にも色々試して無駄だったこと】

・画像を簡単にする(黒の要素が強い0付近のピクセルは全部ゼロにして雲を強調する)

opencvのaddWeightメソッドで合成する

・cycleGANで予測した画像の補正とかUnetで予測するとかのアイデアと実装してしまったこと


精度を上げるには、metaデータをもっと使ったり、過去データを上手く活用する必要があったかもしれない。




結果まとめ

【一位の人の画像】

f:id:trafalbad:20191117162711j:plain

等高線の軌跡があるので、metaデータと雲量の相関関係を出して、関係のあるmetaデータを予測して、雲と関係の高い値のところをヒートみたいに染めるとかの手法でも使ったのか?

【予測する前の3h分の画像のgif(GroundTruth)】

f:id:trafalbad:20191117162731g:plain





【そこから自分が予測した24h分の画像のgif】

f:id:trafalbad:20191117162751g:plain





はじめてのコンペだったので、下調べに時間使いすぎたり、上手く精度を上げる方法がわからなかった。

投稿した回数は数回で、結果は17位。1位との精度差は10%くらいだった。


とにかく「まず投稿→feedback→改善→投稿」の繰り返しでPDCAサイクルを効率よく回せるかが、精度上げてくののカギっぽい。上位の人ほど投稿数が多い傾向にあった。

ちなみに自分は少ししか投稿しなかったので、何が精度アップに直結するのか把握する余裕がなかったのが後悔した点。


初参加で学んだこと

・下準備や画像生成に直結しない、無駄な画像の特徴把握に時間使いすぎて、1回目の投稿がかなり遅れたこと。

・とりあえずゴミのような駄作でも一度submitまでやってしまうこと。PDCAサイクルが回しやすくなる。

・アイデアは先人の知恵を奪うこと

・なるべく高得点を出せるパターンを見つけてPDCAサイクルを回して改善していく

・とりあえずやったデータやアイデアのログは後から見返せるように、dropboxに保存したり、ノートに記録しておいた。

・いい手法を思いつくには大量に良質なインプットして考えまくるのは変わらない

・今の手法がいいか迷ったらとにかくsubmitして得点だし、feedbackを得ること




はじめてのコンペでわからないことも多かったので、勝つより、経験を積むための初戦だった気がする(もちろん勝つ気でいたけど、当たり前的にレベルが高かった)。

こんなに考えたのは機械学習の中でも1、2位を争うくらいだったので、スキルも上がった。何より発想力、技術力がメッチャ上がるし、気軽に始められるし、勝てればリターンでかいし、いいことずくめだと思った。
(基本、土日も考えるので終わった後の疲労感はやばい)


参考サイト

PredNetを用いた混雑レーダーの未来予測

PredNetのGitHub

九州旅-福岡県添田町大藪集落廃村めぐりとかいろいろ

九州の一人旅。福岡県市内から100kmは離れた添田町の大藪集落というところに行ってきた。

googleで見ると大藪峠というのがあるらしいけど、どうやら下付近のダムの中に、昔は家が沢山あって今は林に埋れてしまった場所があるらしい。

お目当の大藪集落の分校後はいくら探しても見つからなかったので、そのダムの中に入って行ったらきっとあるんだろう。
空き家はちょこちょこあってもう50年近く経過してるとのことで、かなりの荒れようだった。

廃村めぐりとしては慌ただしく、タクシーで3時間くらいしか回れなかった。
もっと調べて2日くらいかければもっとゆとりある旅行できた。

トラファルガーやミホークが島を別荘にしてるように、九州は由布院にもかなり空き家があった。しかも普通に住めるやつ。


金かけなくても、ローやミホークみたいに空き家に勝手に住めるようにすれば、ノマドライフできるなとか考えた。

旅中にpythonを書き留めとける神アプリ「pythonist」に出会えた。これでアイデアが浮かんだらiPad にコード書き留めとける。

iPad Pro買おうか悩んだけど、別に必要ないかな今は。「あれば便利」は使わなくなるから、スマホレベルで「ないと困る」じゃないと買わないと思う。

確かにiPad Proからawsにログインできたりは理想だけど、メモ書き留めて、実行は母体のPCでやるなら買う必要なし。



旅は今のうちにしとかんとやばいなと思う。今の環境はいずれ享受できなくなるし、サイコパスみたいに数十年後に東京一極集中が加速すると今みたいに行かなくてなるし、老いぼれになるときついし。
想いも旅路も巡った九州旅だった。平日に3泊4日の旅行けるなんてエンジニアの特権だなとひたすら思った。

宮崎の方、九州内陸には平家の落ち武者が住んだレベルの廃村がかなりあるらしく、行ってみてーと思た。大藪でもかなりなのにレベルはさらに上らしい。
最後に思い出写真。

f:id:trafalbad:20191101133056j:plain

f:id:trafalbad:20191101133051j:plain

f:id:trafalbad:20191101133048j:plain

f:id:trafalbad:20191101133106j:plain

f:id:trafalbad:20191101133110j:plain

f:id:trafalbad:20191101133928j:plain

f:id:trafalbad:20191101133120j:plain

植林で表彰されたのが住んでたらしく、賞状がいっぱいあった。犬小屋もあったから、犬もいたんだろう。

畳は陥没寸前でこち亀の日暮の部屋みたいになってた。

f:id:trafalbad:20191101133110j:plain

f:id:trafalbad:20191101133100j:plain

f:id:trafalbad:20191101133103j:plain

由布院にも空き家はかなりあった。

これも由布院の空き家。人が出てった後で、掃除すれば普通に住めるやつやった。

f:id:trafalbad:20191101133230j:plain

f:id:trafalbad:20191101133233j:plain

今は使われてない由布院厚生年金保養ホーム

f:id:trafalbad:20191101133226j:plain
病棟のアパートとか普通にあった。東京の3万の物件より良かったんですけど。

f:id:trafalbad:20191101133241j:plain

f:id:trafalbad:20191101133237j:plain


中心街は観光地として栄えてだけど、これから高齢化、老いぼれ他界化が進めばさらに空き家は増えるだろうなと思う。

anoGANで画像の高精度の異常検知(anoGAN, metric learning, VAE) 【機械学習】

今回は画像から異常を検知するタスクをやった。データセットはドイツの異常検知コンペのデータセットDAGM 2007」。

人が人為的に微細な傷をつけたかなりガチの異常検知用データセット
f:id:trafalbad:20190904213001j:plain

精度出すのにやった過程とかを備忘録も兼ねてまとめてく。


目次
1.やったことない内容をまず何からはじめたか
2.一通り動かしてフィードバックを得る
3.ひたすらインプットして、集中して考えまくる
4.飛び道具は使わず、オールマイティな手法の汎用性を上げる工夫をした
5.まとめ



1.やったことない内容をまず何からはじめたか


とにかく下準備8割は自分の経験則なので、準備は念入りにした。

・主な一般的な・人気なトレンド手法を論文、kaggle、qiitaとかいろんなサイトで探す

アルゴリズムでスターの多いgithubコードをリサーチ

・大まかな工程を決める

オライリーとの機械学習の画像系の本読んだ

等、下準備をかなりやった。



主に調べたサイト

そのあとはとにかく失敗を恐れず、コードを書いて動くものを作ってみた。

たいてい論文みたいな精度が出ることはないので、動かす。そして問題点を洗い出し、フィードバックをもらって対策を考えた。





2.一通り動かしてフィードバックを得る

チーム内で問題点をざっと共有。主にslackに投げるとアドバイスや意見とかもらえることもあるし、自分でやったこと見直すこともできるのし、対策を考えるのにも役立つ。

問題点を書き出して、人目につくところに投稿してフィードバックを得る or まとめるというのは重要な作業。

anoGANの問題点は例えばこんだけあった。


・普通のGANでは生成画像の精度が悪い

・3次元(カラー)とグレースケールの使い分けが悪い可能性あり

・画像サイズが小さくない?(小さすぎて異常がわからない)

・異常画像と生成画像の差分の取り方、重み付けに使うopencvのメソッドのパラメータの値に問題あるかも

閾値の設定してないのやばいと思う

など



それからフィードバックもらい、anoGANの改善案の一部

問題点:生成精度が悪い

解決策:マスキングする、もっと精度のいいGAN使う、画像サイズを上げる


3.ひたすらインプットして、集中して考えまくる

次は他の同業者の事例からアイデアを盗む、関連事例で成功してる手法、kaggleとかで上手くいってる手法、論文で参考になる部分、問題点を解決する手法、必要な書籍の購入、等とにかく調べてインプットしまくる。

大量のインプットしたら、調べた知識を整理して、とにかく集中して考える。


だいたい1週間以上ぶっ続けで考えて、トライアンドエラーを繰り返すと、上手く行くこと行かないこと、対策法とかがわかってくる。

この大量にインプット→集中して考えまくるという工程はこの記事でも共通してる。


techlife.cookpad.com


いろいろ考えて試したのち、styleGANを使ったanoGANの手法が上手く行くことがわかった。

これで異常検知用データセットを使ったところ、95%以上の精度を出して、無事プロジェクト終了。




主に考えた・使える手法


手法1.anoGAN+LOFやマハラノビス距離を使ったanomaly scoreの算出

図で見るとわかるように、かなり綺麗に分かれてる。ので、ここからLOFやマハラノビス距離でanomaly scoreを算出。

f:id:trafalbad:20190904151355p:plain



手法2.anoGAN+類似画像検索の手法

下の混合行列を見ればわかるように99%、誤差一枚という驚異的な精度。
f:id:trafalbad:20190904151627p:plain





不採用の手法



1.VAEのloss値を使った手法

→VAEの異常検知のサイトみたいに画像をカットせずに、256ピクセルのままloss値を利用する手法でも十分いけたけど、anoGANには及ばないし、使い勝手悪いのでやめた。


2.metric learning
qiitaとか論文では普通に話題だったけど、Resnet50をベースモデルに一から学習させてもそんな精度出なかった。
多分、lossでいろんな手法があるから、もっと深追いすれば上手く行く方法もあるだろう。けど、anoGANの方が精度いいので別にいいやってなった。



3.類似画像生成して差分から検知
下のようにクエリの異常画像とGANで生成した画像を1次元にして引算した後、画像に戻して、閾値超えたら検知する方法。
一枚しか一度にさばけないし、大量の画像の異常検知はできないのでやめた。


1次元にして差分を出し、画像に戻した

f:id:trafalbad:20190904151527j:plain


opencvのcv2.addWeightedメソッドを用いた差分の画像

f:id:trafalbad:20190904151545p:plain


画像を1次元にするアイデアを閃いた手法
f:id:trafalbad:20190904151800j:plain


4.飛び道具は使わず、オールマイティな手法の汎用性を上げる工夫をした

anoGANを超えたmodelが論文で結構出てるが、やってることはmnistとかcifar-10レベルのデータの精度検証であって、そのまま実用レベルとか、同じ精度で使えることはまずない。

大事なのは一般的な手法(例えばanoGANとか)で、どんな状況でも適用できるようにする工夫が出来ることが重要。


料理漫画とかで、仰々しいマシーン(分子ガストロミーの機械とか)ですごい料理作れるより、フライパンとか油のような一般的・汎用的な道具と知恵と工夫でそれに勝てる料理を作れる主人公が強いのがいい例。

あと、アイデアはかなり大事なので、

大量にinputする

・アイデアを盗む

アイシュタインみたいに一つのことについて考え続けると、思いもよらないideaが生まれる(普通の人はそこまで考えないから)

とにかくアイデアを手当たり次第紙に書き出し(失敗とか気にしない)、その中からよさそうなの試して、絞ったりする



とアイデアに凝って、インプットをしまくるとアウトプットも増えていいものができた。
良質なアウトプットには大量のインプット・良質なインプットが大事」というのが今回の教訓。






5.まとめ

ガチの異常検知、特に画像の異常検知は初めてで、一番難しいタスクだった気がする。

opencvの画像補正もやったので、かなりいい経験と勉強になった。あと画像系のこういうニーズは本当に多いなと思うし、今後も増えそう。


参考にしたサイト


画像を1d変換する時に参考にした記事

いろんなデータセット

【精度対決】リアルな画像で異常検知

BERTで6感情の感情分析モデルを作ってみた【機械学習、自然言語処理】

画像と違って文章から感情を予測すること(emotion prediction from text)は未だ自然言語処理NLP)界隈では、うまくいった事例が少ない。

特に、単純なネガポジ判定ではなく、6感情(怒り、驚き、幸せ、嫌悪、恐れ、悲しみ)を分析する感情分析は、研究が頻繁に行われてる。

今回はBERTでなるべく精度の高い感情分析モデルを作ってみた。

f:id:trafalbad:20190901145030j:plain

目次
・感情分析について
1.twitterからスクレイピングしてデータセット作成したcase
2.スクレイピングした映画レビューからデータセットを作ったcase
3.気づいたこと
4.まとめ

感情分析について

感情分析は英語でも日本語でも未だにうまくいってなくて、論文が頻繁にでてる分野。


難しい理由の一因は「データセットの作成が難しい」とか「ノイズの多い日本語のような難解な言語での感情判定が困難」だから。

比較的処理しやい英語でも、kaggleのIMDBの5段階ネガポジ判定で精度68%くらいだった。

なのでノイズ表現(” ~したいンゴ ”、 “~みが強い”、” インスタ蠅 ”)みたいな意味不な言葉が増えた、かつ難解な日本語の6感情の感情分析ならなおさらむずい。





極性分析な主なデータセットの作り方

①極性分析(主にネガポジ判定)では公開用の極性辞書を使い、ラベルをつけて作成。

②EkmanみたいなAPIで文章にラベルづけして作成

③極性辞書を自作してラベルをつけて作成

④どっからからスクレイピングして、感情ラベルの代わりにする(iPhoneスタンプとか)

⑤人手で一からしっかりデータセット作る

①②は極性辞書やAPI作成者の「どのように感情判定するか」の基準が如実に反映されてるので、個々のタスクごとに最良の結果が出るとは言えない。
なので③~⑤が各タスクのメインな手法な気がする。

Microsoftの例
f:id:trafalbad:20190901145342j:plain




今回は

・感情スタンプ付きのツイートをtwitterからスクレイピング(④)

・映画レビューをスクレイピングして自分で簡単なデータセットを作る(⑤)

の2つを試した。

twitterは6感情でよく使うiPhoneスタンプを含んだツイート、映画レビューは6感情をよく表す映画のレビューから自分でラベルをつけて、データセットを作った。






1.twitterからスクレイピングしてデータセット作成したcase

今回はなるべくいいネットワークを使うため、BERTを選択。よく理解した上で使った。
trafalbad.hatenadiary.jp



友達にアンケートとって6感情でよく使うiPhoneスタンプを教えてもらって、そのスタンプ含んだツイートをスクレイピング

run.sh

#!/bin/bash
# angry
twitterscraper 😠 --lang ja -o angry.json &
twitterscraper 😡 --lang ja -o angry2.json &
twitterscraper 😤 --lang ja -o angry3.json &
# disgust
twitterscraper 🤮 --lang ja -o disgust.json &
twitterscraper 😣  --lang ja -o disgust2.json &
# fear
twitterscraper 😨 --lang ja -o fear.json &
twitterscraper 😰 --lang ja -o fear2.json &
twitterscraper 😱 --lang ja -o fear3.json &
# happy
twitterscraper 😄 --lang ja -o happy.json &
twitterscraper 😆 --lang ja -o happy2.json &
twitterscraper 😂 --lang ja -o happy3.json &

# 以下略
wait;

echo "Done!:twitterscraper"

スクレイピング実行コマンド

$ chmod +x run.sh
$ ./run.sh &


EC2インスタンスGPUでも一日かかった。

データセット作成

# get tweet text and emotion label
emotions = [["angry", "angry1", "angry2"], ["disgust", "disgust2"], ["fear", "fear2", "fear3"], ["happy", "happy2", "happy3"],
            "sad", ["surprise", "surprise2", "surprise3"]]
dir_path = "sentiment_sh"

size = 60000
df = []
for i, es in enumerate(emotions):
    if isinstance(es, list):
        for e in es:
            try:
                data = shuffle(pd.read_json(join(dir_path, "{}.json".format(e)))).iloc[:int(size/len(es))]
                data['label'] = i
                df.append(data)
            except ValueError:
                continue
                
    else:
        data = shuffle(pd.read_json(join(dir_path, "{}.json".format(es)))).iloc[:int(size)]
        data['label'] = i
        df.append(data)
        
df = pd.concat(df)
df = shuffle(df)
text_df = df['text']
label_df = df['label']

dff=pd.concat([text_df, label_df], axis=1)
# save to csv
dff.to_csv('tweet.csv')

とりあえず、アルファベット、絵文字や顔文字とか日本語に関係ない文字が多すぎて、ほぼ文章じゃなかった。

なので、正規化して出来るだけまともな形にした後、BERTで転移学習。

正規化してもほぼ日本語じゃない形で、しかも感情を表す要因が、文章に反映されてない(嬉しい系のツイートでも悲しいスタンプ😢があったり)。

結果、データセットとしてかなり質が悪く、BERTでも精度は43%。





2.スクレイピングした映画レビューからデータセットを作ったcase

映画サイトから、Beautifulsoupでレビューをスクレイピング & 自分の直感でラベル付与して、データセット作った。

映画サイトはURLの形式がパターン化されてるのでスクレイピングしやすい。

ジブリ系(悲しい、幸せ)、ハングオーバー(笑い)、ランペイジ-巨獣大戦争(嫌悪)など6感情を愚直に反映してる映画の7このレビュー文をスクレイピング

rating = []
reviews =[]
first_url = 'https://******/movies/82210'
next_urls = 'https://******/movies/82210?page='
for i in range(1,200):
  if i==1:
    next_url = first_url
  else:
    next_url = next_urls+str(i)
    
  result = requests.get(next_url)
  c = result.content
  # HTMLを元に、オブジェクトを作る
  soup = BeautifulSoup(c, "lxml")
  # リストの部分を切り出し
  sums = soup.find("div",{'class':'l-main'})
  com = sums.find_all('div', {'class':'p-mark'})

  # get review
  for rev in com:
    reviews.append(rev.text)
  # get rating
  for crate in com:
    for rate in crate.find_all('div', {'class':'c-rating__score'}):
      rating.append(rate.text)
  # print(i)

# save review data as DataFrame
rev_list = Series(reviews)
rate_list = Series(rating)
print(len(rev_list), len(rate_list))

movie_df = pd.concat([rev_list, rate_list],axis=1)
movie_df.columns=['review','rating']
movie_df.to_csv('movie_review.csv', sep = '\t',encoding='utf-16')

レビュー文は割と長めのしっかりしたレビューを選択。

・trainデータ:15000
・testデータ:1000

trainデータはスピード重視で、特に感情判定にルール設定はせずに、6感情のうち当てはまりそうな感情のラベルを直感でつけた。

testデータは映画「天気の子」からスクレイピングして、mecabで正規化時に「名詞、形容詞、動詞」のいずれかを10個以上含むしっかりとした長さの文章をランダムに取り出した。


精度は72%。直感で感情判定したが精度はtwitterと比べるとかなり高め。データセットの質が精度に影響してるのがよくわかる。

あとこの精度は「映画のレビュー & 割としっかりした長さの文章」という条件下での精度なので、他のドメインの文章(医療、経済、メディア.etc)に同じ精度は出ない可能性は高い。

やっぱり、特定のドメインの文章には特定のドメインに特化したモデルを作るのがベストだと思う。





3.気づいたこと

・日本語は「2チャンネルやtwitterのような崩壊寸前の文章」、「メディア系のお堅い文章」の最低2つのドメインは確実に存在する。


・なぜこの文章が「怒り、驚き、幸せ、嫌悪、恐れ、悲しみ」の感情に分類されるか、ちゃんとルールを設けること(ルールベース)

ルールの例
ex1.「びっくり」の名詞を含む=驚き

ex2.「びびり」の表現=嫌悪
とか


・「お堅い文章」、「2ちゃんやtwitterのような崩壊した日本語の文章」で大きく2つにドメインが分かれるので、一つで完璧な感情分析言語モデルを作るのは難しい


・完璧な感情分析モデルは作成困難なので、ドメイン別にモデルを作るのがベター


・感情判定にルールを設けた上で、人手できちんとデータセットを作るべし(人間が理解できないものは機械学習でもできないし)


・人手で作るなら質の高いクラウドソーシングのAmazon Mechanical Turkを利用するのがおすすめ


4.まとめ

ラベルをつけるとき、「なぜこの文章はこの感情になるか」のルールを決めるとさらに精度は上がることは間違いない。

2ちゃんや、twitterみたいな日本語が崩壊してるレベルの文章(仲間同士でしか使わない隠語、最近の意味不な表現が多数ある文章)は、映画のレビュー文とは全然違う。
かといって映画とかのレビュー文も、メディアみたいにお堅い文章とは違った。

ともあれ日本語には最低でも上の、2ドメインは確実に存在するので、タスクごとのドメインに特化したモデル作るのが得策だと思う。



メモ:感情分析モデルの活用方法


・患者の診断応答から感情予測して、感情が症状回復に関係ある場合に使える。
→この場合のベストプラクティスは医療ドメイン用文章のデータセットを作り、モデルを作るべきかなと思う



・エンタメで感情ごとに似たエンタメをレコメンドする。
→映画とかで悲しい映画がみたいとき、悲しいを表す似た映画をレコメンドする(作るの難しいし、実用性低いだろうけど)



参考site


日本語ツイートをEkmanの基本6感情で評価

Emotion Detection and Recognition from Text Using Deep Learning

・感情分析に関する情報
https://qr.ae/TWyb8i

自然言語処理タクスでよく使うAttentionの出力のAttention weightを可視化してみた【機械学習】

Attentionといえば、すでに自然言語処理モデルではなくてはならない存在。
カニズムは割愛。別名で注意機構とか呼ばれる。

Attentionの仕組みは、(個人的に理解してる範囲では)簡単に言うと以下のポイントがある。

・人間が特定のことに集中(注意)する仕組みと同じ

・Attentionの仕組みはAttention自体が特定の単語に注意(注目)する

・Attentionの挙動は人間の直感に近い

今回はそのAttentionが「どの単語を注意して見てるのか」わかるように、Attentionの出力結果Attention weightを可視化してみた。


こんな感じ
f:id:trafalbad:20190804133028j:plain


その過程を備忘録も兼ねてまとめてく。

目次
・今回の記事の概略
1.データ読み込み
2.モデル構築・訓練
3.Attention可視化用に訓練したモデルを再構築
4.Attention weightの可視化
・まとめ




今回の記事の概略

タスクとデータセットは、前回の日本語版BERTの記事で使った" livedoorニュースコーパス "を使ったトピック分類で、「Sports、トピックニュース」のトピックを分類するタスク。

BERTでのAttention可視化は

・BERTではマスク処理がある

・MaltiHeadAttentionを使ってる

等の理由で基本的にBERTでのAttentionの可視化はできないっぽいので、簡易モデルを作ってAttentionがどの単語に注意を払ってるのか可視化してみた。


AttentionにはMaltiHeadAttentionとか、いろいろ種類があるが、可視化にはselfAttentionが使われる。
f:id:trafalbad:20190804133614j:plain




selfAttentionを含めてAttentionの仕組みは下記サイトに詳しく載ってる。
qiita.com


幸い、kerasにpipでinstallできるselfAttentionがあるので、それを使ってAttentionの出力のAttention weightを可視化してみる。





1.データ読み込み

前回記事の日本語版BERTで使用した
・前処理済みテキストデータ

・idベクトル化してない日本語のテキストデータ(all_txet.npy)

を使う。

# 読み込み
train_x = np.load('train_xs.npy')
train_y = np.load('train_label.npy')
test_x = np.load('test_xs.npy')
test_y = np.load('test_label.npy')
# id化してない日本語の文章も読み込み
all_text = np.load('all_text.npy')

# one-hot表現
n_labels = len(np.unique(train_y))
train_y=np.eye(n_labels)[train_y] 
train_y = np.array(train_y)


2.モデル構築・訓練

BERTは上述の通り、マスク処理とMaltiHeadAttentionを使ってるので、Attentionの可視化はできなかった。

なのでAttention可視化用に、双方向LSTMを使った簡易モデルを作成して、学習した

h_dim=356
seq_len = 691
vocab_size = 23569+1

inp = Input(batch_shape = [None, seq_len])
emb = Embedding(vocab_size, 300)(inp) # (?,128,32)
att_layer = SeqSelfAttention(name='attention')(emb)  # embbedingレイヤーの後にselfattentionを配置
out = Bidirectional(LSTM(h_dim))(att_layer)
output = Dense(2, activation='softmax')(out)  # shape=(?, 2)
model = Model(inp, output)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['acc', 'mse', 'mae'])
model.summary()
>>>
==================================================
input_1 (InputLayer) (None, 691) 0 _________________________________________________________________ embedding_1 (Embedding) (None, 691, 300) 7071000 _________________________________________________________________ attention (SeqSelfAttention) (None, 691, 300) 19265 _________________________________________________________________ bidirectional_1 (Bidirection (None, 712) 1871136 _________________________________________________________________ dense_1 (Dense) (None, 2) 1426 ======================================================



# train
model.fit(train_x, train_y, epochs=1, batch_size=10)

# 予測
predicts = model.predict(test_x, verbose=True).argmax(axis=-1)
print(np.sum(test_y == predicts) / test_y.shape[0])

正解率は、前処理とAttentionのおかげで、BERT並みの92%





3.Attention可視化用に訓練したモデルを再構築

訓練したモデルでAttention weight可視化用にモデルの再構築。

selfAttentionのレイヤーの出力

最後の出力層の出力

の両方をModelのoutputに追加

emodel = Model(inputs=model.input, outputs=[model.output, model.layers[2].output])
emodel.summary()

>>>
======================================================
 input_2 (InputLayer) (None, 691) 0 _________________________________________________________________ embedding_2 (Embedding) (None, 691, 300) 7071000 _________________________________________________________________ attention (SeqSelfAttention) (None, 691, 300) 19265 _________________________________________________________________ bidirectional_2 (Bidirection (None, 712) 1871136 _________________________________________________________________ dense_2 (Dense) (None, 2) 1426 ====================================================

精度も申し分ないので、後はどこの単語をAttentionが注目してるのかを可視化するだけ。




4.Attention weightの可視化

まずkeras のselfAttentionをこのサイトからインストール。

$ pip install keras-self-attention


Attention weightを可視化


どの単語に注目してるかの重みの総和を計算。


from keras_self_attention import SeqSelfAttention
import pandas as pd

# 予測後、Attentionの出力((batch, words, 300)=(batch, 691, 300))を取り出す
predict=emodel.predict(test_x)

token = all_text[700]

# 対象の文章(1batch)の中の175個のwords一つ一つから、3次元目(300dim)のmax値とる  =>shape=(1, 175, 1)
weight = [w.max() for w in predict[1][0][:175]]  # test_x[0][:176]

# pandasに入れてまず数値化。そのあとHTML形式にして、jupyter上で可視化
df = pd.DataFrame([token, weight], index=['token', 'weight'])
mean = np.array(weight).mean()
print(df.shape, mean)
df = df.T

df['rank'] = df['weight'].rank(ascending=False)

# 各wordsのmax値から全max値の平均を引き Attention  weightを計算。マイナスの値は0扱い
df['normalized'] = df['weight'].apply(lambda w: max(w - mean, 0))
df['weight'] = df['weight'].astype('float32')
df['attention'] = df['normalized'] > 0



【pandasの可視化結果】
f:id:trafalbad:20190804133443j:plain

ちなみにHTML形式でjupyter上で可視化するときは下のメソッドを使った。

from IPython.core.display import display, HTML
# HTMLで可視化
display(HTML(html))


【トピックを「Sports」と予測できたときのattention weight & その文章の可視化結果】
f:id:trafalbad:20190804133028j:plain




データセットが少なく、語彙数が少ないのもあるが、「ドラフト、日本ハム、指名、会議」とか、スポーツ(野球)に関連しそうなワードが赤いので、そこにAttentionが注目してるのがわかる。

噂通り、割と人間の直感に近い感じの語彙に注目してる。





Attention weight可視化で気づいたことメモ



①可視化するselfattentionレイヤーの出力はshapeは3次元でもいい


②selfattentionレイヤーはembbedingレイヤーの後に配置するのが定石っぽい


③BERTでのAttention可視化は無理(っぽい)
→マスク処理してる

→MaltiHeadAttentionを使ってる

→工夫次第ではできそう



④matplotlibで可視化もできる。
f:id:trafalbad:20190804133632j:plain


まとめ

はじめはAttentionを書籍で読んだり、調べたりしてもサッパリだった。

けど、実務で使って考えまくることで、仕組み・種類、使い方、なんで精度高くなるのとかかなり理解できた。

経験に勝る知識なしっていう格言のいい経験。



参考サイト
selfattentionを簡単に予測理由を可視化できる文書分類モデルを実装する

自然言語処理で使われるAttentionのWeightを可視化する

学習済み英語版keras BERTのfine tuning(転移学習)でネガポジ判定の2値分類をしてみた【機械学習・自然言語処理】

google自然言語処理の高性能モデルBERTを使ってfine tuning(転移学習)をやった。BERT用のデータセットではなく、一般に公開されてるIMDBデータセット(映画レビュー)を使用。
2値分類用にBERTモデルを再構築して、ネガポジ判定したので、その過程をまてめてく。

目次
・今回のタスク
・データセットの作成と中身
・学習済みBERTのload & 2値分類用に再構築
・正解率
・まとめ & BERTの全体像


今回のタスク

タスクは2値分類のネガポジ判定。データセットIMDBデータセットで中身は「映画のレビューとラベル(negative=0, positive=1)」。

BERTには事前学習と転移学習の二つの使い道があり、普通は事前学習(一から学習)はせず、転移学習で十分高性能が出せるし、

word2vecの拡張でAttentionで文脈を考慮したネガポジ判定ができ、注目した部分をヒートマップで染めることもできる。
f:id:trafalbad:20190720080557j:plain


今回はチュートリアルのこのサイトをほとんど真似させてもらったので、コードはサイトを参照してほしい。Git cloneしたkeras bertを使って、転移学習をやった。

colabratoryでtensorflowのversionが2019/6月から1.14になってて、ネットに出回ってるTPUの使用方法が使えなくなってた。

このサイト通りにやるのもいいけど、keras bertのように、modelを再構築する時にはよくわからないエラーがでまくったので、tensorflowをversion==1.13.1に戻した。

!pip uninstall tensorflow && pip install tensorflow==1.13.1


データセットの作成と中身

データセットの呼び出しはサイト通りなので、その中身をまとめていく。

1 def load_data(path):
2     global tokenizer
3     indices, sentiments = [], []
4     for folder, sentiment in (('neg', 0), ('pos', 1)):
5         folder = os.path.join(path, folder)
6         for name in tqdm(os.listdir(folder)):
7             with open(os.path.join(folder, name), 'r') as reader:
8                   text = reader.read()
9             ids, segments = tokenizer.encode(text, max_len=SEQ_LEN)
10            indices.append(ids)
11            sentiments.append(sentiment)
12     items = list(zip(indices, sentiments))
13     np.random.shuffle(items)
14     indices, sentiments = zip(*items)
15     indices = np.array(indices)
16     return [indices, np.zeros_like(indices)], np.array(sentiments)

データセットの中身



・4行目folder
'/root/.keras/datasets/aclImdb/train/neg'、 ‘/root/.keras/datasets/aclImdb/train/pos'
で文字列の’neg’と’pos’を受け取る。

・4行目のsentiment
[0,1] 。for文から返る右側の0, 1を受け取る

・6行目のname
中身が
['12398_2.txt', '12239_1.txt', '8140_2.txt', ‘9293_1.txt',~]
のtextファイル25000このリスト

・8行目のtext中身(の一部)

Look, it\'s the third one, so you already know it\'s bad. And "Maniac Cop" wasn\'t good enough to warrant the second installment, so you know it\'s even worse. But how much worse? Awful, approaching God-awful.<br /><br />When Maniac Cop goes on a killing spree, a reporter exclaims, "What happened here can ONLY be described as a black rainbow of death."<br /><br />1-- Rainbows are not black, and can never be. 2-- Rainbows are harmless, and can never inflict pain or death. 3-- A news reporter, one valuable to his agency, might find another way to describe the aftermath of a killing spree. "A black rainbow of death" is not the ONLY way to describe the given situation.<br /><br />This is what you\'re in for.’~

・10行目のindices
長さ128のlistが25000こ。len(indices[0])=>128, len(indices)=>25000
[101, 2298,1010, 2009, 1005, 1055, 29310,~,102]

・11行目のsentiments
中身が0, 1のlist、len(sentiments)=>25000
[0,1,0,0,0,1,0,~]

・15行目のindices = np.array(indices)
shape=(25000, 128)のarray。中身は10行目と同じ。

・16行目のnp.array(sentiments)
shape=(25000,)のarray。中身は11行目と同じ

・16行目の[indices, np.zeros_like(indices)]
np.zeros_like()は np.zeros(indices.shape)=np.zeros_like(indices)
中身は下のようになってて、shape=(25000, 128)が2つある。

[array([[ 101, 1996, 2466, ...,    0,    0,    0],
        [ 101, 1037, 2158, ..., 4287, 1996,  102],
        [ 101, 1045, 2001, ..., 2023, 2499,  102],
        ...,
        [ 101, 7929, 2182, ..., 7987, 1013,  102],
        [ 101, 2672, 1045, ..., 1999, 1996,  102],
        [ 101, 1996, 3213, ..., 2046, 1037,  102]]),
 array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]])]


学習済みBERTのload & 2値分類用に再構築

inputs = bert.inputs[:2]
dense = bert.get_layer('NSP-Dense').output
outputs = keras.layers.Dense(units=2, activation='softmax')(dense)
# rebuild BERT model 
model = keras.models.Model(inputs, outputs)


違うのは最後の出力層だけ

Encoder-12-FeedForward-Norm (La (None, 128, 768)     1536        Encoder-12-FeedForward-Add[0][0] 
__________________________________________________________________________________________________
Extract (Extract)               (None, 768)          0           Encoder-12-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
NSP-Dense (Dense)               (None, 768)          590592      Extract[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 2)            1538        NSP-Dense[0][0]  

正解率

88%で転移学習でも十分高性能。



まとめ & BERTの全体像

今回はBERTで転移学習をしてみた。BERT専用データセットじゃなく一般公開用のを使ったから、かなり勉強になった。

keras_bertを使えば、modelを改造することで、様々な入力形式の自然言語処理タスクで高性能が出せるようだ。

参考サイトKeras BERTでファインチューニングしてみる


BERT全体像


Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input-Token (InputLayer)        (None, 128)          0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, 128)          0                                            
__________________________________________________________________________________________________
Embedding-Token (TokenEmbedding [(None, 128, 768), ( 23440896    Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, 128, 768)     1536        Input-Segment[0][0]              
__________________________________________________________________________________________________
Embedding-Token-Segment (Add)   (None, 128, 768)     0           Embedding-Token[0][0]            
                                                                 Embedding-Segment[0][0]          
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, 128, 768)     98304       Embedding-Token-Segment[0][0]    
__________________________________________________________________________________________________
Embedding-Dropout (Dropout)     (None, 128, 768)     0           Embedding-Position[0][0]         
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, 128, 768)     1536        Embedding-Dropout[0][0]          
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 128, 768)     2362368     Embedding-Norm[0][0]             
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-1-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 128, 768)     0           Embedding-Norm[0][0]             
                                                                 Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-1-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-1-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-1-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-1-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-1-MultiHeadSelfAttention-
                                                                 Encoder-1-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-1-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-1-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-1-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-2-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-1-FeedForward-Norm[0][0] 
                                                                 Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-2-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-2-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-2-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-2-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-2-MultiHeadSelfAttention-
                                                                 Encoder-2-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-2-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-2-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-2-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-3-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-2-FeedForward-Norm[0][0] 
                                                                 Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-3-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-3-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-3-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-3-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-3-MultiHeadSelfAttention-
                                                                 Encoder-3-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-3-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-3-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-3-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-4-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-3-FeedForward-Norm[0][0] 
                                                                 Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-4-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-4-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-4-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-4-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-4-MultiHeadSelfAttention-
                                                                 Encoder-4-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-4-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-4-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-4-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-5-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-4-FeedForward-Norm[0][0] 
                                                                 Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-5-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-5-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-5-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-5-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-5-MultiHeadSelfAttention-
                                                                 Encoder-5-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-5-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-5-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-5-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-6-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-5-FeedForward-Norm[0][0] 
                                                                 Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-6-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-6-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-6-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-6-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-6-MultiHeadSelfAttention-
                                                                 Encoder-6-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-6-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-6-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-6-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-7-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-6-FeedForward-Norm[0][0] 
                                                                 Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-7-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-7-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-7-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-7-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-7-MultiHeadSelfAttention-
                                                                 Encoder-7-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-7-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-7-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-7-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-8-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-7-FeedForward-Norm[0][0] 
                                                                 Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-8-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-8-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-8-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-8-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-8-MultiHeadSelfAttention-
                                                                 Encoder-8-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-8-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-8-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 128, 768)     2362368     Encoder-8-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-9-MultiHeadSelfAttention[
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 128, 768)     0           Encoder-8-FeedForward-Norm[0][0] 
                                                                 Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-MultiHeadSelfAttentio (None, 128, 768)     1536        Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-FeedForward (FeedForw (None, 128, 768)     4722432     Encoder-9-MultiHeadSelfAttention-
__________________________________________________________________________________________________
Encoder-9-FeedForward-Dropout ( (None, 128, 768)     0           Encoder-9-FeedForward[0][0]      
__________________________________________________________________________________________________
Encoder-9-FeedForward-Add (Add) (None, 128, 768)     0           Encoder-9-MultiHeadSelfAttention-
                                                                 Encoder-9-FeedForward-Dropout[0][
__________________________________________________________________________________________________
Encoder-9-FeedForward-Norm (Lay (None, 128, 768)     1536        Encoder-9-FeedForward-Add[0][0]  
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 128, 768)     2362368     Encoder-9-FeedForward-Norm[0][0] 
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-9-FeedForward-Norm[0][0] 
                                                                 Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-MultiHeadSelfAttenti (None, 128, 768)     1536        Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-FeedForward (FeedFor (None, 128, 768)     4722432     Encoder-10-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-10-FeedForward-Dropout  (None, 128, 768)     0           Encoder-10-FeedForward[0][0]     
__________________________________________________________________________________________________
Encoder-10-FeedForward-Add (Add (None, 128, 768)     0           Encoder-10-MultiHeadSelfAttention
                                                                 Encoder-10-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-10-FeedForward-Norm (La (None, 128, 768)     1536        Encoder-10-FeedForward-Add[0][0] 
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 128, 768)     2362368     Encoder-10-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-10-FeedForward-Norm[0][0]
                                                                 Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-MultiHeadSelfAttenti (None, 128, 768)     1536        Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-FeedForward (FeedFor (None, 128, 768)     4722432     Encoder-11-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-11-FeedForward-Dropout  (None, 128, 768)     0           Encoder-11-FeedForward[0][0]     
__________________________________________________________________________________________________
Encoder-11-FeedForward-Add (Add (None, 128, 768)     0           Encoder-11-MultiHeadSelfAttention
                                                                 Encoder-11-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-11-FeedForward-Norm (La (None, 128, 768)     1536        Encoder-11-FeedForward-Add[0][0] 
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 128, 768)     2362368     Encoder-11-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 128, 768)     0           Encoder-11-FeedForward-Norm[0][0]
                                                                 Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-MultiHeadSelfAttenti (None, 128, 768)     1536        Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-FeedForward (FeedFor (None, 128, 768)     4722432     Encoder-12-MultiHeadSelfAttention
__________________________________________________________________________________________________
Encoder-12-FeedForward-Dropout  (None, 128, 768)     0           Encoder-12-FeedForward[0][0]     
__________________________________________________________________________________________________
Encoder-12-FeedForward-Add (Add (None, 128, 768)     0           Encoder-12-MultiHeadSelfAttention
                                                                 Encoder-12-FeedForward-Dropout[0]
__________________________________________________________________________________________________
Encoder-12-FeedForward-Norm (La (None, 128, 768)     1536        Encoder-12-FeedForward-Add[0][0] 
__________________________________________________________________________________________________
MLM-Dense (Dense)               (None, 128, 768)     590592      Encoder-12-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
MLM-Norm (LayerNormalization)   (None, 128, 768)     1536        MLM-Dense[0][0]                  
__________________________________________________________________________________________________
Extract (Extract)               (None, 768)          0           Encoder-12-FeedForward-Norm[0][0]
__________________________________________________________________________________________________
MLM-Sim (EmbeddingSimilarity)   (None, 128, 30522)   30522       MLM-Norm[0][0]                   
                                                                 Embedding-Token[0][1]            
__________________________________________________________________________________________________
Input-Masked (InputLayer)       (None, 128)          0                                            
__________________________________________________________________________________________________
NSP-Dense (Dense)               (None, 768)          590592      Extract[0][0]                    
__________________________________________________________________________________________________
MLM (Masked)                    (None, 128, 30522)   0           MLM-Sim[0][0]                    
                                                                 Input-Masked[0][0]               
__________________________________________________________________________________________________
NSP (Dense)                     (None, 2)            1538        NSP-Dense[0][0]          

高性能版GANの「styleGAN」で本物そっくりの画像を生成してみた【keras・機械学習】

今回は論文で紹介されてたNVIDIAが開発したstyleGANを実装してみた。

普通のGANとは生成過程も違うし、生成画像の出来の精度も比較にならないぐらい高くて、驚いた。
仕事で使う機会があったので、その生成過程をまとめてく。


目次
1.styleGANについて
2.styleGANコード詳細
3.訓練
4.生成画像
5.まとめ


1.styleGANについて

styleGANはNVIDIAが開発した、本物と見分けがつかないくらいの画像が作れる、超高精度のGAN。
qiita.com


"Progressive-Growing of GANs”というGANの亜種のgeneratorの部分を発展させたもの。

メインのメカニズムは、低い解像度の層から順に学習して、高精度の画像の生成していく仕組みらしい。


【従来のGANとstyleGANの違い】
f:id:trafalbad:20190707152347j:plain

generatorの部分がAdaptive Instance Normalization (AdaIN)でかなり改造してあるのが普通のGANと大きく異なる点。

ちなみにdiscriminatorは普通のGANと同じ。


【Generator部分】
f:id:trafalbad:20190707152336j:plain




2.styleGANコード詳細

input画像



入力画像はgoogle検索でもってきた景色の画像3枚を120枚に増幅。
f:id:trafalbad:20190707154813p:plain


Macの動画ソフト「Quick time player」で動画にして、ffmpegで静止画に変換。120枚くらいに増やした。

$ ffmpeg -i 元動画.avi -ss 144 -t 148 -r 24 -f image2 %06d.jpg

かなり簡単に画像が大量に作れるので便利。

他にもopencvで動画から、静止画を作るやり方もある。

input画像等、ハイパーパラメータの条件はこんな感じ

・画像の合計120枚

・shape=(256, 256, 3)

・Batch =10

・255で割って正規化

・1000~2000エポック

各景色画像30枚ずつで、計120枚を10batchずつ回して訓練してく。

ちなみにoptimizerを含めて、今回のパラメータは、サイズが256の時にこのパラメータでうまくいった。

けど、512とか1024は同じパラメータでは上手くいくかわからない。




Generator



generatorはinputが3つあって、2つは論文にあるように、ノイズを入れるところになってる。

AdaINが普通のGANとの違いがでかい。

from AdaIN import AdaInstanceNormalization

im_size = 256
latent_size = 512

def g_block(inp, style, noise, fil, u = True):

    b = Dense(fil)(style)
    b = Reshape([1, 1, fil])(b)
    g = Dense(fil)(style)
    g = Reshape([1, 1, fil])(g)

    n = Conv2D(filters = fil, kernel_size = 1, padding = 'same', kernel_initializer = 'he_normal')(noise)

    if u:
        out = UpSampling2D(interpolation = 'bilinear')(inp)
        out = Conv2D(filters = fil, kernel_size = 3, padding = 'same', kernel_initializer = 'he_normal')(out)
    else:
        out = Activation('linear')(inp)

    out = AdaInstanceNormalization()([out, b, g])
    out = add([out, n])
    out = LeakyReLU(0.01)(out)

    b = Dense(fil)(style)
    b = Reshape([1, 1, fil])(b)
    g = Dense(fil)(style)
    g = Reshape([1, 1, fil])(g)

    n = Conv2D(filters = fil, kernel_size = 1, padding = 'same', kernel_initializer = 'he_normal')(noise)

    out = Conv2D(filters = fil, kernel_size = 3, padding = 'same', kernel_initializer = 'he_normal')(out)
    out = AdaInstanceNormalization()([out, b, g])
    out = add([out, n])
    out = LeakyReLU(0.01)(out)

    return out

def generator():

    inp_s = Input(shape = [latent_size])
    sty = Dense(512, kernel_initializer = 'he_normal')(inp_s)
    sty = LeakyReLU(0.1)(sty)
    sty = Dense(512, kernel_initializer = 'he_normal')(sty)
    sty = LeakyReLU(0.1)(sty)

    inp_n = Input(shape = [im_size, im_size, 1])
    noi = [Activation('linear')(inp_n)]
    curr_size = im_size
    while curr_size > 4:
        curr_size = int(curr_size / 2)
        noi.append(Cropping2D(int(curr_size/2))(noi[-1]))

    inp = Input(shape = [1])
    x = Dense(4 * 4 * 512, kernel_initializer = 'he_normal')(inp)
    x = Reshape([4, 4, 512])(x)
    x = g_block(x, sty, noi[-1], 512, u=False)

    if(im_size >= 1024):
        x = g_block(x, sty, noi[7], 512) # Size / 64
    if(im_size >= 512):
        x = g_block(x, sty, noi[6], 384) # Size / 64
    if(im_size >= 256):
        x = g_block(x, sty, noi[5], 256) # Size / 32
    if(im_size >= 128):
        x = g_block(x, sty, noi[4], 192) # Size / 16
    if(im_size >= 64):
        x = g_block(x, sty, noi[3], 128) # Size / 8

    x = g_block(x, sty, noi[2], 64) # Size / 4
    x = g_block(x, sty, noi[1], 32) # Size / 2
    x = g_block(x, sty, noi[0], 16) # Size
    x = Conv2D(filters = 3, kernel_size = 1, padding = 'same', activation = 'sigmoid')(x)
    return Model(inputs = [inp_s, inp_n, inp], outputs = x)


Discriminator



Discriminatorは従来のと変わらない。

今回は、1024サイズにも対応できるようにするために、有名なGANのDiscriminatorを使った。

def discriminator():
    inp = Input(shape = [im_size, im_size, 3])

    x = d_block(inp, 16) #Size / 2
    x = d_block(x, 32) #Size / 4
    x = d_block(x, 64) #Size / 8

    if (im_size > 32):
       x = d_block(x, 128) #Size / 16

    if (im_size > 64):
        x = d_block(x, 192) #Size / 32

    if (im_size > 128):
        x = d_block(x, 256) #Size / 64

    if (im_size > 256):
        x = d_block(x, 384) #Size / 128

    if (im_size > 512):
        x = d_block(x, 512) #Size / 256

    x = Flatten()(x)
    x = Dense(128)(x)
    x = Activation('relu')(x)
    x = Dropout(0.6)(x)
    x = Dense(1)(x)
    
    return Model(inputs = inp, outputs = x)


disganとadganの実装


disganでdiscriminatorを訓練。adganでgeneratorを訓練する。

G = generator()
D = discriminator()

# Dは更新して、Gは更新しない
D.trainable = True
for layer in D.layers:
    layer.trainable = True

G.trainable = False
for layer in G.layers:
    layer.trainable = False

ri = Input(shape = [im_size, im_size, 3])
dr = D(ri)

gi = Input(shape = [latent_size])
gi2 = Input(shape = [im_size, im_size, 1])
gi3 = Input(shape = [1])
df = D(G([gi, gi2, gi3]))
da = D(ri)
disgan = Model(inputs=[ri, gi, gi2, gi3], outputs=[dr, df, da])

# Gは更新、Dは更新しない
D.trainable = False
for layer in D.layers:
    layer.trainable = False

G.trainable = True
for layer in G.layers:
    layer.trainable = True

gi = Input(shape = [latent_size])
gi2 = Input(shape = [im_size, im_size, 1])
gi3 = Input(shape = [1])
df = D(G([gi, gi2, gi3]))
adgan = Model(inputs = [gi, gi2, gi3], outputs = df)



lossとoptimizer


lossはGANには珍しくMSEを使った。

optimizerはAdamで学習率(lr)は下のように設定。

・Disgan => lr=0.0002

・Adage => lr=0.0001

両方0.0001よりdisganを少し大きくした方が精度がいい結果になった。

def gradient_penalty_loss(y_true, y_pred, averaged_samples, weight):
    gradients = K.gradients(y_pred, averaged_samples)[0]
    gradients_sqr = K.square(gradients)
    gradient_penalty = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    return K.mean(gradient_penalty * weight)

partial_gp_loss = partial(gradient_penalty_loss, averaged_samples = ri, weight = 5)

disgan.compile(optimizer=Adam(lr=0.0002, beta_1 = 0, beta_2 = 0.99, decay = 0.00001), loss=['mse', 'mse', partial_gp_loss])

adgan.compile(optimizer = Adam(lr=0.0001, beta_1 = 0, beta_2 = 0.99, decay = 0.00001), loss = 'mse')

3.訓練

主にdisganとadganのtrain。
従来のGANと訓練方法は同じだけど、generatorが違う分、disganとadgan共に独特な値を入れてる。

# Noise 
def noise(n):
    return np.random.normal(0.0, 1.0, size = [n, latent_size])

def noiseImage(n):
    return np.random.uniform(0.0, 1.0, size = [n, im_size, im_size, 1])

# train disgan
real_images  = x_train[idx*batch_size:(idx+1)*batch_size]
train_data = [real_images, noise(batch_size), noiseImage(batch_size), np.ones((batch_size, 1))]
d_loss = disgan.train_on_batch(train_data, [np.ones((batch_size, 1)),
                                                      -np.ones((batch_size, 1)), np.ones((batch_size, 1))])

# train adgan
g_loss = adgan.train_on_batch([noise(batch_size), noiseImage(batch_size), np.ones((batch_size, 1))], 
                                        np.zeros((batch_size, 1), dtype=np.float32))


4.生成画像

生成した画像はこんな感じ。論文通り、本物とそっくりでびっくりの超高精度。

【生成画像】f:id:trafalbad:20190707155211p:plain

生成過程は、初めは黒から始まって、ゴッホみたいな油絵みたいになっていき、本物に近づいてく感じ。


f:id:trafalbad:20190707155057g:plain




5.まとめ

同じ画像が多いとかなり早くできる。今回は3枚×40=120枚で、700エポックにはもう本物っぽいのが出来始めてた。


ただバリエーションが増えると(例えば100枚のスニーカーで、100枚全部違う画像の場合)、とんでもなく時間がかかる。

なるべく生成したい画像は同じやつを何枚も入れとくと早く生成できる。論文みたいな顔のバリエーションだと、とんでもない時間(10万エポックくらい)かかる気がする。

GANの訓練はかなり繊細なので、styleGANのパラメータ調整は普通のGAN以上に繊細だった。