最急降下法で極小値を求める

関数の極小値を求めるために最急降下法を使ってみます。


y=f(x)の極小値を求めるとして、まず適当な値aをとります。f(x)のx=aでの微分f'(a)をとったとき、これが正であればaの値を減らし、f'(a)が負であればaの値を増やすと極小値に近づくはずというのが、最急降下法のアイデアです。
ということで、プログラムでは、funcメソッドが極小を求める関数、funcdメソッドが微分です。
実行すると、aの初期値によって求まる極限がかわります。初期値によっては、値が発散したり振動したりしてしまうこともあります。つまり、最急降下法では、本当の最適解は求めにくいということです。

import java.awt.*;
import java.awt.image.BufferedImage;
import javax.swing.*;

public class SteepestDescent {
    //求める関数
    static double func(double x){
        return (x - 1)*(x - 1) * (x + 1) * (x + 1) + x / 10;
    }
    //求める関数の微分
    static double dfunc(double x){
        return 4 * x * (x - 1) * (x + 1) + 1. / 10;
    }
    public static void main(String[] args){
        BufferedImage img = new BufferedImage(
                400, 300, BufferedImage.TYPE_INT_RGB);
        final Graphics g = img.getGraphics();
        
        JFrame f = new JFrame("最急降下法");
        f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        f.setSize(400, 300);
        final JLabel lbl = new JLabel(new ImageIcon(img));
        f.add(lbl);
        f.setVisible(true);
        
        new Thread(){
            @Override
            public void run() {
                for(;;){
                double a = Math.random() * 6 - 3;
                double th = 0.0001;//収束したと判断する変化率
                double pre = Double.NaN;
                for(int i = 0; i < 1000; ++i){
                    double k = 0.1;//学習係数
                    double da = dfunc(a);//微分を求める
                    a -= da * k;//ここが最急降下法。微分値に学習係数かけた分ずらす
                    
                    if(Double.isInfinite(a) || Double.isNaN(a)){
                        //発散した
                        break;
                    }
                    if(pre != Double.NaN){
                        double d = pre - a;
                        if(Math.abs(d) < th){
                            //収束した
                            break;
                        }
                        pre = a;
                    }
                    
                    g.setColor(Color.WHITE);
                    g.fillRect(0, 0, 400, 300);
                    //求める式のグラフ
                    for(double x = -10; x < 10; x+= .01){
                        double px = (x + 10) * 30 - 150;
                        double py = 250 - func(x) *100;
                        g.setColor(Color.RED);
                        g.fillOval((int)px, (int)py, 3, 3);
                    }
                    //解の予測値
                    g.setColor(Color.BLUE);
                    g.fillOval(
                            (int)((a + 10) * 30 - 150), 
                            (int)(250 - func(a) * 100), 7, 7);
                    lbl.repaint();
                    try{ Thread.sleep(200);
                    }catch(InterruptedException e){}
                }
                }
            }
        }.start();
    }
}