やることは大雑把に言うと、SARの散乱強度から地上の植物の生育状況を予測するためにSAR画像とground画像(植生状況を表す地上の画像)の相関関係を調べる。(データはQGISで作成)
概要の類似問題の詳細はこのサイトを参考に。
主な概要
・SAR画像は散乱強度を、groundは植物の生育段階を表す画像。各値はpixel値に埋め込まれてる
・groundの生育段階は0, 1, 2, 3 の4つ
・SARとgroundの画像サイズは10×10pixel
・画像を回帰問題かつテーブルデータとして使って相関関係を計測
・アルゴリズムはTabnetで、画像はテーブルデータに変換して使う
今回は画像同士の相関関係を「画像をテーブルデータとして扱って、Tabnetを使い相関を調べる」、という特殊なやり方でやったのでその備忘録としてやったことをまとめてく。
目次
1. Tabnetとは
2. 画像をテーブルデータとして扱う
3. Tabnetの回帰用の画像テーブルデータを作る
4. Tabnetで学習
5. 評価:結果を見てみる
1. Tabnetとは
Tabnetは決定木系のアルゴリズムのDNN版で入力はxgboostやlightbgmと同じテーブルデータ。
最近はkaggleやいろんなコンペで上位勢が使ってるアルゴリズム。
主な特徴は使った所管で簡単にまとめるとこんな感じ
2. 決定木特有の逐次学習で効率的に学習できる
3. 特徴量の重要度、選択マスクの視覚化など解釈しやすい指標が多く使える
事前学習の仕方
左の図のエンコーダ部分で特徴の一部にマスクをして学習し、デコーダ部分でマスクの予測を行わせて事前学習をします。その後、右の本学習で、事前学習で得た重みを用いて転移学習をします。
2. 画像をテーブルデータとして扱う
まず、どうやって相関関係をみるかは下の順番で調べる
画像をテーブルデータとして扱うときに必須の工程は
・回帰問題かクラス分類のどちらかで解く
画像のテーブルデータでの扱い方
画像テーブルデータを使った分類問題の例はMNISTを使った例があった。
回帰問題は1次元配列にしたtarget用の画像の全部のpixel値を予測する。
今回は11ヶ月分のSAR画像とground画像(植物の生育データ)を学習データとして、テストデータを12月分のSAR画像とground画像とした。
データ概要
訓練データ:1〜11月のSARとground
テストデータ:12月のSARとground
説明変数用画像:SAR画像(散乱強度)
target用の画像:ground(植生状況)
3. Tabnetの回帰用の画像テーブルデータを作る
今回はpytorch版のtabnetを使った。
下のコードで、Tabnet用の回帰問題用の画像のテーブルデータを作成した。
訓練データの作成
def log(x) return np.log10(x)+3 def set_pixel(sar, ground): sar[ground==0]=0 return sar def create_data(vv, vh, gr): # tif画像の読み込み vv = cv2.imread(vv, -1) vh = cv2.imread(vh, -1) gr = cv2.imread(gr, -1) # SAR画像はマイナスにしないため対数をとる vv, vh = log(vv), log(vh) # groudのピクセルが0のところはSARも0にする vv = set_pixel(vv, gr) vh = set_pixel(vh, gr) # 画像を1次元に変換 vv, vh, gr = np.ravel(vv).reshape(1, -1), np.ravel(vh).reshape(1, -1), np.ravel(gr).reshape(1, -1) return vv, vh, gr vv = 'VV.tif' vh = 'VH.tif' gr = 'ground.tif' vv, vh, gr = create_data(vv, vh, gr) print(vv.shape) >>> (1, 100)
これを11ヶ月分繰り返し結合して、下のshape=(700, 8)のvh, VV, grを作成
特徴量が足らないのでSAR画像を「足し算、引き算、かけ算、割り算」で増やした。
add = vv + vh sub = vv-vh mul = vv * vh dev1 = vv / vh dev2= vh / vv
11ヶ月分結合して、訓練データを可視化してみる。
# 結合してnumpyから pandasに変換 arr = np.vstack([vv, vh, add, sub, mul, dev1, dev2, dr]) X_train = pd.DataFrame(data=arr.T)#.plot()#.get_figure().savefig('df1.png') X_train.columns=['vv', 'vh', 'add', 'sub', 'mul', 'vv/vh', 'vh/vv', 'ground'] # Nanを0で置換 X_train = X_train.fillna(0) print(X_train.shape) X_train.plot() >>>> (700, 8)
pandasのDataFrameは下の感じ
df.head()
testデータの作成
testデータも同じように作った。
テストデータは1ヶ月分なので少ない、shapeは(100, 8)
可視化結果
あとはgroundをtargetデータとして、どのくらいSAR画像から、正確にgroundを予測できるかをTabnetで検証してみる。
4. Tabnetで学習
Tabnetをクラス化したやつ
class TBNet: def __init__(self, train_df, target_name, epochs=3): self.epochs = epochs self.target_name=target_name self.local_interpretability_step = 3 self.tabnet_params = dict(n_d=15, n_a=15, n_steps=8, gamma=0.2, seed=10, lambda_sparse=1e-3, optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=2e-2, weight_decay=1e-5), mask_type="entmax", scheduler_params=dict( max_lr=0.05, steps_per_epoch=int(X_train.shape[0] / 256), epochs=self.epochs, is_batch_level=True, ), verbose=5, ) self.src_df = train_df def train(self, X_train, X_test, y_train, y_test, model_type='classification'): # model if model_type=='classification': self.model = TabNetClassifier(**self.tabnet_params) elif model_type=='regression': self.model = TabNetRegressor(**self.tabnet_params) self.model.fit( X_train=X_train, y_train=y_train, eval_set=[(X_test, y_test)], eval_metric=['rmsle', 'mae', 'rmse', 'mse'], max_epochs=self.epochs, patience=30, batch_size=256, virtual_batch_size=128, num_workers=2, drop_last=False, loss_fn=torch.nn.functional.l1_loss) print('show feature importance') self.plot_metric() self.feature_importances() self.local_interpretability(X_test) print('prediction') pred = self.model.predict(X_test) scores = self.calculate_scores(pred, y_test) print(scores) return self.model, pred, y_test def plot_metric(self): for param in ['loss', 'lr', 'val_0_rmsle', 'val_0_mae', 'val_0_rmse', 'val_0_mse']: plt.plot(self.model.history[param], label=param) plt.xlabel('epoch') plt.grid() plt.legend() plt.show() def feature_importances(self): df = self.src_df feature_name =[str(col) for col in df.columns if col!=self.target_name] print(len(feature_name)) feat_imp = pd.DataFrame(self.model.feature_importances_, index=feature_name) feature_importance = feat_imp.copy() feature_importance["imp_mean"] = feature_importance.mean(axis=1) feature_importance = feature_importance.sort_values("imp_mean") plt.tick_params(labelsize=18) plt.barh(feature_importance.index.values, feature_importance["imp_mean"]) plt.title("feature_importance", fontsize=18) def local_interpretability(self, X_test): """どの特徴量を使うか decision making するのに用いた mask(Local interpretability)""" n_steps = self.local_interpretability_step explain_matrix, masks = self.model.explain(X_test) fig, axs = plt.subplots(n_steps, 1, figsize=(21, 3*n_steps)) for i in range(n_steps): axs[i].imshow(masks[i][:50].T) axs[i].set_title(f"mask {i}") axs[i].set_yticks(range(len(self.src_df.columns[:-1]))) axs[i].set_yticklabels(list(self.src_df.columns[:-1])) def calculate_scores(self, true, pred): scores = {} scores = pd.DataFrame( { "R2": r2_score(true, pred), "MAE": mean_absolute_error(true, pred), "MSE": mean_squared_error(true, pred), "RMSE": np.sqrt(mean_squared_error(true, pred)), }, index=["scores"], ) return scores
さっきのpandasデータからnumpy配列の訓練、テストデータを作る。
y_train = X_train['ground'].values.reshape(700, 1) y_test = X_test['ground'].values.reshape(100, 1) print(y_train.shape, y_test.shape) X_trains = X_train.drop('ground', axis=1).values X_tests = X_test.drop('ground', axis=1).values print(X_trains.shape, X_tests.shape) >>>> (700, 1) (100, 1) (700, 7) (100, 7)
Tabnetでtrainして評価。
tbnet = TBNet(X_train, target_name='drone', epochs=50) model, pred, y_test = tbnet.train(X_trains, X_tests, y_train, y_test, model_type='regression')
5.評価:結果を見てみる
loss関数推移
accuracy(rmse)の推移
各predictionの値
R2 MAE MSE RMSE scores 0.822602 0.12724 0.08228 0.286844
特徴量の重要度
Local interpretability
Tabnetではそれに加えて、どの特徴量を使うか決定(decision making)するのに用いた mask というのを見ることができます。Local interpretabilityとも呼ばれます。mask は n_steps の数だけあり、ここでは予測したデータの先頭 50 個についてのみ図示してみます。
画像で予測データとテストデータでどれだけ正確に予測できてる図示してみる。
RMSE= 0.286844
RMSEが0.3だとかなりよく予測できるんだなと思った。
lightbgmとか決定木系のアルゴリズムなら全部、このやり方で相関がみれるはず。
画像をテーブルデータとして扱った相関関係の調べ方でした。
参考サイト
・Tabnetはどのように使えるのか・TabNet-pytorch
・学習帳9_B:SARの画像(SARデータ利用に進む前に)