Untitled

446 days ago by pub

Hiroshi TAKEMOTO (take@pwv.co.jp)

Sageで共役勾配法を試す

PRMLの第5章のニューラルネットワークの「勾配降下最適化」で「共役勾配法」という 方式がでてきたので、学習がてらsageを使って試してみます。

共役勾配法

朱鷺の杜Wiki の説明とプログラムの変数を合わせると、以下のような2時形式の関数を考える $$ f(x) = \frac{1}{2} x^T A X - b^T x $$ この時、極値に達するには、勾配▽fからある程度接線方向tにずれた共役勾配d 方向に進まなくてはならない。(収束がほぼ直角に折れながら進んでいることに注意)

従って、 $$ d_n = - \nabla f(x_n) + \beta_n d_{n-1} $$ と書けます(一つ前の$d_{n-1}$が接線方向tであることに注意)。$\beta_n$は $$ \beta_n = \frac{(\nabla f(x_n))^T \nabla f(x_n)}{(\nabla f(x_{n-1}))^T \nabla f(x_{n-1})} $$ となり、dの初期値は$d_0 = - \nabla f(x_0)$から始めます。

$x$は刻み値$\alpha$、 $$ \alpha_n = - \frac{d_n^T \nabla f(x_n)}{d_n^T A d_n} $$ を使って次式で更新します。 $$ x_{n+1} = x_n + \alpha_n d_n $$

# 共役勾配法(conjugate gradient method) # f(x) = 1/2 x^T A x - b^T x # d_0 = -▽f(x_0) # d_n = -▽f(x_n)+(▽f(x_n)^T ▽f(x_n))/(▽f(x_n_1)^T ▽f(x_n_1)) d_n_1 # x_n_p1 = x_n + t_n d_n # t_n = - (d_n^T ▽f(x_n))/(d_n^R A d_n) 
       

関数、変数の定義

syou6162さんの最適化理論][R]共役勾配法を実装してみた の例題をSageを使って試してみます。

数式処理システムSageの特徴を活かすため、関数の引数をベクトルとし、var関数でベクトルの要素を宣言し、 ベクトルvにセットします。

つぎに関数fを以下のように定義します。 $$ f(x) = \frac{3}{2} x_1^2 + x_1 x_2 + x_2^2 - t x_1 - 7 x_2 $$

# 変数定義 vars = var('x1 x2'); v = vector([x1, x2]); 
       
# 参考サイトhttp://d.hatena.ne.jp/syou6162/20090926/1253950932 # 例題の関数:f = 3/2*x1^2 + x1*x2 + x2^2 - 6*x1 -7*x2 # fを定義 def f(v): return 3/2 * v[0]^2 + v[0]*v[1] + v[1]^2 - 6*v[0] - 7*v[1]; 
       

▽fの計算

▽fの計算は、ちょっとトリックを使います。あらかじめ関数fの 各変数での偏微分をdfsに保持しておき、その結果に引数のベクトル vxの値を代入した結果を返しています。

# fを偏微分したリスト dfs = [diff(f(v), x_i) for x_i in v]; # ▽fを定義 (dfsにvxの要素の値を適応した結果を返す) def nabla_f(vx): # ベクトルvxの各要素の値をvの要素に対応づける s = dict(zip(v, vx)); # ベクトルの各要素の偏微分の結果にsを適応させる return vector([df.subs(s) for df in dfs]); 
       

ヘッセ行列

ヘッセ行列もSageの数式機能を使えば、簡単にもとめることができます。

# ヘッセ行列 H = matrix([[diff(diff(f(v),x_i), x_j) for x_i in v] for x_j in v]); print jsmath(H); 
       
\newcommand{\Bold}[1]{\mathbf{#1}}\left(\begin{array}{rr} 3 & 1 \\ 1 & 2 \end{array}\right)
\newcommand{\Bold}[1]{\mathbf{#1}}\left(\begin{array}{rr} 3 & 1 \\ 1 & 2 \end{array}\right)
# α_kの定義 def alpha_k(x, d): return -d.dot_product(nabla_f(x)) / (d * H * d); 
       

共役勾配法の反復処理

共役勾配法の反復処理は、至って単純です。条件を満たすまで与えられた式でxと共役勾配を 更新するだけです。

# 共役勾配法の反復処理 eps = 0.001; x0 = vector([2, 1]); d = - nabla_f(vx=x0); x = x0; k = 1; while (true): o_nabla_f_sqr = nabla_f(x).dot_product(nabla_f(x)); o_x = x; x += alpha_k(x, d)*d; if ((x - o_x).norm() < eps): break; beta = nabla_f(x).dot_product(nabla_f(x)) / o_nabla_f_sqr; d = -nabla_f(x) + beta*d; if (d.norm() == 0): # 0割り対策 break; k += 1; print "x=", x; print "k=", k; 
       
x= (1, 3)
k= 2
x= (1, 3)
k= 2

Sageの最適化機能で結果を検証

求まった解x= (1, 3)をSageの最適化機能で求めた結果と比較します。当然同じ結果になります。

# 同様の処理をsageの機能を使って計算してみる g = 3/2*x1^2 + x1*x2 + x2^2 - 6*x1 -7*x2; minimize(g, [2, 1], algorithm="cg") 
       
Optimization terminated successfully.
         Current function value: -13.500000
         Iterations: 2
         Function evaluations: 5
         Gradient evaluations: 5
(1.0, 3.0)
Optimization terminated successfully.
         Current function value: -13.500000
         Iterations: 2
         Function evaluations: 5
         Gradient evaluations: 5
(1.0, 3.0)

解のプロット

Sageの強みは簡単に結果をグラフ化できることです。

# 関数と解をプロット p3d = plot3d(g, [x1, -1, 4], [x2, -1, 4]); pt = point([1, 3, f(x)], color='red'); show(p3d+pt); 
       
# 初期値を変えて収束の様子をプロットしながら再計算 pts = pt pt2s = pt2 eps = 0.001 x0 = vector([-0.5, 1]) d = - nabla_f(vx=x0) x = x0 k = 1 while (true): o_nabla_f_sqr = nabla_f(x).dot_product(nabla_f(x)) o_x = x pts += point([o_x[0], o_x[1], f(o_x)], color='yellow') pt2s += point([o_x[0], o_x[1]], color='blue') x += alpha_k(x, d)*d if ((x - o_x).norm() < eps): break beta = nabla_f(x).dot_product(nabla_f(x)) / o_nabla_f_sqr d = -nabla_f(x) + beta*d if (d.norm() == 0): # 0割り対策 break k += 1 pts += point([x[0], x[1], f(x)], color='red') pt2s += point([x[0], x[1]], color='red') (p3d + pts).show() 
       

コンター図で解の収束を見る

共役勾配法での解の収束の様子をコンターズで示します。 朱鷺の杜Wiki の通り、2回の計算で収束しているのが分かります。

# コンターズで共役勾配と接線の法線方向の関係、および2回で収束することを示す。 cnt_plot = contour_plot(g, [x1, -1, 4], [x2, -1, 4], fill=False) (cnt_plot + pt2s).show()