アプリとサービスのすすめ

アプリやIT系のサービスを中心に書いていきます。たまに副業やビジネス関係の情報なども気ままにつづります

Yolov7にSinkhorn lossを使って実験してみた

OTA-loss とかOCcostとか話題になってる昨今に絡んで、「Sinkhorn」とかいう手法が気になってた。

何でもlossかなんかのmatrixを最適化する手法でSinkhorn lossとかいうのもある。Yolov7を使ってる最中だったので、なんとかYolov7にSinkhornを使ってみたかった。

なので、物体検出の最新版Yolov7にSinkhorn使って検証してみた。その備忘録。

目次
1. Sinkhornとは
2. Sinkhornと Yolov7のOTA-lossの融合
3. 検証結果


1. Sinkhornとは

Sinkhornとは二つの変数で求めたcost matrixを最小化する手法で最適化輸送問題に使われる。

簡単にどう使うかというとpredictionとground truthで求めたlossをcost matrixとして使って、Sinkhornで最適化して、lossをさらに改善しようという話。


predictionとground truthを確立分布とみなして同じに近づけるので、Kullback-Leibler Divergence lossと仕組みは近いと思う。

Sinkhornとlossの仕組み


細かい説明はした記事を参照。

OTA(Optimal Transport Assignment for Object Detection)
最適化輸送問題




Sinkhorn loss

import torch
import torch.nn as nn

# Adapted from https://github.com/gpeyre/SinkhornAutoDiff
class SinkhornDistance(nn.Module):
    def __init__(self, model, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.device = next(model.parameters()).device
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, cost, pred, truth):
        '''
    	We can easily see that the optimal transport corresponds to assigning each point 
    	in the support of pred(x) to the point of truth(y)
    	'''
        x, y = pred, truth
        # The Sinkhorn algorithm takes as input three variables :
        C = cost  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False, device=self.device).fill_(1.0 / x_points).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False, device=self.device).fill_(1.0 / y_points).squeeze()

        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

        return cost #, pi, C

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

2. Sinkhornと Yolov7のOTA-lossの融合

Yolov7のOTA-lossに Sinkhornを使ってみた。object lossとiou lossは3次元でなかったり、相性が悪かったので、class lossだけに使うことにした。

Sinkhornを使ったyolov7のOTA-loss

class ComputeLossOTA:

          if self.use_cost:
               lcls_cost = self.BCEcls(ps[:, 5:], t)
               lcls += self.sinkhorn(cost=lcls_cost.unsqueeze(2), pred=ps[:, 5:].unsqueeze(2), truth=t.unsqueeze(2))
          else:
               lcls += self.BCEcls(ps[:, 5:], t)  # BCE


3. 検証結果

普通のYolov7

confusion matrix

loss curve

P_curve

F1 curve


Sinkhornを使ったYolov7

confusion matrix

loss curve

P_curve

F1 curve

うまい具合にいってるけど、全体で的に見るとSinkhornに入れる引数のshapeが3次元の必要性から、使うドメインは限られてくる。

今回はclass lossにしか使ってなかったから、object lossとiou lossに使えばもっと改善する可能性もある。

少なくともlossの改善には悪い影響は与えないんじゃないかと思う。

参照記事

yolov7
SinkHorn