딥러닝/컴퓨터비전

[컴퓨터비전] 기초부터 시작하는 CLIP (Pytorch 구현)

johyeongseob 2025. 6. 26. 17:16

📌 이 글에 대하여
이 게시글은 Kaggle의 Moein Shariatnia 님이 작성한 원문 글을 한국어로 번역한 것입니다.
원문은 Apache License 2.0 하에 공개되었으며, 이 블로그 역시 해당 라이선스를 따릅니다.

원문 저자:Moein Shariatnia
원문 위치: Kaggle Notebook
라이선스: Apache License 2.0 전문 보기

본 번역은 비상업적/교육적 목적이며, 원문 저자의 저작권과 라이선스를 존중합니다.

코드: https://github.com/johyeongseob/from-scratch-ai
데이터셋: https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset


라이브러리 설치

conda create -n clip-env python=3.10
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install timm
pip install transformers==4.30.2

소개

2021년 1월 OpenAICLIP을 발표했습니다. 이 multi-modality model은 textsimages를 연결합니다. 이번 포스팅에서는 Pytorch로 구현한 CLIP을 구현해봅니다. 

What does CLIP do? Why is it fun?

Learning Transferable Visual Models From Natural Language Supervision 논문에서 OpenAI는 CLIP, Contrastive Language-Image Pre-training, 이라는 모델을 소개합니다. 모델은 문장과 문장이 설명하는 이미지의 관계를 학습합니다. 모델은 학습 시, 입력 문장에 대해 문장과 가장 관련성이 높은 이미지를 검색할 수 있습니다. 훈련을 마친 CLIP은 ImageNet에 대한 classification에서 SOTA를 달성하였습니다.

teaser로 이번 포스팅에서 다루는 CLIP의 결과물을 아래 보여드립니다.


Config

import torch

debug = False
image_path = "./flickr30k_images"
captions_path = "./captions.csv"

batch_size = 32
num_workers = 4
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 40  # covers 99.5% of Flickr30k captions, max=88

pretrained = True
trainable = True
temperature = 1.0

size = 224

num_projection_layers = 1
projection_dim = 256
dropout = 0.1

Utils

class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        return f"{self.name}: {self.avg:.4f}"

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

Dataset

모델은 이미지 그리고 텍스트 임베딩이 필요합니다. 이를 위한 데이터셋을 구성합니다. 게시글 저자는 텍스트 임베딩을 위해 사전 훈련된 DistilBERT (HuggingFace)를 사용합니다. 

import torch
from torch.utils.data import Dataset
import config
from PIL import Image


class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """

        self.image_filenames = image_filenames  # Each image is repeated 5 times for 5 captions in captions.csv
        self.captions = list(captions)  # class: 'pandas.core.series.Series' -> 'list'
        self.encoded_captions = tokenizer(  # Encode all captions from the CSV file using the tokenizer
            list(captions),
            padding=True,
            truncation=True,
            max_length=config.max_length
        )
        self.transforms = transforms

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        item = {}

        # Extract the encoded token (e.g., input_ids, attention_mask) for the current index
        for key, values in self.encoded_captions.items():
            item[key] = torch.tensor(values[idx])

        image = Image.open(f"{config.image_path}/{self.image_filenames[idx]}").convert("RGB")  # directory + filename
        item['image'] = self.transforms(image)
        item['caption'] = self.captions[idx]

        return item

Data Utils

import torch
import numpy as np
import pandas as pd
import config
from CLIPDataset import CLIPDataset
from torch.utils.data import Dataset
from torchvision import transforms


def make_train_valid_dfs():
    dataframe = pd.read_csv(f"{config.captions_path}")
    max_id = dataframe["id"].max() + 1 if not config.debug else 100
    image_ids = np.arange(0, max_id)  # 31782
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader


# You can add augmentations in the train mode if needed
def get_transforms(mode="train"):
    if mode == "train":
        return transforms.Compose(
            [
                transforms.Resize((config.size, config.size)),
                transforms.ToTensor()
            ]
        )
    else:
        return transforms.Compose(
            [
                transforms.Resize((config.size, config.size)),
                transforms.ToTensor()
            ]
        )

CLIP

Image Encoder

게시글 저자는 image encoder로 사전 학습한 ResNet50 (timm)을 사용합니다. 여기서 fclayer는 제거하고 순수하게 feature extractor $\in \mathbb{R}^{2048}$ 로만 사용합니다.

class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
            self,
            model_name=config.model_name,
            pretrained=config.pretrained,
            trainable=config.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name,
            pretrained,
            num_classes=0,  # self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
            global_pool="avg"
        )

        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

Text Encoder

앞서 언급한 대로, 게시글 저자는 text encdoer로 DistilBERT를 사용합니다. 저자는 output 중 CLS 토큰 $\in \mathbb{R}^{768}$ 만을 사용합니다.

class TextEncoder(nn.Module):
    def __init__(
            self,
            model_name=config.text_encoder_model,
            pretrained=config.pretrained,
            trainable=config.trainable
    ):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

Projection Head

Image와 Text로부터 얻은 embedding을 하나의 space로 공유하기 위해 project 합니다. projection 후, 관련 있는 embedding $\in \mathbb{R}^{256}$ 은 push하고, 그렇지 않으면 pull하는 작업을 수행합니다.

class ProjectionHead(nn.Module):
    def __init__(
            self,
            embedding_dim,
            projection_dim=config.projection_dim,
            dropout=config.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

CLIP

이제 image encoder와 text encoder와 projection head를 합쳐 CLIP 모델을 구성합니다. 여기서 재밌는 부분은 target으로 soft label을 사용하였습니다. 왜냐하면, 데이터셋 (Flicker-30k) 특성에 따라, 하나의 이미지에 여러 캡션 (5개) 이 존재하기 때문에 이들을 함께 학습시키기 위해서는 hard label을 사용할 수 없기 때문입니다.

import torch
import torch.nn as nn
import timm
import config
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertConfig


class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=config.temperature,
        image_embedding=config.image_embedding,
        text_embedding=config.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(  # soft label
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss = (images_loss + texts_loss) / 2.0  # shape: (batch_size)
        return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

Train

 아래 훈련 함수 (handy function) 를 구성하였습니다. 저는 windows에서 실험을 진행하기 때문에, data_loader의 num_workers (Data Utils 항목 참고)를 사용하기 위해 아래처럼 실행하였습니다.

if __name__ == "__main__":
    main()
import config
from utils import AvgMeter, get_lr
import itertools
from tqdm import tqdm
from data_utils import make_train_valid_dfs, build_loaders
from model import CLIPModel
import torch
from transformers import DistilBertTokenizer


def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter


def main():
    train_df, valid_df = make_train_valid_dfs()
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(config.device)
    params = [
        {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": config.head_lr, "weight_decay": config.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=config.patience, factor=config.factor
    )
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(model, valid_loader)

        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            torch.save(model.state_dict(), "best.pt")
            print("Saved Best Model!")

        lr_scheduler.step(valid_loss.avg)


if __name__ == "__main__":
    main()

실행 결과

100%|██████████| 3973/3973 [07:42<00:00,  8.59it/s, lr=0.0001, train_loss=1.4]
100%|██████████| 994/994 [01:24<00:00, 11.74it/s, valid_loss=2.15]
  0%|          | 0/3973 [00:00<?, ?it/s]Saved Best Model!
Epoch: 2
100%|██████████| 3973/3973 [07:45<00:00,  8.54it/s, lr=0.0001, train_loss=0.455]
100%|██████████| 994/994 [01:02<00:00, 15.82it/s, valid_loss=2.21]

Inference

학습한 모델에서 이미지 임베딩을 추출한 다음, query와 가장 유사한 이미지를 검색하는 작업을 수행하겠습니다.

import os
import config
from tqdm import tqdm
from data_utils import make_train_valid_dfs, build_loaders
from model import CLIPModel
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
import cv2

def get_image_embeddings(valid_df, model_path, save_path=None):
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    model = CLIPModel().to(config.device)
    model.load_state_dict(torch.load(model_path, map_location=config.device))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(config.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)

    all_embeddings = torch.cat(valid_image_embeddings)
    if save_path:
        torch.save(all_embeddings, save_path)
    return model, all_embeddings


def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(config.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{config.image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")

    plt.show()


if __name__ == "__main__":
    _, valid_df = make_train_valid_dfs()
    embedding_path = "valid_image_embeddings.pt"
    model_path = "best.pt"

    # 모델 준비
    model = CLIPModel().to(config.device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # 이미지 임베딩 불러오거나 새로 계산
    if os.path.exists(embedding_path):
        print(f"Loading image embeddings from {embedding_path}")
        image_embeddings = torch.load(embedding_path).to(config.device)
    else:
        print(f"No saved embeddings found. Generating...")
        _, image_embeddings = get_image_embeddings(valid_df, model_path, save_path=embedding_path)

    find_matches(model,
                 image_embeddings,
                 query="one dog sitting on the grass",
                 image_filenames=valid_df['image'].values,
                 n=9)

질문 (query)으로 "one dog sitting on the grass"를 입력하자, 모델은 아래와 같이 query에 적절한 이미지들을 검색함을 확인할 수 있습니다.

model.load_state_dict(torch.load(model_path, map_location=config.device))
100%|██████████| 994/994 [01:01<00:00, 16.12it/s]

 이상으로 포스팅을 마칩니다. 독자분들께서 이번 포스팅을 통해 CLIP을 실제로 구현해보았기를 바랍니다. 마지막으로 해당 Pytorch를 구현하는데 도움 받은 Khalid Salama (Keras 구현)에게 감사 인사를 드립니다.