PyTorch를 이용해 변분 오토인코더(Variational Autoencoder, VAE)를 구현하는 방법에 대해 이야기해보려 합니다. VAE는 이미지 같은 복잡한 데이터를 학습하고 새로운 데이터를 생성할 수 있는 딥러닝 모델 중 하나입니다.
이 포스트에서는, 간단한 MNIST 데이터셋을 사용하여 VAE 모델을 어떻게 구축하고 학습시킬 수 있는지 살펴봅니다.
이미지 저장 및 전체 소스코드는 아래 Github에 업로드 해놓았습니다.
https://github.com/dev-jinwoohong/vae-pytorch
GitHub - dev-jinwoohong/vae-pytorch
Contribute to dev-jinwoohong/vae-pytorch development by creating an account on GitHub.
github.com
Install
pip install torch torchvision
라이브러리 및 모듈 Import
import torch
from torch import nn, optim
import torch.utils.data
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
데이터 로딩
MNIST 데이터셋을 로드하고, 학습용과 테스트용 DataLoader를 생성한다. 이들은 각각 모델을 학습시키고 평가하는 데 사용된다. transform=transforms.ToTensor()
를 통해 이미지를 PyTorch 텐서로 변환하고, batch_size=128
은 한 번에 128개의 이미지를 처리한다.
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=False, num_workers=
모델 정의 (VAE)
VAE
클래스는 Variational Autoencoder 모델을 정의한다. 이는 Input Image를 Latent space의 벡터로 인코딩한 다음, 이 벡터를 사용해 원본 이미지를 재구성하는 신경망이다. encode
, reparameterize
, decode
메서드는 각각 인코딩, 재매개변수화, 디코딩 과정을 담당한다.
class VAE(nn.Module):
def __init__(self, input_dim=784, h_dim=400, latent_dim=20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(input_dim, h_dim)
self.fc2_mu = nn.Linear(h_dim, latent_dim)
self.fc2_var = nn.Linear(h_dim, latent_dim)
self.fc3 = nn.Linear(latent_dim, h_dim)
self.fc4 = nn.Linear(h_dim, input_dim)
def encode(self, x):
out = F.relu(self.fc1(x))
return self.fc2_mu(out), self.fc2_var(out)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, x):
out = F.relu(self.fc3(x))
return torch.sigmoid(self.fc4(out))
def forward(self, x):
mu, log_var = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
Loss Function
VAE의 손실 함수는 두 부분으로 구성된다 : 재구성 손실(Binary Cross Entropy)과 KL 발산(KLD).
재구성 손실은 원본 이미지와 재구성된 이미지 사이의 차이를 측정하고, KL 발산은 잠재 공간의 분포가 표준 정규 분포를 얼마나 잘 따르는지를 측정한다.
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
Train
train
함수는 모델을 학습 모드로 설정하고, 학습 datalodaer를 통해 반복하며 Loss을 계산하고 역전파를 통해 모델을 업데이트합니다. 각 epoch마다 평균 loss을 출력.
def train(epoch):
model.train()
train_loss = 0
for idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon, mu, log_var = model(data)
loss = loss_function(recon, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
Test
test
함수는 모델을 평가 모드로 설정하고, 테스트 datalodaer를 통해 반복하면서 손실을 계산한다. 첫 번째 배치의 이미지와 재구성된 이미지를 저장하여 모델의 성능을 시각적으로 확인할 수 있습니다.
def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for idx, (data, _) in enumerate(test_loader):
data = data.to(device)
recon, mu, log_var = model(data)
test_loss += loss_function(recon, data, mu, log_var).item()
if idx == 0:
n = min(data.size(0), 10)
comparison = torch.cat([data[:n],
recon.view(-1, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'./results/epoch_' + str(epoch) + '.png', nrow=n)
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
Main
train
과 test
를 실행한다. 무작위로 생성된 s
(latent space vector 값을 의미)를 10 epoch 마다 모델을 통해 새로운 이미지를 생성하고 저장한다. 이는 모델이 학습하는 동안 latent space에서 샘플링하여 얻은 벡터를 디코딩하여 얻는 과정을 볼 수 있다. 이는 모델이 얼마나 잘 일반화하는지를 보여주는 좋은 방법이다.
if __name__ == "__main__":
os.makedirs('./results', exist_ok=True)
s = torch.randn(64, 20).to(device)
for epoch in range(0, 101):
train(epoch)
test(epoch)
if epoch % 10 == 0:
sample = model.decode(s).cpu()
save_image(sample.view(64, 1, 28, 28),
'./results/sample_{}_result.png'.format(epoch))
결과 이미지


전체 코드
import torch
from torch import nn, optim
import torch.utils.data
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
class VAE(nn.Module):
def __init__(self, input_dim=784, h_dim=400, latent_dim=20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(input_dim, h_dim)
self.fc2_mu = nn.Linear(h_dim, latent_dim)
self.fc2_var = nn.Linear(h_dim, latent_dim)
self.fc3 = nn.Linear(latent_dim, h_dim)
self.fc4 = nn.Linear(h_dim, input_dim)
def encode(self, x):
out = F.relu(self.fc1(x))
return self.fc2_mu(out), self.fc2_var(out)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, x):
out = F.relu(self.fc3(x))
return torch.sigmoid(self.fc4(out))
def forward(self, x):
mu, log_var = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
def train(epoch):
model.train()
train_loss = 0
for idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon, mu, log_var = model(data)
loss = loss_function(recon, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for idx, (data, _) in enumerate(test_loader):
data = data.to(device)
recon, mu, log_var = model(data)
test_loss += loss_function(recon, data, mu, log_var).item()
if idx == 0:
n = min(data.size(0), 10)
comparison = torch.cat([data[:n],
recon.view(-1, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'./results/epoch_' + str(epoch) + '.png', nrow=n)
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
if __name__ == "__main__":
s = torch.randn(64, 20).to(device)
for epoch in range(0, 101):
train(epoch)
test(epoch)
if epoch % 10 == 0:
sample = model.decode(s).cpu()
save_image(sample.view(64, 1, 28, 28),
'results/sample_{}_result.png'.format(epoch))
'AI Research' 카테고리의 다른 글
[Paper Review] AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE (0) | 2025.01.24 |
---|---|
[Paper Review] Attention Is All You Need (0) | 2025.01.13 |
[Stable Diffusion] 상황 별 Negative prompt (1) | 2024.01.09 |
Diffusion Model vs GANs (2) | 2023.12.28 |