I'm trying to fine-tunning the VAE of SD 1.4
I'm in a multi gpu environment, and I'm using accelerate
library for handling that.This is my code summarized:
import osimport torch.nn.functional as Fimport yamlfrom PIL import Imageimport torchfrom torch.utils.data import Dataset, DataLoaderfrom torchvision.transforms import Compose, Resize, ToTensor, Normalizefrom diffusers import AutoencoderKLfrom torch.optim import Adamfrom accelerate import Acceleratorfrom torch.utils.tensorboard import SummaryWriter# Load configurationwith open('config.yaml', 'r') as file: config = yaml.safe_load(file)def save_checkpoint(model, optimizer, epoch, step, filename="checkpoint.pth.tar"): checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch,'step': step } torch.save(checkpoint, filename)class ImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.png')] def __len__(self): return len(self.images) def __getitem__(self, idx): img_path = self.images[idx] image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) return image# Setup dataset and dataloader based on configtransform = Compose([ Resize((config['dataset']['image_size'], config['dataset']['image_size'])), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])dataset = ImageDataset(root_dir=config['dataset']['root_dir'], transform=transform)dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=config['training']['num_workers'])# Initialize model, accelerator, optimizer, and TensorBoard writerdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model_path = config['model']['path']vae = AutoencoderKL.from_pretrained(model_path).to(device)optimizer = Adam(vae.parameters(), lr=config['training']['learning_rate'])accelerator = Accelerator()vae, dataloader = accelerator.prepare(vae, dataloader)writer = SummaryWriter()# Training loopfor epoch in range(config['training']['num_epochs']): vae.train() total_loss = 0 for step, batch in enumerate(dataloader): with accelerator.accumulate(vae): # Assuming the first element of the batch is the image target = batch[0].to(next(vae.parameters()).dtype) # Access the original model for custom methods model = vae.module if hasattr(vae, "module") else vae posterior = model.encode(target).latent_dist z = posterior.mode() pred = model.decode(z).sample kl_loss = posterior.kl().mean() mse_loss = F.mse_loss(pred, target, reduction="mean") loss = mse_loss + config['training']["kl_scale"] * kl_loss optimizer.zero_grad() accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # Clear gradients after updating weights # Checkpointing every 10 steps if step % 10 == 0: checkpoint_path = f"checkpoint_epoch_{epoch}_step_{step}.pth" accelerator.save({"epoch": epoch,"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"loss": loss, }, checkpoint_path) print(f"Checkpoint saved to {checkpoint_path}")writer.close()print("Training complete.")
When running the code, I got the following error:
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [128] and input of shape [128, 1024, 1024]:
My input folder contains a set of png images with different sizes, and resized to 1024x1024 in the configuration file.
I do not know why this is happening and if someone knows, or if there is a easier way to fine-tunning the VAE weights using my images.Thanks.
Edit:My config.yaml
file
model: path: 'vae1dot4' # Path to your pre-trained model directorydataset: root_dir: 'segmented' # Directory containing your PNG images image_size: 1024 # Target size for image resizingtraining: batch_size: 8 # Batch size for training num_epochs: 10 # Number of epochs to train learning_rate: 0.0005 # Learning rate for the optimizer num_workers: 4 # Number of worker processes for data loading kl_scale: 1 gradient_accumulation_steps: 1logging: tensorboard_dir: 'runs' # Directory for TensorBoard logs