概要
今回は画像生成や画像の表現学習などに用いられる深層生成モデルの一種である変分オートエンコーダー(以下、VAE)のPytorch実装について解説していきたいと思います。
確率的プログラミング言語Pyroを用いた実装はこちらで解説しています。
また、VAEのモデル・学習方法といった理論的な部分についてはこちらで解説しています。目を通しておいていただけると、実装の理解も進みやすいと思います。
どの部分を説明しているかわかりやすいように、コード内にコメントアウト部分で主に解説をしています。それでは、見ていきましょう。
前提知識
- 基本的なニューラルネットワークに関する知識
- pytorchの基礎、
- VAEのモデル、(学習方法)
ライブラリの準備
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch.utils as utils
from torchvision import datasets, transforms
GPUを利用できる場合は、GPUを利用するようにします。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
データローダーの準備
データローダーを作成する関数を設計します。
データローダーはミニバッチ学習を行う際に必要となるものです。
例えば、全部で1000個のデータを32個ずつ小分けにして、32個使ってパラメータを学習 ⇒ 次の32個使ってパラメータを学習 ⇒ 次の32個~ のように繰り返し学習する際に用います。32がいわゆるバッチサイズです。
def setup_data_loaders(batch_size=128, use_cuda=True):
root = "../data"
download=True
# 画像にどのような前処理を施すか設定します。
# ToTensor()の他にもサイズを変更するResize()や標準化を行うnormalize()などを指定できます。
trans = transforms.ToTensor()
# まず、torchvision.datasetsからMNISTという手書き数字のデータセットを読み込みます。
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=download)
valid_set = datasets.MNIST(root=root, train=False, transform=trans)
# データセットをbatch_size個のデータごとに小分けにしたものにして、ミニバッチ学習が可能なようにします。
# shuffle=True にすると画像の順序がランダムになったりしますが、ここらへんはどっちでもいいと思います。
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=False)
return train_loader, valid_loader
モデルの設計
class VAE(nn.Module): # nn.Moduleクラスを引き継ぎ
def __init__(self, z_dim, x_dim=28*28):
super(VAE, self).__init__()
self.x_dim = x_dim
self.z_dim = z_dim
# エンコーダー用の関数
self.fc1 = nn.Linear(x_dim, 20)
self.fc2_mean = nn.Linear(20, z_dim)
self.fc2_var = nn.Linear(20, z_dim)
# デコーダー用の関数
self.fc3 = nn.Linear(z_dim, 20)
self.fc4 = nn.Linear(20, x_dim)
# エンコーダー
def encoder(self, x):
x = x.view(-1, self.x_dim)
x = F.relu(self.fc1(x))
mean = self.fc2_mean(x) # 平均
log_var = self.fc2_var(x) # 分散の対数
return mean, log_var
# 潜在ベクトルのサンプリング(再パラメータ化)
def reparametrizaion(self, mean, log_var, device):
epsilon = torch.randn(mean.shape, device=device)
return mean + epsilon*torch.exp(0.5 * log_var)
# デコーダー
def decoder(self, z):
y = F.relu(self.fc3(z))
y = torch.sigmoid(self.fc4(y)) # 各要素にシグモイド関数を適用し、値を(0,1)の範囲に
return y
def forward(self, x, device):
x = x.view(-1, self.x_dim)
mean, log_var = self.encoder(x) # 画像xを入力して、平均・分散を出力
KL = 0.5 * torch.sum(1+log_var - mean**2 - torch.exp(log_var)) # KL[q(z|x)||p(z)]を計算
z = self.reparametrizaion(mean, log_var, device) # 潜在ベクトルをサンプリング(再パラメータ化)
x_hat = self.decoder(z) # 潜在ベクトルを入力して、再構築画像 y を出力
reconstruction = torch.sum(x * torch.log(x_hat+1e-8) + (1 - x) * torch.log(1 - x_hat + 1e-8)) #E[log p(x|z)]
lower_bound = -(KL + reconstruction) #変分下界(ELBO)=E[log p(x|z)] - KL[q(z|x)||p(z)]
return lower_bound , z, x_hat
学習
dataloader_train, dataloader_valid = setup_data_loaders(batch_size=1000) # データローダーを作成
model = VAE(z_dim = 10).to(device) # モデルをインスタンス化し、GPUにのせる
optimizer = optim.Adam(model.parameters(), lr=1e-3) # オプティマイザーの設定
model.train() # モデルを訓練モードに
num_epochs = 100
loss_list = []
for i in range(num_epochs):
losses = []
for x, t in dataloader_train: # データローダーからデータを取り出す。
x = x.to(device) # データをGPUにのせる
loss, z, y = model(x, device) # 損失関数の値 loss 、潜在ベクトル z 、再構築画像 y を出力
model.zero_grad() # モデルの勾配を初期化
loss.backward() # モデル内のパラメータの勾配を計算
optimizer.step() # 最適化を実行
losses.append(loss.cpu().detach().numpy()) # ミニバッチの損失を記録
loss_list.append(np.average(losses)) # バッチ全体の損失を登録
print("EPOCH: {} loss: {}".format(i, np.average(losses)))
EPOCH: 0 loss: 491549.34375
EPOCH: 1 loss: 300059.1875
EPOCH: 2 loss: 233781.21875
EPOCH: 3 loss: 216165.5
EPOCH: 4 loss: 206224.546875
EPOCH: 5 loss: 198658.3125
(結果一部)
可視化
まず、画像を再構築してみましょう。
fig = plt.figure(figsize=(20,4))
model.eval()
zs = []
for x, t in dataloader_valid:
for i, im in enumerate(x.view(-1,28,28).detach().numpy()[:10]):
# 元画像を可視化
ax = fig.add_subplot(2, 10, i+1, xticks=[], yticks=[])
ax.imshow(im, "gray")
x = x.to(device)
_, _, y = model(x, device) #再構築画像 y を出力
y = y.view(-1,28,28)
for i, im in enumerate(y.cpu().detach().numpy()[:10]):
# 再構築画像を可視化
ax = fig.add_subplot(2,10,11+i, xticks=[], yticks=[])
ax.imshow(im, "gray")
ぼやけてはいますが、だいたい同じような画像が生成できていますね。畳み込みを使ってないにしては良さげだと思います。
次は、正規乱数から画像を生成してみます。
fig, ax = plt.subplots(nrows = 3, ncols=5, figsize=(20,12))
model.eval()
for r in range(3):
for c in range(5):
ax[r,c].imshow(model.decoder(torch.randn(10).cuda()).detach().cpu().numpy().reshape(28,28), cmap="gray")
ax[r,c].axis("off")
数字と言えないこともないですね。
今回は以上です。
まとめ
今回はPytorchを用いたVAEの実装について解説してみました。極力簡単なモデルにするために畳み込み層や逆畳み込み層を用いていないのであまり精度は高くありませんが、余りサイズの大きくないMNISTであればそれなりのモデルと言えそうですね。
今回はPytorchでの実装でしたが、確率的プログラミング言語のPyroの実装解説もありますので、興味がある方は是非ご覧ください。
次回はVAEの派生手法について解説できればと思っています。
参考文献
小川雄太郎(2019):『つくりながら学ぶ! PyTorchによる発展ディープラーニング』,マイナビ出版