[논문 리뷰]/컴퓨터비전

[논문 리뷰] SupCon: Supervised Constrastive Learning

johyeongseob 2024. 9. 3. 19:42

논문: https://proceedings.neurips.cc/paper_files/paper/2020/hash/d89a66c7c80a29b1bdbab0f2a1a94af8-Abstract.html

 

저자: Prannay Khosla (Google Research), Piotr Teterwak(Boston University), ChenWang (Snap Inc.), Aaron Sarna (Google Research, corresponding author), Yonglong Tian (MIT), Phillip Isola (MIT), Aaron Maschinot (Google Research), Ce Liu (Google Research), Dilip Krishnan (Google Research)

 

인용: Khosla, Prannay, et al. "Supervised contrastive learning." Advances in neural information processing systems 33 (2020): 18661-18673.

 

깃허브: https://github.com/HobbitLong/SupContrast

 

0. 초록(Abstract)

Contrastive learning은 self-supervised representation learning에서 최근 몇 년간 좋은 성능을 가져왔다. 다만 label을 사용하지 않기에 저자들은 contrastive learning을 fully-supervised setting으로 접근하려고 시도하였다. Embedding space에서 같은 클래스는 당기고, 다른 클래스는 밀어냄으로서 클러스터링을 잘 만들 수 있었다. 이는 ImageNet dataset에서 SOTA를 달성할 수 있었다. (그림 1)

그림1: SupCon이 다른 방법들(cross-entropy loss기반 augmentation) 보다 좋은 성능을 나타낸다.

 

1. 서론 (Introduction)

최근 몇 년간 contrastive learning (CL)이 self-supervised representation learning 분야에서 많은 성과를 보여주었다. CL의 기본적인 아이디어는 embedding space에서 "positive" sample은 당기고, "negative" sample은 밀어내는 것이다. 기존 CL은 레이블 정보가 없기 때문에 anchor 이미지 (기준이 되는 이미지, 자기 자신)에서 augmented image (증강 이미지)를 positive로 보고, 나머지 이미지들을 negative로 보았다. (그림 2 왼쪽) 이는 같은 레이블의 정보도 negative sample 로 인식하여 잘못된 학습이 진행될 가능성이 있다.

 

그림2: 기존 contrastive learning (왼쪽) 과 해당 논문이 제안하는 SupCon (오른쪽)

 

본 논문의 연구 (work)는 레이블 정보를 이용하므로 지도학습이 가능하면서, 동시에 CL기반 손실함수이다. 같은 클래스로부터 나온 embeddings 끼리는 당겨주고, 다른 클래스로부터 나온 embeddings는 밀어주는 것이다. (그림 2 오른쪽) 해당 연구의 novelty는 하나의 anchor에 많은 positives와 negatives를 가지는 것이다. SupCon은 구현이 쉽고 훈련에 적합하며, ResNet-50과 ResNet-200 모델을 사용하여 ImageNet dataset에서 SOTA를 달성하였다. 본 논문의 Contribution은 다음과 같다.

 

  1. Contrastive learning을 지도학습 방법에 사용할 수 있게 확장하였다.
  2. SupCon은 다양한 데이터셋에서 top-1 accuracy를 달성하였다.
  3. SupCon은 hard positive 와  hard negative에 대한 학습을 이끌어내기 위해 손실함수에 대해 자세히 설명하였다.
  4. SupCon은 cross-entropy loss에 비해 하이퍼파라미터에 덜 민감하다.

 

3. 방법 (Method)

SupCon은 다음과 같은 방식으로 진행된다. 우선 입력 배치 데이터에 두 번의 증강(augmentation)을 진행한다. 두 개의 증강 데이터셋은 encoder 네트워크를 통과하여 2048-dimenstion embedding이 된다. 훈련이 진행되는 동안, 2048-D embedding은 다시 projection 네트워크를 통과하여 128-D embedding이 된다. 여기서 SupCon은 계산을 진행한다. 충분히 학습이 진행된 후에는 projection 네트워크를 사용하지 않고 오직 encoder 네트워크만을 사용한다. encoder 네트워크를 얼리고 (frozen) 선형 분류기 (linear classifier) 를 cross-entropy loss를 이용하여 훈련한다.

 

3.1 용어 설명 (Representation Learning Framework)

  • 데이터 증강 모듈, Aug(): 입력 샘플 x 에 대해 $\tilde{x}$  = Aug(x)이다.  
  • Encoder 네트워크, Enc(): 입력 샘플 x 를 표현 벡터 r로 변환한다. r = Enc(x) $\in R^{D_E}$. ${D_E}$ = 2048
  • Projection 네트워크, Proj(): r 벡터를 z 벡터로 변환한다. z = Proj(r) $\in R^{D_P}$. ${D_P}$ = 128. inference에서는 proj() 대신 같은 수의 파라미터를 가진 linear classifier를 사용한다.

3.2 Contrastive Loss Functions

우선 기존 self-supervised contrastive loss를 알아보고 SupCon에 대해 살펴보자.

N개로 이루어진 무작위 샘플들과 라벨들 세트 $\{x_k, y_k\}_{k=1...N}$ 를 2번 증강하면 다음과 같다. $\{\tilde{x}_k, \tilde{y}_k\}_{l=1...2N}$. 여기서 $\tilde{x}_{2k}$와 $\tilde{x}_{2k-1}$는 두 개의 증강 데이터이다. 라벨은 같다. $\tilde{y}_{2k}$ = $\tilde{y}_{2k-1}$ = $\tilde{y}_{k}$ 기존 CL 수식은 아래와 같다.

 

\begin{align*}
    & \mathcal{L}^{self}=\sum_{i\in I}\mathcal{L}^{self}_{i} = -\sum_{i\in I}\log \frac{\exp{(z_i \cdot z_{j(i)}/\tau )}}{\sum_{a \in A(i)}\exp{(z_i \cdot z_{a}/\tau )}}
    ,
    \tag{1}
\end{align*}

 

여기서 $ i\in I $는 배치 내 모든 샘플이다. $z_l = Proj(Enc(\tilde{x}_l)) \in R^{D_P}$ 이고,  '$\cdot$' 심볼은 내적 (inner product)이다. $\tau$는 scalar temperature parameter이다. $A(i)\equiv I \setminus \{i\}$이다. 여기서 \ 는 차집합이다. 즉, A(i)는 {i} 인덱스를 제외한 나머지 집합이다. 인덱스 i 는 anchor이다. j(i)는 positive이고 증강을 제외한 2(N-1)은 negative이다.

 

이제 SupCon을 보자. 수식은 다음과 같다. (논문에서 제시한 $\mathcal{L}^{sup}_{out}$과 $\mathcal{L}^{sup}_{in}$ 중 성능이 좋은 $\mathcal{L}^{sup}_{out}$ 을 본 포스팅에서 설명한다.)

 

\begin{align*}
    &  \mathcal{L}^{sup}_{out} = \sum_{i\in I}\mathcal{L}^{sup}_{out,i} = \sum_{i\in I} \frac{-1}{|P(i)|} \sum_{p\in P(i)} \log \frac{\exp{(z_i \cdot z_{j(i)}/\tau )}}{\sum_{a \in A(i)}\exp{(z_i \cdot z_{a}/\tau )}}
    ,
    \tag{2}
\end{align*}

 

여기서 $P(i) \equiv \{p \in A(i) : \tilde{y}_p=\tilde{y}_i\}$는 anchor i 에 대한 positives들의 집합이다. |P(i)|는 집합 P의 원소 개수 (cardinality)이다.

 

기존 CL 손실함수나 SupCon 손실함수 모두 분자를 최대로, 분모를 최소로 만들어야 Loss가 낮아짐을 알 수 있다.

 

그림3: (a) 기본 cross-entropy 손실함수, (b) 기존 contrastive learning (c) SupCon

 

4. 실험 결과 (Experiments)

SupCon은 분류 벤치마크 CIFAR-10, CIFAR-100, ImageNet을 사용하여 성능을 보여주었다. Encoder 네트워크는 ResNet-50, ResNet-101, ResNet-200을 사용하였다. 증강(augmentation)은 AutoAugment, RandAugment, SimAugment, Stacked RandAugment를 사용하였다. temperature는 기본적으로 $\tau = 0.1$로 지정한다.

 

데이터 증강에 대한 내용은 논문 supplementary (14.2 Data Augmentation)에서 간략하게 소개한다.

 

표 2 는 ResNet-50을 사용한 SupCon이 다른 모델들보다 top-1 분류 성능이 높음을 보여준다. 

 

 

표 3 은 ImageNet에 대해 ResNet-50과 ResNet-200을 사용한 모델에서 SupCon이 다른 모델보다 top-1과 top-5 성능이 우수함을 보여준다.

 

 

주) top-1 accuracy와 top-5 accuracy에 대한 설명: https://www.kaggle.com/discussions/questions-and-answers/164379

 

5. SupCon Pytorch 코드

아래 코드는 논문의 저자가 작성한 코드이다. 

 

1. 라이브러리 불러오기

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn

 

2. 클래스 설정

class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

 

contrast mode 는 anchor의 개수를 몇 개로 지정할 지 고르는 선택지이다. 

 

3. forward 함수 및 device 설정

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

 

여기서 device는 features 가 어디에 있는지로 결정된다. 즉, features가 GPU에 올려져 있으면 device도 GPU로 설정한다.

 

4. features 차원에 대한 오류 방지

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

 

만약 features가 3차원 (batch_size, 증강된 view의 개수, embedding vector의 크기) 보다 작으면, ValueError를 생성한다. (오류 처리)

 

5. label 또는 mask 둘 중 하나만 사용

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

 

레이블 (실제 클래스) 혹은 마스크(샘플간의 관계, self-supervised CL에 사용) 하나만 사용하도록 조정. 만약 둘 다 없으면 기존 CL로 설정. 즉, 단위행렬을 생성하여 자기자신만 positive로 파악.

 

레이블이 존재한다면 레이블의 텐서를 [batch_size, 1]로 변환한다. 예를 들어, 만약 labels가 [0, 1, 2, 0]처럼 1차원 배열로 주어졌다면, 이를 [ [0], [1], [2], [0] ]와 같은 형태로 변환한다. 다음으로 레이블을 기반으로 마스크 (positive, negative 구분)를 생성하고 device로 옮긴다.

 

6. 앵커 개수 설정

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

 

contrast_feature는 [batch_size, n_views, feature_dim] 형태의 텐서를 n_views 개의 [batch_size, feature_dim] 텐서로 나눈 후, 텐서들을 배치 차원(첫 번째 차원)을 따라 다시 연결하여 [batch_size * n_views, feature_dim]로 만든다. 이는 모든 샘플의 모든 뷰를 하나의 큰 배치로 만든다.

 

다음으로 조건문은 anchor를 각 샘플의 첫 번째 view만 사용할 지, 모든 view를 사용할 지 정한다.

 

7. logits 계산

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

 

anchor_dot_contrast는 anchor_feature와 contrast_feature간 내적을 진행한 후, temperatue 파라미터인 $\tau$를 나누어준다. $(z_i \cdot z_{j(i)}/\tau )$ 여기까지 계산이 완료되었다.

 

다음으로 mask는 anchor 개수와 contrast 개수만큼 늘려준다. 예를 들어 (위의 mask 예시를 가져와서) mask가 4x4의 크기이고, anchor가 2개 contrast가 2개라면, mask의 크기는 (4*2 x 4*2)인 8x8의 크기가 된다.

 

다음으로 mask는 자기 자신의 위치 (대각성분)에 0을 삽입한다. 

 

8. log_probability 계산

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

 

우리가 구한 logit에 지수함수를 계산하고 $\exp{(z_i \cdot z_{j(i)}/\tau )}$, 자기 자신의 내적 값은 제외한다.

 

log_prob는 $\log \frac{\exp{(z_i \cdot z_{j(i)}/\tau )}}{\sum_{a \in A(i)}\exp{(z_i \cdot z_{a}/\tau )}}$를 계산한다. logits (분자)은 log에 exp를 취하면 원래 값이 나오므로 그대로 사용한다. logit에 torch.log(분모)를 빼면 log 내에서 나누기로 변환된다.

 

다음으로 mask에 1의 개수 (positive 개수, |P(i)|)를 구한 다음, log_prob에서 positive값들만을 취해 더한 다음 |P(i)|로 나누어 준다. 

 

지금까지 다음과 같은 부분이 수행되었다.

 

\begin{align*}
    &  \mathcal{L}^{sup}_{out,i} = \frac{1}{|P(i)|} \sum_{p\in P(i)} \log \frac{\exp{(z_i \cdot z_{j(i)}/\tau )}}{\sum_{a \in A(i)}\exp{(z_i \cdot z_{a}/\tau )}}
\end{align*}

 

9. 최종 손실함수값 (loss) 계산

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

 

$ \mathcal{L}^{sup}_{out,i} $에 마이너스 부호 (-)를 추가하고 i에 대해 합을 취한다. 그러면 비로소 다음과 같은 최종 손실함수 값을 구할 수 있다.

 

\begin{align*}
    &  \mathcal{L}^{sup}_{out} = \sum_{i\in I}\mathcal{L}^{sup}_{out,i} = \sum_{i\in I} \frac{-1}{|P(i)|} \sum_{p\in P(i)} \log \frac{\exp{(z_i \cdot z_{j(i)}/\tau )}}{\sum_{a \in A(i)}\exp{(z_i \cdot z_{a}/\tau )}}
    ,
    \tag{2}
\end{align*}

 

 

0. 번외: 커스텀 SupCon 코드

 

지금까지의 SupCon 손실함수 코드는 논문의 저자가 직접 작성한 코드이다. 매우 훌륭한 코드이지만 직접 사용하기에는 배워야 할 부분이 많다. 그래서 나는 내가 사용할 코드를 새로 작성하였다. 수정된 코드의 조건은 다음과 같다.

  1. 데이터 증강은 없다.
  2. 하이퍼 파라미터는 temperature 하나만 사용한다.
  3. 레이블은 존재하여야 한다. (지도학습) <-- 논문의 저자 코드는 레이블 정보가 없으면 SimCLR로 진행하지만 내 코드는 레이블이 존재한다는 가정하에 작동한다.
  4. 현재 가지고 있는 이미지의 embeddings에서 레이블 정보를 활용하여 SupCon을 진행한다.

아래는 수정된 코드이다.

"""
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        손실함수 값을 계산하기 위한 모델

        Args:
            features: embeddings 벡터값 [batch_size, embeddings_dim], ex) [32, 128].
            labels: 레이블값 [batch_size, label], ex) [32, 1].
            mask: 'positive'와 'negative'를 구분하는 마스크 [labels, labels]
            mask_{i,j}=1 만약 j가 i와 같은 클래스이면 1이다.

        Returns:
            손실함수 결과 스칼라 값
        """

        # 장치 설정
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 마스크 생성
        labels = labels.view(-1, 1) # dim: [batch_size, 1]
        mask = torch.eq(labels, labels.T).float().to(device) # dim: [batch_size, batch_size]

        # 내적(dot product) 계산 및 스케일링
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature) # dim: [batch_size, batch_size]

        # 수치적 안정성 확보: logit 값들을 0 또는 음수로 만들어 softmax 계산을 원할하게 한다.
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach() # dim: [batch_size, batch_size]

        # 자기 자신과의 비교 제거: 대각 성분이 0이고 나머지 성분이 1인 논리_마스크
        batch_size = features.shape[0]
        logits_mask = ~torch.eye(batch_size, device=device).bool()
        mask = mask * logits_mask

        """
        로그 확률 계산
        # log 내부 값이 0이 되는 것을 방지하기 위해 epsilon을 추가
        # 자신을 제외한 나머지 features 간의 내적 값: sigma(a in A(i)) {z_i * z_a/ t}
        """
        epsilon = 1e-8
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + epsilon)

        # 개별 Loss^{sup}_{out,i} 계산
        mask_pos_pairs = mask.sum(1) # 앵커별 |P(i)|의 값을 계산: 각 행의 요소들 합을 계산
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) # |P(i)| = 0 일시, 1로 변환
        log_prob = mask * log_prob # 최종 log_prob 계산
        mean_log_prob_pos = (log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = -mean_log_prob_pos
        loss = loss.view(1, batch_size).mean() # 배치 내 anchor (i)에 대한 평균값

        return loss

if __name__ == '__main__':

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 3
    feat_dim = 128
    batch_size = 32

    embeddings = torch.randn(batch_size, feat_dim).to(device)
    labels = torch.randint(0, num_classes, (batch_size,)).to(device)
    print(f"embeddings: {embeddings.shape}")
    print(f"labels: {labels}")

    criterion = SupConLoss(temperature=0.5)

    loss = criterion(embeddings, labels)

    print(f'loss = {loss}')