NumPyroによるベイズモデリング入門【線形回帰編】

はじめに

 こんにちは deepblue でインターン生として働いている渡邊です。
最近、PyroやTensorflow Probabilityなどの深層学習ライブラリベースのGPU対応PPL(確率的プログラミング言語)が出てきていますが、なかなか知られていないNumPyroなるものがあるそうです。NumPyroはバックエンドがJaxでサポートされているPPLで、マルコフ連鎖モンテカルロ(MCMC)法によるサンプリングが高速らしいので、今回は線形回帰で使用感を試してみたいと思います。

 今回の内容は、ベイズモデリングの用語(事前分布、事後分布など)はご存知の方が対象ですので、そこも怪しいと思われる方はこの記事をさらっと見てから来てくださると理解しやすいと思います。

参考サイト

使用環境

  • Google Colaboratory

ライブラリの準備

pip install numpyro
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpyro
sns.set_style("darkgrid")

単回帰モデル

 今回は1番シンプルな、単回帰モデル ~y=ax+b+\epsilon をモデリングしていきます。
確率モデル~L(a,b,\sigma)、事前分布~p(a,b,\sigma)はそれぞれ以下のように設定します。

  • 事前分布
    a \sim \mathrm{Normal}(0,100)
    b \sim \mathrm{Normal}(0,100)
    \sigma \sim \mathrm{LogNormal}(0,10)

  • 確率モデル
    y_i \sim \mathrm{Normal}(ax_i+b,\sigma)

 これは、以下のようなデータの生成過程をモデル化することに対応しています。

  1. 回帰係数~a、切片項~b、観測の分散パラメータ~\sigmaを事前分布からサンプル
  2. それらと~x_iにより確率モデルの形状を決定
  3. 確率モデルから観測~y_iをサンプル
  4. 2.3.をN回繰り返す

ここではあくまで、なんの学習(推論)も行っていない段階でこのようにモデル化できるだろうと考えているモデルです。
このような生成過程を記述する事前分布と確率モデルをNumPyroで書くと以下のようになります。

def model(x, y, N=100):
  ## 回帰係数 a、切片項 b の事前分布に平均0, 分散100の正規分布を置きます。
  a = numpyro.sample("a",
                     numpyro.distributions.Normal(loc=jnp.array(0.), scale=jnp.array(100.)))
  b = numpyro.sample("b",
                     numpyro.distributions.Normal(loc=jnp.array(0.), scale=jnp.array(100.)))
  ## 分散パラメータ sigma の事前分布に平均0, 分散10の対数正規分布を置きます。
  sigma = numpyro.sample("sigma",
                       numpyro.distributions.LogNormal(0, 10))
  ## 観測についての確率モデルとして、正規分布を置きます。
  ## yは実際に観測されているものなので、データyと観測モデルを紐づけるために、obs=yと設定しておきます
  with numpyro.plate("data", N):
    numpyro.sample("obs",numpyro.distributions.Normal(a*x + b, sigma),
                   obs=y)

(jnp.arrayはnp.arrayのようなものであると考えてください。)

 このようにNumPyroでは、自分で考えた生成過程を直感的にプログラムに落とし込むことができます。
確率モデルの記述には numpyro.plate を用いています。plateは、~y_iのように複数の確率変数をまとめて書くことができる便利な関数です。

データの生成

以下のようなモデルからデータを生成します。
y_i \sim \mathrm{Normal}(-5x+3,1.0)

確率モデルは、
~y_i \sim \mathrm{Normal}(ax_i+b,\sigma)
と設定していたので、~a=-5, b=3, \sigma=1.0とした場合のデータを生成することになります。
当たり前ですが、~a=-5, b=3, \sigma=1.0 は分析者にとって未知の値です。ここでの問題設定ではこの~a,b,\sigmaの分布を推論します。推論結果として事後分布は、それぞれ ~a=-5, b=3, \sigma=1.0 あたりにピークが来るような分布であると嬉しいです。
 今回はデータ数を 5, 10, 50 として、それぞれの挙動を確認してみましょう

import random
def toy_data(a ,b ,N):
  x = np.linspace(-1,1,N)
  y = a * x + b  + 1.0*jax.random.normal(jax.random.PRNGKey(1), x.shape)
  return x, y
x_data50, y_data50 = toy_data(-5, 3, 50)
x_data10, y_data10 = x_data50[np.arange(0,50,5)], y_data50[np.arange(0,50,5)]
x_data5, y_data5 = x_data50[np.arange(0,50,10)], y_data50[np.arange(0,50,10)]
fig, ax = plt.subplots(3,1,sharex=True,sharey=True,figsize=(7,15))
ax[0].plot(x_data5, y_data5, "o")
ax[0].set_title("N=5")
ax[1].plot(x_data10, y_data10, "o")
ax[1].set_title("N=10")
ax[2].plot(x_data50, y_data50, "o")
ax[2].set_title("N=50")


データは、 y=-5x+3 の周りに散らばっています。

MCMCによるサンプリング

 事後分布は以下のように表現できます。
p(a,b,\sigma|X,Y) = \frac{p(Y|X,a,b,\sigma)p(a,b,\sigma)}{\int_{a,b,\sigma}p(Y|X,a,b,\sigma)p(a,b,\sigma) \mathrm{d}a\mathrm{d}b\mathrm{d}\sigma}
しかし、大抵の場合はこの事後分布を解析的に求めることができないので、事後分布を近似する変分ベイズ法や事後分布からサンプリングするMCMC法を用いて、事後分布を近似的に求めます。

 今回は、後者のMCMC法を使いたいと思います。MCMCでは、事後分布~p(a,b,\sigma|X,Y)からたくさんサンプリングして、事後分布~p(a,b,\sigma|X,Y)の形状を知ろうとするわけです。解析的に求まらない事後分布~p(a,b,\sigma|X,Y)からどうやってサンプリングするのかと疑問に思う方もいらっしゃるかと思いますが、そういう方法があるぐらいに思っておいてください。もっと深く知りたいという方は、こちらのブログをご覧ください。

## おまじない
## num_warmup, num_samplesあたりは適切に設定する必要がありますが、この辺りは他のPPLと同様なので割愛します。
kernel = numpyro.infer.NUTS(model)
mcmc5 = numpyro.infer.MCMC(kernel, num_warmup=300, num_samples=1000)
mcmc10 = numpyro.infer.MCMC(kernel, num_warmup=300, num_samples=1000)
mcmc50 = numpyro.infer.MCMC(kernel, num_warmup=300, num_samples=1000)

サンプリングしていきます。

## N=5の場合
mcmc5.run(jax.random.PRNGKey(1),x = x_data5, y = y_data5, N = x_data5.shape[0])
## N=10の場合
mcmc10.run(jax.random.PRNGKey(1),x = x_data10, y = y_data10, N = x_data10.shape[0])
## N=5の場合
mcmc50.run(jax.random.PRNGKey(1),x = x_data50, y = y_data50, N = x_data50.shape[0])

結果

事後分布から得られたサンプリングを眺めてみましょう。
get_samplesを用いて、サンプルを獲得します。

##N=5のとき
samples5 = mcmc5.get_samples()
a_samples5 = samples5["a"].squeeze()
b_samples5 = samples5["b"].squeeze()
sigma_samples5 = samples5["sigma"].squeeze()
##N=10のとき
samples10 = mcmc10.get_samples()
a_samples10 = samples10["a"].squeeze()
b_samples10 = samples10["b"].squeeze()
sigma_samples10 = samples10["sigma"].squeeze()
##N=50のとき
samples50 = mcmc50.get_samples()
a_samples50 = samples50["a"].squeeze()
b_samples50 = samples50["b"].squeeze()
sigma_samples50 = samples50["sigma"].squeeze()
##aについて
fig1, ax1 = plt.subplots(1,3,figsize=(15,7),sharey=True,sharex=True)
sns.distplot(a_samples5, bins=20, ax=ax1[0]).set_title("N=5")
sns.distplot(a_samples10, bins=20, ax=ax1[1]).set_title("N=10")
sns.distplot(a_samples50, bins=20, ax=ax1[2]).set_title("N=50")
fig1.suptitle("posterior of a")
##bについて
fig2, ax2 = plt.subplots(1,3,figsize=(15,7),sharey=True,sharex=True)
sns.distplot(b_samples5, bins=20, ax=ax2[0]).set_title("N=5")
sns.distplot(b_samples10, bins=20, ax=ax2[1]).set_title("N=10")
sns.distplot(b_samples50, bins=20, ax=ax2[2]).set_title("N=50")
fig2.suptitle("posterior of b")
##sigmaについて
fig3, ax3 = plt.subplots(1,3,figsize=(15,7),sharey=True,sharex=True)
sns.distplot(sigma_samples5, bins=20, ax=ax3[0]).set_title("N=5")
sns.distplot(sigma_samples10, bins=20, ax=ax3[1]).set_title("N=10")
sns.distplot(sigma_samples50, bins=20, ax=ax3[2]).set_title("N=50")
fig3.suptitle("posterior of sigma")

だいたい真値 ~a=-5, b=3, \sigma=0.5 あたりにピークを持つ事後分布になっています。また、データを増やすほど、不確実性が小さくなっている様子が見て取れるます。ただ、データが少なくてもまあまあいい結果にはなっていますね。

予測分布

 予測分布は以下のような分布です。
p^*(y|x,X,Y) = \int_{a,b,\sigma}p(y|x,a,b,\sigma)p(a,b,\sigma|X,Y) \mathrm{d}a\mathrm{d}b\mathrm{d}\sigma

 新たなデータ~xを入力したときに観測~yが従う(と考えている)分布が予測分布です。X,Yは事後分布の推論(サンプリング)に用いたデータ(学習データ)に対応しています。

予測分布は以下のような流れで得ます。

  1. ~a_k,b_k,\sigma_k \sim p(a,b,\sigma|X,Y) をKセット、サンプリングする。
  2. Kセットのパラメータ(a_k,b_k,\sigma_k)に対して、~y_i \sim p(y_i|x_i,a_k,b_k,\sigma_k)からN個サンプリング

 第一段階は、MCMC法によるサンプリングに他ならないので、第二段階のサンプリングのみを行えば良いこととなります。今回は、~K=1000,\ N=2000としておきます。要するに、各パラメータ~a_k,b_k,\sigma_k~(k=1,\dots,1000)毎に観測~y_iを2000個サンプリングすることとします。

予測分布の計算にはnumpyro.infer.Predictiveを使います。

predictive5 = numpyro.infer.Predictive(model, samples5)
predictive10 = numpyro.infer.Predictive(model, samples10)
predictive50 = numpyro.infer.Predictive(model, samples50)
index_points = np.linspace(-2., 2., 2000)
predictive_samples5 = predictive5.get_samples(
    jax.random.PRNGKey(1),
    index_points,
    None,
    index_points.shape[0])["obs"]
predictive_samples10 = predictive10.get_samples(
    jax.random.PRNGKey(1),
    index_points,
    None,
    index_points.shape[0])["obs"]
predictive_samples50 = predictive50.get_samples(
    jax.random.PRNGKey(1),
    index_points,
    None,
    index_points.shape[0])["obs"]
/usr/local/lib/python3.6/dist-packages/numpyro/infer/util.py:559: FutureWarning: The method .get_samples has been deprecated in favor of .__call__.
  FutureWarning)

予測平均、68%予測区間、99%予測区間、学習データを図示します。

mean5 = predictive_samples5.mean(axis=0)
mean10 = predictive_samples10.mean(axis=0)
mean50 = predictive_samples50.mean(axis=0)
std5 = predictive_samples5.std(axis=0)
std10 = predictive_samples10.std(axis=0)
std50 = predictive_samples50.std(axis=0)
lower5_1, upper5_1 = mean5 - std5, mean5 + std5
lower10_1, upper10_1 = mean10 - std10, mean10 + std10
lower50_1, upper50_1 = mean50 - std50, mean50 + std50
lower5_3, upper5_3 = mean5 - 3*std5, mean5 + 3*std5
lower10_3, upper10_3 = mean10 - 3*std10, mean10 + 3*std10
lower50_3, upper50_3 = mean50 - 3*std50, mean50 + 3*std50
fig, ax = plt.subplots(3,1,sharex=True,sharey=True,figsize=(7,15))
ax[0].scatter(x_data5, y_data5, color="g")
ax[0].plot(index_points, mean5)
ax[0].fill_between(index_points.squeeze(), lower5_1, upper5_1, alpha=0.3, color="b")
ax[0].fill_between(index_points.squeeze(), lower5_3, upper5_3, alpha=0.1, color="b")
ax[0].set_title("N=5")
ax[1].scatter(x_data10, y_data10, color="g")
ax[1].plot(index_points, mean10)
ax[1].fill_between(index_points.squeeze(), lower10_1, upper10_1, alpha=0.3, color="b")
ax[1].fill_between(index_points.squeeze(), lower10_3, upper10_3, alpha=0.1, color="b")
ax[1].set_title("N=10")
ax[2].scatter(x_data50, y_data50, color="g")
ax[2].plot(index_points, mean50)
ax[2].fill_between(index_points.squeeze(), lower50_1, upper50_1, alpha=0.3, color="b")
ax[2].fill_between(index_points.squeeze(), lower50_3, upper50_3, alpha=0.1, color="b")
ax[2].set_title("N=50")
fig.legend(["predict mean",
            "68% bayes predictive interval",
            "99% bayes predictive interval",
            "training data"])
plt.show()

 N=5の場合だと、データがない部分の不確実性(分散)が大きくなっていますが、N=10とN=50ではほとんど結果が変わらないように見えます。このようなシンプルな場合では、データ数が10程度で十分そうですね。
いい感じに予測できているようなので、今回はここまでにします。最後まで読んでいただきありがとうございました。
PPLの中でもStanに関しては、以下の記事なども参考になると思いますので、是非参考にしてみてください。

Header photo by PYRO