アプリとサービスのすすめ

アプリやIT系のサービスを中心に書いていきます。たまに副業やビジネス関係の情報なども気ままにつづります

Analysis of CNN+LSTM with Attentions

I created CNN+LSTM model that has Attention layers in LSTM side.

CNN is role of encoder and LSTM is one of decoder. 

Here, I'll write this as log of my analysis about how to use CNN output as LSTM input, and Attentions that are 'Self-Attenttion' and 'Source-Target Attention'.

This is just like my dialy, thus it's simple and short.

Contents
1.How to use CNN output as LSAM input
2.Analysis result of Attention
3.Self-Attention
4.Source-Target Attention
5.Summary


1.How to use CNN output as LSTM input

This time, input data are images and aim is meta label classification such as image colors and shape types.

First, input image to CNN-encoder and use its output to LSTM input.

CNN+LSTM Overall 


CNN output and LSTM input part 

"""params"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LSTM_UNITS = 200
embed_size = 1024
H = W = 224
in_size = 3
input_size = (1, in_size, H, W)
inp1 = torch.rand(input_size, dtype=torch.float32).to(device)

"""CNN encoder"""
avgpool = nn.AdaptiveAvgPool2d(1)
lstm = nn.LSTM(embed_size, LSTM_UNITS, bidirectional=True, batch_first=True)
x = self.cnn(x)
print(x.shape)
# torch.Size([1, 512, 14, 14])
x = avgpool(x)
# torch.Size([1, 512, 1, 1])
x = torch.reshape(x, (1, embed_size, 1, 1))
b,f,_,_ = x.shape
embedding = x.reshape(1,b,f) 
print(embedding.shape)
#torch.Size([1, 1, 1024])


"""LSTM decoder """
lstm.flatten_parameters()
h_lstm1, (hidden1, cell1) = lstm1(embedding)
print(h_lstm1.shape, hidden1.shape)
#torch.Size([1, 1, 400]) torch.Size([2, 1, 200])


2.Analysis result of Attention

Attention is divided to Self-Attention and Source-Target Attention.

Self-Attention has only 1 input and can be used for many filed and easy to customized.

Source-Target Attention has 2 input that's input and memory(input is as query, memory is keys and value).
It is mainly for sequence data and text emmbadding vector, as mainly used at after LSTM output. 


This time, I analize and used the 2 Attentions to LSTM.

3.Self-Attention

Self-Attention is mainly for NLP and use for Network mainly like LSTM, seq2seq and Transformer.
But recent day, as improved ML, Attention layers is used to CNN (image classification).

I'll use Attention layer for CNN another time, and this time I used for LSTM for improving accuracy.


Self-Attention stracture

class SelfAttention(nn.Module):
    def __init__(self, lstm_dim):
        super(SelfAttention, self).__init__()
        self.lstm_dim = lstm_dim *2
        self.attn_weight = nn.Sequential(
            nn.Linear(lstm_dim *2, lstm_dim *2), 
            nn.Tanh(),
            nn.Linear(lstm_dim *2, lstm_dim *2)
        )
    def forward(self, lstm_output):
        attn_weight = self.attn_weight(lstm_output)
        attn_weight = F.softmax(attn_weight, dim=2)
        feats = torch.cat([lstm_output, attn_weight], 2)
        return feats.squeeze(0)

LSTM_UNITS = 200
num_cls = 12
input_size = (1, 1, 400)

# function and class
selfattension = SelfAttention(LSTM_UNITS)
attention_linear = nn.Linear(LSTM_UNITS*4, num_cls)
final_activate = nn.Softmax(dim=1)

# inputs
hidden = torch.rand(input_size, dtype=torch.float32).to(device)
print(hidden.shape) # torch.Size([1, 1, 400])


# in LSTM, after "hidden = h_lstm1 + h_lstm2"
hidden = selfattension(hidden)
print(hidden.shape) # torch.Size([1, 800])
output = attention_linear(hidden)
output = final_activate(output)
print(output.shape) # torch.Size([1, 12])


In LSTM, Self-Attention is mainly placed around LSTM's final layer.
I reffed Kaggle sample use and PyTorch github.


4.Source-Target Attention

This time, I used this at after LSTM's output.
In Transformer, use many "Source-Target Attention" and it called as "Multi-head Attention"

"Source-Target Attention" is harder to customize more than 'Self-Attenttion'. but it depends on technique and idea.(up to you)

"Source-Target Attention" structure

def source_target_attention(lstm_output, final_state):
    lstm_output = lstm_output.permute(1, 0, 2) # keys
    querys = final_state.squeeze(0) # query
    logits = torch.bmm(lstm_output, querys.unsqueeze(2)).squeeze(2)
    attn_weights = F.softmax(logits, dim=1)
    new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), # value
                                 attn_weights.unsqueeze(2)).squeeze(2)
    return new_hidden_state

# LSTM outputs
lstm_output = torch.rand((1, 1, 400), dtype=torch.float32)
hidden1 = torch.rand((2, 1, 200), dtype=torch.float32)

# in LSTM, #h_lstm1, (hidden1, cell1) = self.lstm1(embedding)
hidden1 = hidden1.reshape(1, 1, 400)
lstm_output = source_target_attention(lstm_output, hidden1).unsqueeze(0)
print(lstm_output.shape)
# torch.Size([1, 1, 400])


5.Summary

this time I realized that CNN-LSTM with Attention is good for image label classification not only for NLP task.

good accuracy is due to coverting image to sequence data.(3d to 1d).
I roughly draw entire this time network image which show where I set the 2 Attentions as follows


and this time better loss and activation are

This time
Loss : nn.CrossEntropyLoss()
Final activationnn.Softmax(dim=1)


task was to classify color label in 11 colors and shape label in 2 shape type.(meta data label).

if you use this for NLP task, loss and final activation should be as follows.

NLP task
Loss : nn.NLLLoss()
Final activation :nn.LogSoftmax(dim=1)

Accuracy is almost as good as I expected. If I find better architecture, will improve more.


Reference site

Attention is all you need(jp)
self attention 】簡単に予測理由を可視化できる文書分類モデルを実装する