googleの自然言語処理の高性能モデルBERTを使ってfine tuning(転移学習)をやった。BERT用のデータセットではなく、一般に公開されてるIMDBデータセット(映画レビュー)を使用。
2値分類用にBERTモデルを再構築して、ネガポジ判定したので、その過程をまてめてく。
目次
・今回のタスク
・データセットの作成と中身
・学習済みBERTのload & 2値分類用に再構築
・正解率
・まとめ & BERTの全体像
今回のタスク
タスクは2値分類のネガポジ判定。データセットはIMDBデータセットで中身は「映画のレビューとラベル(negative=0, positive=1)」。BERTには事前学習と転移学習の二つの使い道があり、普通は事前学習(一から学習)はせず、転移学習で十分高性能が出せるし、
word2vecの拡張でAttentionで文脈を考慮したネガポジ判定ができ、注目した部分をヒートマップで染めることもできる。
今回はチュートリアルのこのサイトをほとんど真似させてもらったので、コードはサイトを参照してほしい。Git cloneしたkeras bertを使って、転移学習をやった。
colabratoryでtensorflowのversionが2019/6月から1.14になってて、ネットに出回ってるTPUの使用方法が使えなくなってた。
このサイト通りにやるのもいいけど、keras bertのように、modelを再構築する時にはよくわからないエラーがでまくったので、tensorflowをversion==1.13.1に戻した。
!pip uninstall tensorflow && pip install tensorflow==1.13.1
データセットの作成と中身
データセットの呼び出しはサイト通りなので、その中身をまとめていく。1 def load_data(path): 2 global tokenizer 3 indices, sentiments = [], [] 4 for folder, sentiment in (('neg', 0), ('pos', 1)): 5 folder = os.path.join(path, folder) 6 for name in tqdm(os.listdir(folder)): 7 with open(os.path.join(folder, name), 'r') as reader: 8 text = reader.read() 9 ids, segments = tokenizer.encode(text, max_len=SEQ_LEN) 10 indices.append(ids) 11 sentiments.append(sentiment) 12 items = list(zip(indices, sentiments)) 13 np.random.shuffle(items) 14 indices, sentiments = zip(*items) 15 indices = np.array(indices) 16 return [indices, np.zeros_like(indices)], np.array(sentiments)
データセットの中身
・4行目folder
'/root/.keras/datasets/aclImdb/train/neg'、 ‘/root/.keras/datasets/aclImdb/train/pos'
で文字列の’neg’と’pos’を受け取る。
・4行目のsentiment
[0,1] 。for文から返る右側の0, 1を受け取る
・6行目のname
中身が
['12398_2.txt', '12239_1.txt', '8140_2.txt', ‘9293_1.txt',~]
のtextファイル25000このリスト
・8行目のtext中身(の一部)
Look, it\'s the third one, so you already know it\'s bad. And "Maniac Cop" wasn\'t good enough to warrant the second installment, so you know it\'s even worse. But how much worse? Awful, approaching God-awful.<br /><br />When Maniac Cop goes on a killing spree, a reporter exclaims, "What happened here can ONLY be described as a black rainbow of death."<br /><br />1-- Rainbows are not black, and can never be. 2-- Rainbows are harmless, and can never inflict pain or death. 3-- A news reporter, one valuable to his agency, might find another way to describe the aftermath of a killing spree. "A black rainbow of death" is not the ONLY way to describe the given situation.<br /><br />This is what you\'re in for.’~
・10行目のindices
長さ128のlistが25000こ。len(indices[0])=>128, len(indices)=>25000
[101, 2298,1010, 2009, 1005, 1055, 29310,~,102]
・11行目のsentiments
中身が0, 1のlist、len(sentiments)=>25000
[0,1,0,0,0,1,0,~]
・15行目のindices = np.array(indices)
shape=(25000, 128)のarray。中身は10行目と同じ。
・16行目のnp.array(sentiments)
shape=(25000,)のarray。中身は11行目と同じ
・16行目の[indices, np.zeros_like(indices)]
np.zeros_like()は np.zeros(indices.shape)=np.zeros_like(indices)
中身は下のようになってて、shape=(25000, 128)が2つある。
[array([[ 101, 1996, 2466, ..., 0, 0, 0], [ 101, 1037, 2158, ..., 4287, 1996, 102], [ 101, 1045, 2001, ..., 2023, 2499, 102], ..., [ 101, 7929, 2182, ..., 7987, 1013, 102], [ 101, 2672, 1045, ..., 1999, 1996, 102], [ 101, 1996, 3213, ..., 2046, 1037, 102]]), array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]])]
学習済みBERTのload & 2値分類用に再構築
inputs = bert.inputs[:2] dense = bert.get_layer('NSP-Dense').output outputs = keras.layers.Dense(units=2, activation='softmax')(dense) # rebuild BERT model model = keras.models.Model(inputs, outputs)
違うのは最後の出力層だけ
Encoder-12-FeedForward-Norm (La (None, 128, 768) 1536 Encoder-12-FeedForward-Add[0][0] __________________________________________________________________________________________________ Extract (Extract) (None, 768) 0 Encoder-12-FeedForward-Norm[0][0] __________________________________________________________________________________________________ NSP-Dense (Dense) (None, 768) 590592 Extract[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 2) 1538 NSP-Dense[0][0]
正解率
88%で転移学習でも十分高性能。まとめ & BERTの全体像
今回はBERTで転移学習をしてみた。BERT専用データセットじゃなく一般公開用のを使ったから、かなり勉強になった。keras_bertを使えば、modelを改造することで、様々な入力形式の自然言語処理タスクで高性能が出せるようだ。
参考サイト:Keras BERTでファインチューニングしてみる
BERT全体像
Layer (type) Output Shape Param # Connected to ================================================================================================== Input-Token (InputLayer) (None, 128) 0 __________________________________________________________________________________________________ Input-Segment (InputLayer) (None, 128) 0 __________________________________________________________________________________________________ Embedding-Token (TokenEmbedding [(None, 128, 768), ( 23440896 Input-Token[0][0] __________________________________________________________________________________________________ Embedding-Segment (Embedding) (None, 128, 768) 1536 Input-Segment[0][0] __________________________________________________________________________________________________ Embedding-Token-Segment (Add) (None, 128, 768) 0 Embedding-Token[0][0] Embedding-Segment[0][0] __________________________________________________________________________________________________ Embedding-Position (PositionEmb (None, 128, 768) 98304 Embedding-Token-Segment[0][0] __________________________________________________________________________________________________ Embedding-Dropout (Dropout) (None, 128, 768) 0 Embedding-Position[0][0] __________________________________________________________________________________________________ Embedding-Norm (LayerNormalizat (None, 128, 768) 1536 Embedding-Dropout[0][0] __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 128, 768) 2362368 Embedding-Norm[0][0] __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-1-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 128, 768) 0 Embedding-Norm[0][0] Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-1-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-1-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-1-MultiHeadSelfAttention- Encoder-1-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-1-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-1-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-1-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-2-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-1-FeedForward-Norm[0][0] Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-2-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-2-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-2-MultiHeadSelfAttention- Encoder-2-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-2-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-2-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-2-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-3-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-2-FeedForward-Norm[0][0] Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-3-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-3-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-3-MultiHeadSelfAttention- Encoder-3-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-3-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-3-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-4-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-3-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-4-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-4-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-4-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-3-FeedForward-Norm[0][0] Encoder-4-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-4-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-4-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-4-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-4-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-4-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-4-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-4-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-4-MultiHeadSelfAttention- Encoder-4-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-4-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-4-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-5-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-4-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-5-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-5-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-5-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-4-FeedForward-Norm[0][0] Encoder-5-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-5-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-5-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-5-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-5-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-5-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-5-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-5-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-5-MultiHeadSelfAttention- Encoder-5-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-5-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-5-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-6-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-5-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-6-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-6-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-6-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-5-FeedForward-Norm[0][0] Encoder-6-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-6-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-6-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-6-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-6-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-6-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-6-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-6-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-6-MultiHeadSelfAttention- Encoder-6-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-6-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-6-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-7-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-6-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-7-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-7-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-7-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-6-FeedForward-Norm[0][0] Encoder-7-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-7-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-7-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-7-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-7-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-7-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-7-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-7-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-7-MultiHeadSelfAttention- Encoder-7-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-7-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-7-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-8-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-7-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-8-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-8-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-8-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-7-FeedForward-Norm[0][0] Encoder-8-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-8-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-8-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-8-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-8-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-8-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-8-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-8-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-8-MultiHeadSelfAttention- Encoder-8-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-8-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-8-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-9-MultiHeadSelfAttentio (None, 128, 768) 2362368 Encoder-8-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-9-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-9-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-9-MultiHeadSelfAttentio (None, 128, 768) 0 Encoder-8-FeedForward-Norm[0][0] Encoder-9-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-9-MultiHeadSelfAttentio (None, 128, 768) 1536 Encoder-9-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-9-FeedForward (FeedForw (None, 128, 768) 4722432 Encoder-9-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-9-FeedForward-Dropout ( (None, 128, 768) 0 Encoder-9-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-9-FeedForward-Add (Add) (None, 128, 768) 0 Encoder-9-MultiHeadSelfAttention- Encoder-9-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-9-FeedForward-Norm (Lay (None, 128, 768) 1536 Encoder-9-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-10-MultiHeadSelfAttenti (None, 128, 768) 2362368 Encoder-9-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-10-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-10-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-10-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-9-FeedForward-Norm[0][0] Encoder-10-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-10-MultiHeadSelfAttenti (None, 128, 768) 1536 Encoder-10-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-10-FeedForward (FeedFor (None, 128, 768) 4722432 Encoder-10-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-10-FeedForward-Dropout (None, 128, 768) 0 Encoder-10-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-10-FeedForward-Add (Add (None, 128, 768) 0 Encoder-10-MultiHeadSelfAttention Encoder-10-FeedForward-Dropout[0] __________________________________________________________________________________________________ Encoder-10-FeedForward-Norm (La (None, 128, 768) 1536 Encoder-10-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-11-MultiHeadSelfAttenti (None, 128, 768) 2362368 Encoder-10-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-11-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-11-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-11-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-10-FeedForward-Norm[0][0] Encoder-11-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-11-MultiHeadSelfAttenti (None, 128, 768) 1536 Encoder-11-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-11-FeedForward (FeedFor (None, 128, 768) 4722432 Encoder-11-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-11-FeedForward-Dropout (None, 128, 768) 0 Encoder-11-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-11-FeedForward-Add (Add (None, 128, 768) 0 Encoder-11-MultiHeadSelfAttention Encoder-11-FeedForward-Dropout[0] __________________________________________________________________________________________________ Encoder-11-FeedForward-Norm (La (None, 128, 768) 1536 Encoder-11-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-12-MultiHeadSelfAttenti (None, 128, 768) 2362368 Encoder-11-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-12-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-12-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-12-MultiHeadSelfAttenti (None, 128, 768) 0 Encoder-11-FeedForward-Norm[0][0] Encoder-12-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-12-MultiHeadSelfAttenti (None, 128, 768) 1536 Encoder-12-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-12-FeedForward (FeedFor (None, 128, 768) 4722432 Encoder-12-MultiHeadSelfAttention __________________________________________________________________________________________________ Encoder-12-FeedForward-Dropout (None, 128, 768) 0 Encoder-12-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-12-FeedForward-Add (Add (None, 128, 768) 0 Encoder-12-MultiHeadSelfAttention Encoder-12-FeedForward-Dropout[0] __________________________________________________________________________________________________ Encoder-12-FeedForward-Norm (La (None, 128, 768) 1536 Encoder-12-FeedForward-Add[0][0] __________________________________________________________________________________________________ MLM-Dense (Dense) (None, 128, 768) 590592 Encoder-12-FeedForward-Norm[0][0] __________________________________________________________________________________________________ MLM-Norm (LayerNormalization) (None, 128, 768) 1536 MLM-Dense[0][0] __________________________________________________________________________________________________ Extract (Extract) (None, 768) 0 Encoder-12-FeedForward-Norm[0][0] __________________________________________________________________________________________________ MLM-Sim (EmbeddingSimilarity) (None, 128, 30522) 30522 MLM-Norm[0][0] Embedding-Token[0][1] __________________________________________________________________________________________________ Input-Masked (InputLayer) (None, 128) 0 __________________________________________________________________________________________________ NSP-Dense (Dense) (None, 768) 590592 Extract[0][0] __________________________________________________________________________________________________ MLM (Masked) (None, 128, 30522) 0 MLM-Sim[0][0] Input-Masked[0][0] __________________________________________________________________________________________________ NSP (Dense) (None, 2) 1538 NSP-Dense[0][0]