📌 이 글에 대하여
이 게시글은 Kaggle의 Sushant Kumar 님이 작성한 원문 글을 한국어로 번역한 것입니다.
원문은 Apache License 2.0 하에 공개되었으며, 이 블로그 역시 해당 라이선스를 따릅니다.
원문 저자: Sushant Kumar
원문 위치: Kaggle Notebook
라이선스: Apache License 2.0 전문 보기
본 번역은 비상업적/교육적 목적이며, 원문 저자의 저작권과 라이선스를 존중합니다.
코드: https://github.com/johyeongseob/from-scratch-ai
dependency conflict 확인
Windows11, Python 3.8.18, torch version: 2.4.1+cu121, CUDA: 12.1, GPU: NVIDIA GeForce RTX 3080 Ti
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install einops torchsummary pytorch-lightning lightning-bolts
이번 게시글에서 다룰 핵심 개념
- ViT를 활용한 이미지 패치의 시각화 (예시 이미지)
- ViT scratch 실험 (CIFAR-10에 대한 훈련 및 평가)
- 기본 원리에 대한 설명
예시이미지: Cat And Dog Classifier
Vision Transformer (ViT)
1. 라이브러리 불러오기
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import itertools
import torch.nn as nn
from PIL import Image
import torchvision.transforms as T
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
1.1. 이미지 및 패치 이미지 예시
이미지를 ViT의 입력으로 넣기 위해 16*16 해상도 크기 패치들로 분해한다.
img = Image.open("cat-and-dog-classifier/data/Cat/1200px-Cat03.jpg")
img.show()
transforms = T.Compose([
T.Resize((224, 224)),
T.ToTensor()
])
x = transforms(img)
x = x[None, ...]
print(f"x.shape: {x.shape}") # [1, 3, 224, 224]
patch_size = 16
patches = rearrange(
x,
'b c (h s1) (w s2) -> b (h w) (s1 s2 c)',
s1=patch_size,
s2=patch_size
)
print(f"patches.shape: {patches.shape}") # [1, 196, 768]
patches_d = rearrange(
x,
'b c (h s1) (w s2) -> b h w s1 s2 c',
s1 = 16,
s2 = 16
)
print(f"patches_d.shape: {patches_d.shape}") # [1, 14, 14, 16, 16, 3]
fig, axes = plt.subplots(nrows=14, ncols=14, figsize=(20, 20))
for i, j in itertools.product(range(14), repeat=2):
axes[i, j].imshow(patches_d[0, i, j])
axes[i, j].axis('off')
axes[i, j].set_title(f"patch ({i},{j})")
fig.tight_layout()
plt.show()
2. 패치 임베딩
ViT 모델의 이미지 처리에서 첫 번째 단계는 패치 임베딩 과정이다. 이는 이미지를 고정된 크기의 패치로 나누고 1-D 벡터로 flatten 이다. 해당 패치들은 linear project를 통해 token으로 변환된다. 여기에 cls token을 추가한다. 이후, position embedding을 element-wise sum하여 위치 정보를 추가한다. cls token은 학습 가능한 parameter이다.
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
super(PatchEmbedding, self).__init__()
assert img_size / patch_size % 1 == 0, "img_size must be integer multiple of patch_size"
self.projection = nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
nn.Linear(patch_size * patch_size * in_channels, emb_size)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size)) # [1, 1, 768]
self.positional_emb = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size)) # [197, 768]
def forward(self, x):
B, *_ = x.shape # B: num of samples
x = self.projection(x) # [1, 196, 768]
cls_token = repeat(self.cls_token, '() p e -> b p e', b = B) # [1, 1, 768]
x = torch.cat([cls_token, x], dim=1)
x += self.positional_emb
return x
3. Transformer Encoder
ViT는 Transformer의 encoder로 구성되어 있다. 하나씩 살펴보자
3.1. Multi-head Attention (MHA)
attention 모듈은 총 3개의 입력 query, key, value 를 사용한다. 이 모듈은 query와 key의 attention matrix를 계산하여 attention 값을 구하고 value에 적용한다. 모듈에 대해 자세히 살펴보자. 모듈은 총 4개의 fully connected layer를 가진다. 3개는 query, key, value에 사용되고, 마지막은 최종 결과값에 사용된다. attention matrix는 dot product를 사용한다. 이는 각 요소가 다른 요소에게 "얼마나 많이" 관여하고 있는지를 학습한다. 이후 sequence 내 요소들에 대한 attention 결과을 scaling하고 softmax 연산을 적용한다. attention 값 (energy)은 다시 value에 적용하여 중요도를 학습한다.
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size=768, num_heads=8, dropout=0):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.emb_size = emb_size
self.query = nn.Linear(emb_size, emb_size)
self.key = nn.Linear(emb_size, emb_size)
self.value = nn.Linear(emb_size, emb_size)
self.projection = nn.Linear(emb_size, emb_size)
self.attn_dropout = nn.Dropout(dropout)
self.scaling = (self.emb_size // num_heads) ** -0.5
def forward(self, x, mask=None):
rearrange_heads = 'batch seq_len (num_head h_dim) -> batch num_head seq_len h_dim'
# [batch, seq_len, **emb_size**] → [batch, **num_head**, seq_len, **h_dim**]
queries = rearrange(self.query(x), rearrange_heads, num_head=self.num_heads) # [Batch, heads, seq_len, h_dim]
keys = rearrange(self.key(x), rearrange_heads, num_head=self.num_heads) # [Batch, heads, seq_len, h_dim]
values = rearrange(self.value(x), rearrange_heads, num_head=self.num_heads) # [Batch, heads, seq_len, h_dim]
# sum up over the last axis
energies = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(energies.dtype).min # -infinity (inf)
energies.mask_fill(~mask, fill_value)
attention = F.softmax(energies * self.scaling, dim=-1)
attention = self.attn_dropout(attention)
# sum up over the third axis
out = torch.einsum('bhas, bhsd -> bhad', attention, values)
out = rearrange(out, 'batch num_head seq_length dim -> batch seq_length (num_head dim)')
out = self.projection(out)
return out
3.2. Residuals
Transformer는 각 세부모듈 전에 normaliztion을 수행하고 세부모듈 후에 residual block을 사용한다. 세부모듈은 MHA 와 MLP 이다.
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs): # keyword arguments: 유연한 함수 인자 전달 방식
res = x
x = self.fn(x, **kwargs)
x += res
return x
FeedForwardBlock=lambda emb_size=768, expansion=4, drop_p=0.: nn.Sequential(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size)
)
class TransformerEncoderBlock(nn.Sequential):
def __init__(self, emb_size=768, drop_p=0., forward_expansion=4, forward_drop_p=0, **kwargs):
super(TransformerEncoderBlock, self).__init__(
ResidualAdd(
nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, **kwargs),
nn.Dropout(drop_p)
)
),
ResidualAdd(
nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
)
)
3.3. Transformer
Transformer의 encoder는 L개의 identical layers로 구성된다.
class TransformerEncoder(nn.Sequential):
def __init__(self, depth=12, **kwargs):
super(TransformerEncoder, self).__init__(
*(TransformerEncoderBlock(**kwargs) for _ in range(depth))
)
4. Classification Head
ViT는 이미지가 Transformer encoder를 거쳐 생성된 feature map (NLP에서는 embedding vector라는 표현을 사용)를 사용하여 classfication logit을 추출한다. (기본적으로 cls token을 사용하나, 해당 코드에서는 모든 token의 평균값을 사용)
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size=768, num_classes=1000):
super(ClassificationHead, self).__init__(
Reduce('batch_size seq_len emb_dim -> batch_size emb_dim', reduction='mean'), # mean pooling head
nn.LayerNorm(emb_size),
nn.Linear(emb_size, num_classes)
)
5. Vision Transformer (ViT)
지금까지 설명한 Patch Embedding, Transformer Encoder, Classification Head를 결합하여 ViT 모델을 구현한다.
class ViT(nn.Sequential):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224, depth=12, num_classes=1000, **kwargs):
super(ViT, self).__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size,),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, num_classes)
)
CIFAR-10을 이용한 VIT모델 훈련 및 검증
1. 라이브러리 불러오기
from pl_bolts.datamodules import CIFAR10DataModule
from torch.optim.lr_scheduler import OneCycleLR
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning.callbacks import ModelCheckpoint
import torchvision.transforms as T
import torch.nn.functional as F
import pytorch_lightning as pl
from ViT import ViT
from torchmetrics.functional import accuracy
import torch
import warnings
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('medium') # option
1.1 train configuration 정의 및 CIFAR-10 데이터셋 다운로드
config = {
"data_dir": ".",
"batch_size": 256,
"num_workers": 2,
"num_classes": 10,
"lr": 1e-4, "max_lr": 1e-3
}
train_transforms = T.Compose(
[
T.RandomCrop(32, padding=4), # perturbation
T.RandomHorizontalFlip(), # perturbation
T.ToTensor(),
cifar10_normalization()
]
)
test_transforms = T.Compose(
[
T.ToTensor(),
cifar10_normalization()
]
)
# Train: 45,000, valid: 5,000, test: 10,000
cifar10_dm = CIFAR10DataModule(
data_dir=config["data_dir"],
batch_size=config["batch_size"],
num_workers=config["num_workers"],
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms
)
2. 모델 정의
class LitViT(pl.LightningModule):
# Initialize model and hyperparameters
def __init__(self, lr=0.05):
super().__init__()
self.save_hyperparameters()
self.model = ViT(in_channels=3, patch_size=4, emb_size=128, img_size=32, depth=12, num_classes=10)
# Forward pass
def forward(self, x):
out = self.model(x)
return F.log_softmax(out, dim=1)
# Training step
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x) # model forward pass
loss = F.nll_loss(logits, y) # negative log likelihood loss
self.log("train_loss", loss)
return loss
# Shared evaluation logic for validation and test
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y, task="multiclass", num_classes=config["num_classes"])
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
# Validation step
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
# Test step
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
# Configure optimizer and learning rate scheduler
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.lr,
)
steps_per_epoch = 45_000 // config["batch_size"]
scheduler_dict = {
"scheduler": OneCycleLR(
optimizer,
max_lr=config["max_lr"],
epochs=self.trainer.max_epochs,
steps_per_epoch=steps_per_epoch,
),
"interval": "step",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
3. 훈련 및 평가
if __name__ == "__main__":
model = LitViT(lr=config["lr"])
checkpoint_callback = ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=1,
filename="vit-best",
dirpath="checkpoints"
)
trainer = pl.Trainer(max_epochs=100, accelerator="auto", callbacks=[checkpoint_callback])
# Auto LR finder
lr_finder = trainer.tuner.lr_find(model, cifar10_dm)
model.hparams.lr = lr_finder.suggestion()# Auto-find model LR is: 0.000630957344480193
trainer.fit(model, cifar10_dm)
ckpt_path = "checkpoints/vit-best.ckpt"
model = LitViT.load_from_checkpoint(ckpt_path)
trainer = pl.Trainer(accelerator="auto")
trainer.test(model, datamodule=cifar10_dm)
3.1. 실험 결과
Testing DataLoader 0: 100%|██████████| 40/40 [00:01<00:00, 34.32it/s]
──────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
──────────────────────────────────────────────────────────────────────────────
test_acc 0.8058000206947327
test_loss 1.023311972618103
──────────────────────────────────────────────────────────────────────────────
'딥러닝 > 컴퓨터비전' 카테고리의 다른 글
[컴퓨터비전] 기초부터 시작하는 CLIP (Pytorch 구현) (1) | 2025.06.26 |
---|---|
[컴퓨터비전] 데이터 증강 종류 및 코드 (Pytorch, Albumentations, Imgaug) (1) | 2024.12.06 |
[컴퓨터비전] Cityscapes annotation을 COCO (.json)로 변경하는 방법 (2) | 2024.09.19 |
[컴퓨터비전] KITTI dataset label (.txt) 파일을 PASCAL VOC label (.xml)로 변경하는 방법 (0) | 2024.09.19 |
[컴퓨터비전] 윈도우 환경에서 detectron2 설치하는 방법 (1) | 2024.09.19 |