Quantcast
Viewing all articles
Browse latest Browse all 14155

Fail to finetune CLIP model

I'm trying to finetune CLIP using anime image and tags from danbooru. I used huggingface to load the pretrained model and tried to train it with clip loss.

finetune.py

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import CLIPModel, CLIPConfigfrom optimum.exporters.onnx import export_modelsfrom transformers.models.clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjectionfrom optimum.exporters.onnx.model_configs import CLIPTextOnnxConfig, ViTOnnxConfigfrom typing import Dictimport randomfrom finetune_data import ImageTextPairDataclass CLIPVisionOnnxConfig(ViTOnnxConfig):    passclass CLIPTextModelWithProjectionOnnxConfig(CLIPTextOnnxConfig):    @property    def outputs(self) -> Dict[str, Dict[int, str]]:        return {"text_embeds": {0: "batch_size"},        }class CLIPVisionModelWithProjectionOnnxConfig(CLIPVisionOnnxConfig):    @property    def outputs(self) -> Dict[str, Dict[int, str]]:        return {"image_embeds": {0: "batch_size"},        }sim_eval = nn.CosineSimilarity(dim=1, eps=1e-6)def train(num_epochs, device, grad_acc=4, eval_steps=20):    config = CLIPConfig.from_pretrained("openai/clip-vit-base-patch32")    model = CLIPModel(config).to(device)    train_dataset = ImageTextPairData()    train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, pin_memory=True, num_workers=8)    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6, weight_decay=0.001)    pbar = tqdm(total=len(train_dataloader)*num_epochs)    for epoch in range(num_epochs):        for i, batch in enumerate(train_dataloader):            images, tokens, attn_masks = batch            images = images.to(device)            tokens = tokens.to(device)            attn_masks = attn_masks.to(device)            response = model(input_ids=tokens, attention_mask=attn_masks, pixel_values=images, return_loss=True)            loss = response.loss / grad_acc            if loss != 0:                loss.backward()            if (i + 1) % grad_acc == 0:                optimizer.step()                optimizer.zero_grad()            if (i + 1) % eval_steps == 0:                image, token, attn_mask = random.choice(train_dataset)                response = model(input_ids=token.to(device).unsqueeze(0), attention_mask=attn_mask.to(device).unsqueeze(0), pixel_values=image.to(device).unsqueeze(0))                similarity = sim_eval(response.image_embeds, response.text_embeds)                print("\nSample Distance", similarity[0].item())            pbar.update(1)            pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {loss.item():.4f}")        optimizer.step()        optimizer.zero_grad()    model.save_pretrained("clip-finetuned", from_pt=True)def merge_models(alpha):    original_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")    finetuned_model = CLIPModel.from_pretrained("clip-finetuned")    for (param_orig, param_fint) in zip(original_model.parameters(), finetuned_model.parameters()):        # Perform linear interpolation on weights        param_fint.data = alpha * param_orig.data + (1 - alpha) * param_fint.data    finetuned_model.save_pretrained("clip-finetuned", from_pt=True) def export():    text_model = CLIPTextModelWithProjection.from_pretrained("clip-finetuned")    vision_model = CLIPVisionModelWithProjection.from_pretrained("clip-finetuned")    export_models(        models_and_onnx_configs={"text_model": (text_model, CLIPTextModelWithProjectionOnnxConfig(text_model.config)),"vision_model": (vision_model, CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)),        },        output_dir="clip-finetune-onnx",    )if __name__ == '__main__':    train(5, "cuda")    #merge_models(1)    export()

finetune_data.py

from torch.utils.data import Datasetimport jsonlinesfrom transformers import CLIPTokenizerfrom PIL import Imagefrom torchvision import transformsimport torchimport randomtokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")image_transform = transforms.Compose([    transforms.RandomResizedCrop(size=(224, 224), antialias=True),    transforms.RandomHorizontalFlip(p=0.5),    transforms.ToTensor(),    transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])])class ImageTextPairData(Dataset):    def __init__(self, eval=False):        self.image_path = []        self.list_prompts = []        filename = "train.jsonl"        if eval:            filename = "eval.jsonl"        with jsonlines.open(filename) as reader:            for i, obj in enumerate(reader):                if len(obj["prompt"]) > 0:                    self.image_path.append(obj["path"])                    self.list_prompts.append(obj["prompt"])    def __len__(self):        return len(self.list_prompts)    def __getitem__(self, idx):        image = image_transform(Image.open(self.image_path[idx]).convert('RGB'))        tokenized_text = tokenizer(self.list_prompts[idx], padding="max_length", truncation=True, return_tensors="pt")        return image, tokenized_text.input_ids.squeeze(0), tokenized_text.attention_mask.squeeze(0)

samples from data:

{"prompt": "1girl, bangs, bare_shoulders, bikini, black_choker, blush, breasts, choker, cleavage, collarbone, eyebrows_visible_through_hair, hair_between_eyes, huge_breasts, jacket, long_hair, mole, mole_on_breast, mole_under_eye, multicolored_hair, nail_polish, off_shoulder, red_hair, red_nails, simple_background, streaked_hair, swimsuit, thigh_strap, thighs, twintails, bremerton_(azur_lane)", "path": "dataset\\99993299_p0_master1200.jpg"}{"prompt": "1girl, bangs, black_bow, black_choker, black_hairband, black_legwear, blue_eyes, blue_hair, blush, bow, choker, dress, eyebrows_visible_through_hair, frills, garter_straps, hair_between_eyes, hair_bow, hairband, indoors, light_particles, long_sleeves, looking_at_viewer, maid, puffy_sleeves, shoes, short_hair, solo, thighhighs, window, wooden_floor, sora_ginko", "path": "dataset\\99995015_p0_master1200.jpg"}{"prompt": "1girl, bangs, bare_shoulders, bed, bed_sheet, blonde_hair, blue_eyes, blush, breasts, curtains, elf, eyebrows_visible_through_hair, hair_ribbon, indoors, large_breasts, long_hair, looking_at_viewer, navel, on_bed, parted_lips, pointy_ears, revealing_clothes, ribbon, seiza, sitting, solo, thighs, twintails, veil, very_long_hair, white_bow, white_ribbon, window", "path": "dataset\\99996628_p0_master1200.jpg"}

I tried different learning rates and batch sizes, but the model just go worse in every case.The image encoder produces nearly the same embedding for completely different images. The output of text encoder is also different from what I expected (not matching the images).

I have no experience with this so I can't figure out what's going on. I found that there are other people with similar problem (etc. https://github.com/openai/CLIP/issues/83) but there is no answer for that.


Viewing all articles
Browse latest Browse all 14155

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>