📌 이 글에 대하여
이 게시글은 Kaggle의 Moein Shariatnia 님이 작성한 원문 글을 한국어로 번역한 것입니다.
원문은 Apache License 2.0 하에 공개되었으며, 이 블로그 역시 해당 라이선스를 따릅니다.
원문 저자:Moein Shariatnia
원문 위치: Kaggle Notebook
라이선스: Apache License 2.0 전문 보기
본 번역은 비상업적/교육적 목적이며, 원문 저자의 저작권과 라이선스를 존중합니다.
코드: https://github.com/johyeongseob/from-scratch-ai
데이터셋: https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset
라이브러리 설치
conda create -n clip-env python=3.10
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install timm
pip install transformers==4.30.2
소개
2021년 1월 OpenAI는 CLIP을 발표했습니다. 이 multi-modality model은 texts와 images를 연결합니다. 이번 포스팅에서는 Pytorch로 구현한 CLIP을 구현해봅니다.
What does CLIP do? Why is it fun?
Learning Transferable Visual Models From Natural Language Supervision 논문에서 OpenAI는 CLIP, Contrastive Language-Image Pre-training, 이라는 모델을 소개합니다. 모델은 문장과 문장이 설명하는 이미지의 관계를 학습합니다. 모델은 학습 시, 입력 문장에 대해 문장과 가장 관련성이 높은 이미지를 검색할 수 있습니다. 훈련을 마친 CLIP은 ImageNet에 대한 classification에서 SOTA를 달성하였습니다.
teaser로 이번 포스팅에서 다루는 CLIP의 결과물을 아래 보여드립니다.
Config
import torch
debug = False
image_path = "./flickr30k_images"
captions_path = "./captions.csv"
batch_size = 32
num_workers = 4
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 40 # covers 99.5% of Flickr30k captions, max=88
pretrained = True
trainable = True
temperature = 1.0
size = 224
num_projection_layers = 1
projection_dim = 256
dropout = 0.1
Utils
class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
return f"{self.name}: {self.avg:.4f}"
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
Dataset
모델은 이미지 그리고 텍스트 임베딩이 필요합니다. 이를 위한 데이터셋을 구성합니다. 게시글 저자는 텍스트 임베딩을 위해 사전 훈련된 DistilBERT (HuggingFace)를 사용합니다.
import torch
from torch.utils.data import Dataset
import config
from PIL import Image
class CLIPDataset(torch.utils.data.Dataset):
def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
self.image_filenames = image_filenames # Each image is repeated 5 times for 5 captions in captions.csv
self.captions = list(captions) # class: 'pandas.core.series.Series' -> 'list'
self.encoded_captions = tokenizer( # Encode all captions from the CSV file using the tokenizer
list(captions),
padding=True,
truncation=True,
max_length=config.max_length
)
self.transforms = transforms
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
item = {}
# Extract the encoded token (e.g., input_ids, attention_mask) for the current index
for key, values in self.encoded_captions.items():
item[key] = torch.tensor(values[idx])
image = Image.open(f"{config.image_path}/{self.image_filenames[idx]}").convert("RGB") # directory + filename
item['image'] = self.transforms(image)
item['caption'] = self.captions[idx]
return item
Data Utils
import torch
import numpy as np
import pandas as pd
import config
from CLIPDataset import CLIPDataset
from torch.utils.data import Dataset
from torchvision import transforms
def make_train_valid_dfs():
dataframe = pd.read_csv(f"{config.captions_path}")
max_id = dataframe["id"].max() + 1 if not config.debug else 100
image_ids = np.arange(0, max_id) # 31782
np.random.seed(42)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
# You can add augmentations in the train mode if needed
def get_transforms(mode="train"):
if mode == "train":
return transforms.Compose(
[
transforms.Resize((config.size, config.size)),
transforms.ToTensor()
]
)
else:
return transforms.Compose(
[
transforms.Resize((config.size, config.size)),
transforms.ToTensor()
]
)
CLIP
Image Encoder
게시글 저자는 image encoder로 사전 학습한 ResNet50 (timm)을 사용합니다. 여기서 fclayer는 제거하고 순수하게 feature extractor $\in \mathbb{R}^{2048}$ 로만 사용합니다.
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self,
model_name=config.model_name,
pretrained=config.pretrained,
trainable=config.trainable
):
super().__init__()
self.model = timm.create_model(
model_name,
pretrained,
num_classes=0, # self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
Text Encoder
앞서 언급한 대로, 게시글 저자는 text encdoer로 DistilBERT를 사용합니다. 저자는 output 중 CLS 토큰 $\in \mathbb{R}^{768}$ 만을 사용합니다.
class TextEncoder(nn.Module):
def __init__(
self,
model_name=config.text_encoder_model,
pretrained=config.pretrained,
trainable=config.trainable
):
super().__init__()
if pretrained:
self.model = DistilBertModel.from_pretrained(model_name)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
Projection Head
Image와 Text로부터 얻은 embedding을 하나의 space로 공유하기 위해 project 합니다. projection 후, 관련 있는 embedding $\in \mathbb{R}^{256}$ 은 push하고, 그렇지 않으면 pull하는 작업을 수행합니다.
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=config.projection_dim,
dropout=config.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
CLIP
이제 image encoder와 text encoder와 projection head를 합쳐 CLIP 모델을 구성합니다. 여기서 재밌는 부분은 target으로 soft label을 사용하였습니다. 왜냐하면, 데이터셋 (Flicker-30k) 특성에 따라, 하나의 이미지에 여러 캡션 (5개) 이 존재하기 때문에 이들을 함께 학습시키기 위해서는 hard label을 사용할 수 없기 때문입니다.
import torch
import torch.nn as nn
import timm
import config
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertConfig
class CLIPModel(nn.Module):
def __init__(
self,
temperature=config.temperature,
image_embedding=config.image_embedding,
text_embedding=config.text_embedding,
):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax( # soft label
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
Train
아래 훈련 함수 (handy function) 를 구성하였습니다. 저는 windows에서 실험을 진행하기 때문에, data_loader의 num_workers (Data Utils 항목 참고)를 사용하기 위해 아래처럼 실행하였습니다.
if __name__ == "__main__":
main()
import config
from utils import AvgMeter, get_lr
import itertools
from tqdm import tqdm
from data_utils import make_train_valid_dfs, build_loaders
from model import CLIPModel
import torch
from transformers import DistilBertTokenizer
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
loss_meter = AvgMeter()
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch in tqdm_object:
batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
return loss_meter
def valid_epoch(model, valid_loader):
loss_meter = AvgMeter()
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm_object:
batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
return loss_meter
def main():
train_df, valid_df = make_train_valid_dfs()
tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(config.device)
params = [
{"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr},
{"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr},
{"params": itertools.chain(
model.image_projection.parameters(), model.text_projection.parameters()
), "lr": config.head_lr, "weight_decay": config.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=config.patience, factor=config.factor
)
step = "epoch"
best_loss = float('inf')
for epoch in range(config.epochs):
print(f"Epoch: {epoch + 1}")
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
torch.save(model.state_dict(), "best.pt")
print("Saved Best Model!")
lr_scheduler.step(valid_loss.avg)
if __name__ == "__main__":
main()
실행 결과
100%|██████████| 3973/3973 [07:42<00:00, 8.59it/s, lr=0.0001, train_loss=1.4]
100%|██████████| 994/994 [01:24<00:00, 11.74it/s, valid_loss=2.15]
0%| | 0/3973 [00:00<?, ?it/s]Saved Best Model!
Epoch: 2
100%|██████████| 3973/3973 [07:45<00:00, 8.54it/s, lr=0.0001, train_loss=0.455]
100%|██████████| 994/994 [01:02<00:00, 15.82it/s, valid_loss=2.21]
Inference
학습한 모델에서 이미지 임베딩을 추출한 다음, query와 가장 유사한 이미지를 검색하는 작업을 수행하겠습니다.
import os
import config
from tqdm import tqdm
from data_utils import make_train_valid_dfs, build_loaders
from model import CLIPModel
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
import cv2
def get_image_embeddings(valid_df, model_path, save_path=None):
tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(config.device)
model.load_state_dict(torch.load(model_path, map_location=config.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(config.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
all_embeddings = torch.cat(valid_image_embeddings)
if save_path:
torch.save(all_embeddings, save_path)
return model, all_embeddings
def find_matches(model, image_embeddings, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(config.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = model.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [image_filenames[idx] for idx in indices[::5]]
_, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"{config.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image)
ax.axis("off")
plt.show()
if __name__ == "__main__":
_, valid_df = make_train_valid_dfs()
embedding_path = "valid_image_embeddings.pt"
model_path = "best.pt"
# 모델 준비
model = CLIPModel().to(config.device)
model.load_state_dict(torch.load(model_path))
model.eval()
# 이미지 임베딩 불러오거나 새로 계산
if os.path.exists(embedding_path):
print(f"Loading image embeddings from {embedding_path}")
image_embeddings = torch.load(embedding_path).to(config.device)
else:
print(f"No saved embeddings found. Generating...")
_, image_embeddings = get_image_embeddings(valid_df, model_path, save_path=embedding_path)
find_matches(model,
image_embeddings,
query="one dog sitting on the grass",
image_filenames=valid_df['image'].values,
n=9)
질문 (query)으로 "one dog sitting on the grass"를 입력하자, 모델은 아래와 같이 query에 적절한 이미지들을 검색함을 확인할 수 있습니다.
model.load_state_dict(torch.load(model_path, map_location=config.device))
100%|██████████| 994/994 [01:01<00:00, 16.12it/s]
이상으로 포스팅을 마칩니다. 독자분들께서 이번 포스팅을 통해 CLIP을 실제로 구현해보았기를 바랍니다. 마지막으로 해당 Pytorch를 구현하는데 도움 받은 Khalid Salama (Keras 구현)에게 감사 인사를 드립니다.
'딥러닝 > 컴퓨터비전' 카테고리의 다른 글
[컴퓨터비전] 기초부터 시작하는 ViT (Pytorch 구현) (1) | 2025.06.23 |
---|---|
[컴퓨터비전] 데이터 증강 종류 및 코드 (Pytorch, Albumentations, Imgaug) (1) | 2024.12.06 |
[컴퓨터비전] Cityscapes annotation을 COCO (.json)로 변경하는 방법 (2) | 2024.09.19 |
[컴퓨터비전] KITTI dataset label (.txt) 파일을 PASCAL VOC label (.xml)로 변경하는 방법 (0) | 2024.09.19 |
[컴퓨터비전] 윈도우 환경에서 detectron2 설치하는 방법 (1) | 2024.09.19 |