PyTorch를 이용해 변분 오토인코더(Variational Autoencoder, VAE)를 구현하는 방법에 대해 이야기해보려 합니다. VAE는 이미지 같은 복잡한 데이터를 학습하고 새로운 데이터를 생성할 수 있는 딥러닝 모델 중 하나입니다.

 

이 포스트에서는, 간단한 MNIST 데이터셋을 사용하여 VAE 모델을 어떻게 구축하고 학습시킬 수 있는지 살펴봅니다.

 

이미지 저장 및 전체 소스코드는 아래 Github에 업로드 해놓았습니다.

https://github.com/dev-jinwoohong/vae-pytorch

 

GitHub - dev-jinwoohong/vae-pytorch

Contribute to dev-jinwoohong/vae-pytorch development by creating an account on GitHub.

github.com

Install

pip install torch torchvision

라이브러리 및 모듈 Import

import torch
from torch import nn, optim
import torch.utils.data
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

데이터 로딩

MNIST 데이터셋을 로드하고, 학습용과 테스트용 DataLoader를 생성한다. 이들은 각각 모델을 학습시키고 평가하는 데 사용된다. transform=transforms.ToTensor()를 통해 이미지를 PyTorch 텐서로 변환하고, batch_size=128은 한 번에 128개의 이미지를 처리한다.

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False, num_workers=

모델 정의 (VAE)

VAE 클래스는 Variational Autoencoder 모델을 정의한다. 이는 Input Image를 Latent space의 벡터로 인코딩한 다음, 이 벡터를 사용해 원본 이미지를 재구성하는 신경망이다. encode, reparameterize, decode 메서드는 각각 인코딩, 재매개변수화, 디코딩 과정을 담당한다.

class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc2_mu = nn.Linear(h_dim, latent_dim)
        self.fc2_var = nn.Linear(h_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        out = F.relu(self.fc1(x))
        return self.fc2_mu(out), self.fc2_var(out)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x):
        out = F.relu(self.fc3(x))
        return torch.sigmoid(self.fc4(out))

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

Loss Function

VAE의 손실 함수는 두 부분으로 구성된다 : 재구성 손실(Binary Cross Entropy)KL 발산(KLD).

재구성 손실은 원본 이미지와 재구성된 이미지 사이의 차이를 측정하고, KL 발산은 잠재 공간의 분포가 표준 정규 분포를 얼마나 잘 따르는지를 측정한다.

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return BCE + KLD

Train

train 함수는 모델을 학습 모드로 설정하고, 학습 datalodaer를 통해 반복하며 Loss을 계산하고 역전파를 통해 모델을 업데이트합니다. 각 epoch마다 평균 loss을 출력.

def train(epoch):
    model.train()
    train_loss = 0
    for idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, log_var = model(data)
        loss = loss_function(recon, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

Test

test 함수는 모델을 평가 모드로 설정하고, 테스트 datalodaer를 통해 반복하면서 손실을 계산한다. 첫 번째 배치의 이미지와 재구성된 이미지를 저장하여 모델의 성능을 시각적으로 확인할 수 있습니다.

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for idx, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon, mu, log_var = model(data)
            test_loss += loss_function(recon, data, mu, log_var).item()
            if idx == 0:
                n = min(data.size(0), 10)
                comparison = torch.cat([data[:n],
                                        recon.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                           './results/epoch_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

Main

traintest 를 실행한다. 무작위로 생성된 s(latent space vector 값을 의미)를 10 epoch 마다 모델을 통해 새로운 이미지를 생성하고 저장한다. 이는 모델이 학습하는 동안 latent space에서 샘플링하여 얻은 벡터를 디코딩하여 얻는 과정을 볼 수 있다. 이는 모델이 얼마나 잘 일반화하는지를 보여주는 좋은 방법이다.

if __name__ == "__main__":
    os.makedirs('./results', exist_ok=True)
    s = torch.randn(64, 20).to(device)
    for epoch in range(0, 101):
        train(epoch)
        test(epoch)
        if epoch % 10 == 0:
            sample = model.decode(s).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       './results/sample_{}_result.png'.format(epoch))

 

결과 이미지

테스트 데이터셋에 대한 epoch별 결과 이미지

 

epoch별로 생성된 잠재 공간 샘플 이미지

 

전체 코드

import torch
from torch import nn, optim
import torch.utils.data
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc2_mu = nn.Linear(h_dim, latent_dim)
        self.fc2_var = nn.Linear(h_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        out = F.relu(self.fc1(x))
        return self.fc2_mu(out), self.fc2_var(out)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, x):
        out = F.relu(self.fc3(x))
        return torch.sigmoid(self.fc4(out))

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return BCE + KLD

def train(epoch):
    model.train()
    train_loss = 0
    for idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, log_var = model(data)
        loss = loss_function(recon, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for idx, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon, mu, log_var = model(data)
            test_loss += loss_function(recon, data, mu, log_var).item()
            if idx == 0:
                n = min(data.size(0), 10)
                comparison = torch.cat([data[:n],
                                        recon.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                           './results/epoch_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    s = torch.randn(64, 20).to(device)
    for epoch in range(0, 101):
        train(epoch)
        test(epoch)
        if epoch % 10 == 0:
            sample = model.decode(s).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_{}_result.png'.format(epoch))