パーセプトロンの収束に学習率は関係ない件

(2012-06-20追記) ただし,重みベクトルの初期値が0のときだけということの記述忘れ.これについて以下の記事参照 (パーセプトロンの収束に学習率が関係ないのは初期値が0のときだけ)

前回の記事に @shuyo さんから以下のコメントを頂いた.前半部分を引用する.

「学習率を適当な値に設定すると経験的には収束が速くなる」とありますが、本当にそうでしょうか?
重みの初期値は0ですから、更新式よりη倍した解が求まるだけで、収束性に影響があるように思えません。

はい,そのとおりです.重みベクトルは定数倍しても分離超平面は変化しない.
そのため,学習率がいくつになろうと,@shuyo さんのおっしゃるようにη倍した重みベクトルが得られるだけで,収束に影響しない.というわけで「経験的には適当な学習率を設定した方が収束が速い」というのは誤り.お詫びして訂正します.そして @shuyo さんありがとうございます.

検証

というわけで検証するまでもないのだけれど,体験してみることが大切な気がするので線形分離可能なデータセットを作成し,異なる学習率を設定したパーセプトロンで収束までの誤り回数と最終的な重みベクトルを表示するプログラムを書いて実行してみた.(実験に用いたコードは末尾に記載)

結果は以下のとおり.上から学習率が100, 10, 1, 0.1, 0.01, 0.001.

% python gen_data.py > toy.txt
% python perceptron.py toy.txt
Converged.
Error count=56621
[3441900.0, -356.84360002073299, -19582.371400038977]
Converged.
Error count=56621
[344190.0, -35.68436000034535, -1958.2371400047396]
Converged.
Error count=56621
[34419.0, -3.5684360000869333, -195.82371400042496]
Converged.
Error count=56621
[3441.8999999979651, -0.356843600001298, -19.582371400016243]
Converged.
Error count=56621
[344.18999999983197, -0.035684359999769488, -1.9582371400026064]
Converged.
Error count=56621
[34.419000000010485, -0.0035684360000880627, -0.19582371400070708]

結果などみなくてもわかるとおり,全て同じ誤り回数で収束していることがわかる.

ただし,学習率が大きいと重みベクトルの大きさが激しいことになっている.正則化の観点からすると,ノルムは小さい方がよい.また,線形分離可能でないときには偉いことになっている重みベクトルを得ることがあるので,実用上は学習率は小さめに設定しておいた方がよいと思うのだけれど,それは正則化についてきちんと頭の整理がついてから,また別の機会に書くことにする.

コメントされて気が付いたときは布団の上で暴れるほど恥ずかしくて死にそうだったけれど,大恥をかく前に知ることができてよかったと思う.@shuyo さんに改めて感謝.

利用したコード

  • gen_date.py
# -*- coding:utf-8 -*-

import random
import math

def box_muller ():
    x  = random.random()
    y  = random.random()
    z1 = math.sqrt( - 2.0 * math.log(x) ) * math.cos( 2.0 * math.pi * y )
    z2 = math.sqrt( - 2.0 * math.log(x) ) * math.sin( 2.0 * math.pi * y )
    return (z1, z2)

def gauss_rand_2d (mu1, sigma1, mu2, sigma2):
    pair = box_muller()
    return ( (sigma1 * pair[ 0 ]) + mu1, (sigma2 * pair[ 1 ]) + mu2 )

def gen_data (settings):
    pos_num = settings['pos_num']
    pos_xmu = settings['pos_xmu']
    pos_xsigma = settings['pos_xsigma']
    pos_ymu = settings['pos_ymu']
    pos_ysigma = settings['pos_ysigma']

    neg_num = settings['neg_num']
    neg_xmu = settings['neg_xmu']
    neg_xsigma = settings['neg_xsigma']
    neg_ymu = settings['neg_ymu']
    neg_ysigma = settings['neg_ysigma']

    for i in range(0, pos_num):
        xy_pair = gauss_rand_2d( pos_xmu, pos_xsigma, pos_ymu, pos_ysigma )
        print "1 1:%f 2:%f" % (xy_pair[0], xy_pair[1])

    for i in range(0, neg_num):
        xy_pair = gauss_rand_2d( neg_xmu, neg_xsigma, neg_ymu, neg_ysigma )
        print "-1 1:%f 2:%f" % (xy_pair[0], xy_pair[1])


if __name__ == '__main__':
    settings = {'pos_num':50, 'pos_xmu':50, 'pos_xsigma':50, 'pos_ymu':50, 'pos_ysigma':50,
                'neg_num':50, 'neg_xmu':300, 'neg_xsigma':50, 'neg_ymu':300, 'neg_ysigma':50}
    gen_data( settings )
  • perceptron.py
# -*- coding: utf-8 -*-

import sys
import random

class Instance:

    def __init__ (self, line):
        self.label = 0
        self.fvec  = []
        self.parse_line( line )


    def parse_line (self, line):
        pair_list = line.split(' ')
        label = int( pair_list.pop( 0 ) )

        if label != -1 and label != 1:
            print "Label must be +1 or -1"
            sys.exit(1)

        self.label = label

        for pair in pair_list:
            if pair.find(':') < 0:
                continue

            fid, fval = pair.split(':')
            self.fvec.append( (int(fid), float(fval)) )


    def get_max_fid (self):
        return self.fvec[ -1 ][ 0 ]


class Instances:

    def __init__ (self, filename):
        self.ins_list = []
        self.load_file( filename )


    def load_file (self, filename):
        max_fid = -1
        for line in open( filename, 'r' ):
            ins = Instance( line )
            cur_max_fid = ins.get_max_fid()
            if cur_max_fid > max_fid:
                max_fid = cur_max_fid
            self.ins_list.append( ins )
        self.max_fid = cur_max_fid


    def get_random_instance (self):
        idx = random.randint(0, len(self.ins_list) - 1)
        return self.ins_list[ idx ]


class Perceptron:

    def __init__ (self, inss, iternum = 1000, eta = 0.01):
        self.instances = inss
        self.wvec_size = self.instances.max_fid + 1 # Consider bias term
        self.wvec      = [ 0.0 ] * self.wvec_size
        self.iternum   = iternum
        self.eta       = eta


    def check_convergence (self, iternum, eta):
        # Initialize weight vector and eta value
        self.wvec      = [ 0.0 ] * self.wvec_size
        self.eta       = eta

        convergence_flag  = True
        error_count = 0
        for i in range(0, iternum):
            convergence_flag = True
            for ins in self.instances.ins_list:
                wx = self.inner_prod( ins )
                if ins.label * wx <= 0:
                    # Error
                    self.update_weight( ins )
                    error_count += 1
                    convergence_flag = False

            if convergence_flag:
                break

        if convergence_flag:
            print "Converged."
        else:
            print "Failed to converge."
        print "Error count=%d" % (error_count)
        print self.wvec


    def train (self):
        for i in range(0, self.iternum):
            ins = self.instances.get_random_instance()
            wx = self.inner_prod( ins )

            if ins.label * wx <= 0:
                self.update_weight( ins )


    def predict (self, instance):
        return 1 if self.inner_prod( instance ) >= 0 else -1


    def inner_prod (self, instance):
        fx = self.wvec[ 0 ] # bias term
        for pair in instance.fvec:
            fx += self.wvec[ pair[ 0 ] ] * pair[ 1 ]

        return fx


    def update_weight (self, instance):
        self.wvec[ 0 ] += self.eta * instance.label

        for pair in instance.fvec:
            self.wvec[ pair[ 0 ] ] += self.eta * instance.label * pair[ 1 ]


if __name__ == '__main__':
    argv = sys.argv
    argc = len( sys.argv )

    if (argc < 2):
        print "Input filename"
        quit()

    inss = Instances( argv[ 1 ] )
    p = Perceptron( inss )

    p.check_convergence( 100000, 100 )
    p.check_convergence( 100000, 10 )
    p.check_convergence( 100000, 1 )
    p.check_convergence( 100000, 0.1 )
    p.check_convergence( 100000, 0.01 )
    p.check_convergence( 100000, 0.001 )