概要
今回は変分オートエンコーダー(VAE)をPyroで実装したいと思います。理論的な部分については知っているものとして話を進めていこうと思うので、わからない部分がある場合はこちらをご確認ください。
Pyroについては、こちらに解説がありますのでさっと目を通すと読みやすいと思います。
それではさっそく実装していきましょう。
環境
Google Colabratory
ライブラリの準備
ここではpyroのみインストールしていますが、各々の環境に合わせて適宜追加してください。
!pip install pyro-ppl
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.contrib.examples.util import MNIST
再現性のためにランダムシードを指定しておきます。
pyro.set_rng_seed(0)
データの準備
データローダーを作成します。今回は手書き文字データセットMNISTを用います。
def setup_data_loaders(batch_size=128, use_cuda=False):
root = "./data"
download=True
trans = transforms.ToTensor()
train_set = MNIST(root=root, train=True, transform=trans, download=download)
test_set = MNIST(root=root, train=False, transform=trans)
kwargs = {"num_workers": 2, "pin_memory":use_cuda}
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, test_loader
データローダーは簡単に言うとミニバッチの集合です。詳しくはこちらの書籍やサイトを参考にしてください。
モデル
続いてモデルについて説明していきます。
VAEはエンコーダー q(z|x)とデコーダーp(x|z)の2つのネットワークから構成されます。
まずは、潜在ベクトルから画像を出力するデコーダーp(x|z)についてです。
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# 全結合層
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 784)
# 活性化関数
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
hidden = self.softplus(self.fc1(x))
# 各出力がベルヌーイ分布のパラメータになる→大きいほど黒に近くなる
loc_img = self.sigmoid(self.fc2(hidden))
# 画像を出力
return loc_img
今回はシンプルな全結合層のみを採用したネットワークを採用していますが、nn.Linearの部分をnn.ConvTranspose2d等に適宜変えることで逆畳み込みを用いることもできます。
loc_imgがシグモイド関数の出力であることからもわかるように各要素は0から1の値をとるので、各要素はベルヌーイ分布のパラメータとみなせます。
次に画像を入力し、潜在ベクトルを出力するエンコーダー q(z|x)を考えます。
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# 入力-中間層間の全結合層
self.fc1 = nn.Linear(784, hidden_dim)
# 中間-出力層間の全結合層
self.fc21 = nn.Linear(hidden_dim, z_dim) # 平均用
self.fc22 = nn.Linear(hidden_dim, z_dim) # 分散用
# 活性化関数の準備
self.softplus = nn.Softplus()
def forward(self, x):
# 入力の次元を修正(28×28=784)
x = x.reshape(-1, 784)
hidden = self.softplus(self.fc1(x))
# 各ピクセル値が従う正規分布の平均、分散を出力
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
エンコーダーの出力は、正規分布の平均(z_loc),標準偏差(z_scale)です。これらを用いてz = z_loc+z_scale*\epsilon, \epsilon\sim\mathcal{N}(0,I)のように正規分布からのサンプリングを実現することで、誤差逆伝播法を適用することができるようになります。これが再パラメータ化のミソです。
ここまでで設計したエンコーダー・デコーダーをまとめてVAEのクラスにします。
class VAE(nn.Module):
def __init__(self, z_dim=50, hidden_dim = 400, use_cuda=True):
super().__init__()
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
def model(self, x):
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_img = self.decoder(z)
pyro.sample("obs", dist.ContinuousBernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
def guide(self, x):
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
z_loc, z_scale = self.encoder(x)
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
def reconstruct_img(self, x):
z_loc, z_scale = self.encoder(x)
z = dist.Normal(z_loc, z_scale).sample()
loc_img = self.decoder(z)
return img
ここでは大きく三つの部分について説明しておきます。
- pyro.module("decoder", self.decoder)
これはpyroに、デコーダー内のネットワークのパラメータを学習対象として認識させるための部分です。pyro.module("encoder", self.encoder)についても同様です。 - .to_event(1)
詳しく説明すると長くなってしまうので、イメージだけ簡単に説明しておきます。例として
dist.Normal(torch.zeros(10), torch.ones(10))からのサンプリングを考えます。このサンプルは10次元の多変量正規分布からのサンプルでしょうか?あるいは10個のi.i.dなサンプルでしょうか?前者であれば、尤度は1つの値ですが、後者は10個の値になります。これらを区別するような目的で.to_event(1)を用いています。詳しくはこちらをご覧ください。 - dist.ContinuousBernoulli()
この部分でobs=x.reshape(-1,784)とされていることからわかるように、ここでは、(対数)尤度の計算が行われています。ここでxの各要素(各ピクセル)は正規化されており、0~1の連続値です。pyro公式ではこの部分でdist.Bernoulliを用いていますが、dist.Bernoulliの実現値は0 or 1であり、0~1の連続値に対する尤度計算を行うことができません。そこで、dist.Bernoulliを連続緩和したdist.CountinuousBernoulliを用いることで、尤度計算を可能にしているというわけです。(公式と異なる記述ですので、間違えている可能性もあります。参考までに)
ここまでで、モデルに関する説明が終わったのであとは学習するのみです。
学習
訓練データによる学習
def train(svi, train_loader, use_cuda=True):
epoch_loss = 0
for x,_ in train_loader:
if use_cuda:
x = x.cuda()
epoch_loss += svi.step(x)
normalizer_train = len(train_loader.dataset)
total_epoch_loss_train = epoch_loss/normalizer_train
return total_epoch_loss_train
テストデータでのELBOの評価
def evaluate(svi, test_loader, use_cuda=True):
test_loss = 0
for x,_ in test_loader:
if use_cuda:
x = x.cuda()
test_loss += svi.evaluate_loss(x)
normalizer_test = len(test_loader.dataset)
total_epoch_loss_test = test_loss/normalizer_test
return total_epoch_loss_test
学習
USE_CUDA = True
LEARNING_RATE = 1.0e-3
NUM_EPOCHS = 100
TEST_FREQUENCY = 5
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)
pyro.clear_param_store() #pyroに登録されているパラメータを一旦消し去ります。
vae = VAE(use_cuda=USE_CUDA)
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
train_elbo = []
test_elbo = []
for epoch in range(NUM_EPOCHS):
total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
train_elbo.append(-total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" %(epoch, total_epoch_loss_train))
if epoch % TEST_FREQUENCY == 0:
total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
test_elbo.append(-total_epoch_loss_test)
print("[epoch %03d] average test loss: %.4f" %(epoch, total_epoch_loss_test))
pyro.clear_param_store()はpyroに登録されているパラメータを初期化する役割を持っています。pyroはパラメータをグローバルに管理しているので、VAEのパラメータのみを学習対象として登録するために、初期化を行っているというわけです。
他の部分は通常のpytorchと変わらないと思います。
可視化
学習過程でのELBOの推移を確認してみましょう。
import matplotlib.pyplot as plt
plt.plot(train_elbo[::5],"-.")
plt.show()
学習自体はうまく進行しているようですね。
正規乱数から適当な画像を生成してみます。
fig, ax = plt.subplots(nrows = 3, ncols=5, figsize=(20,12))
for r in range(3):
for c in range(5):
ax[r,c].imshow(vae.decoder(torch.randn(50).cuda()).detach().cpu().numpy().reshape(28,28))
今回は表現力の乏しい全結合層のみを用いているので綺麗な画像を再現することはできていませんが、それっぽく見える画像もありますね。畳み込み層を用いることでさらにきれいに画像生成を行うことができますので、興味のある方は是非試してみてください。
まとめ
今回はpyroを用いたVAEの実装について解説してみました。
次回はpytorchでの実装や、別の画像生成モデルについて解説してみたいと思います。
参考文献
Variational Autoencoders
pyro.distributionsの解説
データローダーの解説
HEADER PHOTO BY PYRO