うまい寿司が食いたい。

うまい寿司が遠慮なく食べれるようになるまで,進捗とか垂れ流すブログ

ベイズ線形回帰

動機

MCMCで計算をすると,カルマンフィルタを実装したときは実感できたベイズ更新の部分がよくわからないという僕の気持ちの問題がありました。
そこで,須山さんの「ベイズ推論による機械学習入門」

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

の3章で,学習と推論を解析的に導出していたため,自分で計算・実装を行い理解を深めたいというのがこの記事を書く動機です。

線形回帰の式

1次関数での線形回帰

想像しやすいように,1次関数での線形回帰を行います。式は y = a x + b+ \varepsilon_rです。
 \varepsilon_rはノイズで,回帰分析の事後分布として求めたい値は, a, bとなります。
ノイズは簡単のために正規分布に従うと仮定して,
 \varepsilon_r \sim N(\varepsilon_r | 0, \lambda^{-1})
と定義します。テキストに習って \lambdaは既知であるとします。
ある一つの観測値 y_kとある一つの入力値 x_kが得られたときに,事後分布 p(a| y_k, x_k), p(b|y_k, x_k)が導出できれば,学習と予測ができることになります。

まず,簡単のために,bは既知であるとして,aの事後分布を解析的に求めてみます。
ベイズの定理より,求めたい事後分布 p(a| y_k, x_k)は,
 p(a| y_k, x_k) = \frac{p(y_k|x_k, a) p(a)}{p(y_k, x_k)}
と書けます。
このときに,事前分布 p(a)正規分布していると仮定し
 p(a) = N(a | m_a, \Lambda_a^{-1})=  \frac{1}{\sqrt{2\pi\Lambda_a^{-1}}}\exp\big(  -\frac{1}{2}\Lambda_a(a-m_a)^{2}  \big)
と書けるとします。ここで, m_a, \Lambda_aは事前分布の平均と分散で固定値で与えるハイパーパラメータです。

さて,事前分布 p(a)と入力値  x_kが与えられた元で,観測値 y_kが得られる度合いを表す尤度を考えます。
先程,ノイズが正規分布に従うとしたため,
 p(y_k|x_k, a) = N(y_k | a * x_k +b, \lambda^{-1}  )
と書くことができます。ここで,この正規分布を真面目に書いてみると
 N(y_k |  a * x_k+b, \lambda^{-1}) = \frac{1}{\sqrt{2\pi\lambda^{-1}}}\exp\big(  -\frac{1}{2}\lambda(y_k-a*x_k-b)^{2}  \big)
となります。
今は aの事後分布だけが気になっているので,計算を簡単にするために事前分布,尤度をそれぞれの対数( \log)を取って, aについて整理していきます。

事前分布

 \log p(a) = \log \frac{1}{\sqrt{2\pi\Lambda_a^{-1}}}\exp\big(  -\frac{1}{2}\Lambda_a(a-m_a)^{2}  \big)
  = -\frac{1}{2} \log (2\pi\Lambda_a^{-1} )  -\frac{1}{2}\Lambda_a(a-m_a)^{2}

第一項は aに関係のない項なので,第二項を展開して, aに関する項を昇順に並べていきます。
 -\frac{1}{2}\Lambda_a(a-m_a)^{2} = -\frac{1}{2}  \Lambda_a a^{2} + \Lambda_a m_a a  -\frac{1}{2} \Lambda_a m_a^{2} (1)
となります。

尤度

次に,先程定義した尤度も同様に計算することで,
 -\frac{1}{2}  \lambda_a x_k ^{2} a^{2} + \lambda (y_k -b) x_k a  -\frac{1}{2} \lambda (y_k -b )^{2} (2)
と書けます。

事前分布と尤度の積

ベイズの定理の分子である p(y_k|x_k, a) p(a)は事前分布と尤度の積ですので,対数を取ったあとは和になります。
なので,(1)と(2)の和を計算して, aについて昇順に並べることで,事後分布の aについてはどのような形になるのか想像することができます。和を取ると  -\frac{1}{2}  ( \Lambda_a  + \lambda_a x_k ^{2} ) a^{2} +( \Lambda_a m_a + \lambda (y_k -b) x_k )a  -\frac{1}{2} ( \Lambda_a m_a^{2} + \lambda (y_k -b )^{2}) (3)
と書けます。
この形は事前分布や尤度で仮定した正規分布と全く同じ形になっていることから,事後分布も正規分布であると考えることができます。*1
分母は,周辺尤度を求めれば,正規化されていると思います。
このあたりの数式は天下り的に計算はしませんが,任意の事後分布の正規化の数式などは例えば,

こちらの本などに書かれているように思えます。
さて,事後分布も正規分布であるとして,一点を観測したあとの傾き aの事後分布は
 p(a| y_k, x_k) = N(a| ( \Lambda_a  + \lambda_a x_k ^{2} ), \frac{\Lambda_a m_a + \lambda (y_k -b) x_k}{( \Lambda_a  + \lambda_a x_k ^{2} )}) (4)
として計算できます。*2

今,観測点が一点だけの場合を考えていましたが,複数観測する場合には,式(3)は更に一般化することができて,  -\frac{1}{2}  ( \Lambda_a  +\sum_k \lambda_a x_k ^{2} ) a^{2} +( \Lambda_a m_a + \sum_k  \lambda (y_k -b) x_k )a  -\frac{1}{2} ( \Lambda_a m_a^{2} + \sum_k \lambda (y_k -b )^{2}) (3)
と,書けます。このとき,(4)は
 p(a| y_k, x_k) = N(a| \frac{\Lambda_a m_a +  \sum_k \lambda (y_k -b) x_k}{( \Lambda_a  +  \sum_k  \lambda_a x_k ^{2} )}, ( \Lambda_a  +  \sum_k  \lambda_a x_k ^{2} )) (5)
となりますので,この式を使うことで,一点一点を更新していく様子がみることができます。

実装

実装コードはこちらにあります。

github.com

傾き  a=5 , 切片  b=3, ノイズが  \sigma=1正規分布に従って生成されるとして,データをランダムに200点ほど作成します。

f:id:Leo0523:20190501223555p:plain
図1: 観測データ

図1のような,データが得られます。

データが得られたので,上の計算で行ったように,切片aをベイズ線形回帰で計算してみます。
事前分布は,中心ピークを4.5,標準偏差  \sigma=2正規分布を仮定します。

f:id:Leo0523:20190501223410p:plain
図2: 事前分布

この仮定の元,上記の計算を実行して,事後分布が毎回どのような形状になるのか計算してみます。
結果は次の動画で,

f:id:Leo0523:20190501224331g:plain
ベイズ更新

一点ごとの更新となりますので,徐々に予測の幅が狭くなり,値が真値である5に収束していく様子が見えます。
これは僕が当初見たかったベイズ線形回帰におけるベイズ更新の様子になります。

現在の仮定では事前分布の幅が広く,ピークも真値に近いので,パラメータを変えて実験を行ってみました。

事前分布:ピーク3, 標準偏差  \sigma=0.5, 観測データ点数100点

f:id:Leo0523:20190501225109g:plain
ベイズ更新2

事前分布:ピーク3, 標準偏差  \sigma=0.1 , 観測データ点数100点

f:id:Leo0523:20190501225247g:plain
ベイズ更新3

標準偏差が小さいときには,事前分布に引っ張られて真の値になかなかたどり着かない様子がこれらの動画から見ることができました。
事前分布,大事!

まとめ

  • 教科書に従って事後分布を解析的に導出した。

  • 導出した式を使ってベイズ更新の様子を描写した。

  • (数式的には自明だが)事前分布の標準偏差が小さいほど真値を予測するには観測データが沢山必要であることが動画からわかった。

*1:ほんまかいな?数学的な証明はわかりません

*2:詳細な計算は須山さんのベイズ推論による機械学習入門を読んでください