じんべえざめのノート

仮想生物と強化学習と深層学習に興味がある大学院生のブログ。

KelpNetでXOR

今回は、KelpNetを用いてXORを学習していきます。XORは、x_1x_2の片方のみ1のときのみ1になる論理回路です。ANDやORとあわせて表にすると、以下のようになります。今回は、このXORを1か0かの二値分類問題として扱います。今回用いるコードは、KelpNetのプロジェクト内のTest2を参考にして書かれています。

x_1 x_2 OR AND XOR
0 0 0 0 0
0 1 1 0 1
1 0 1 0 1
1 1 1 1 0

この記事は深層学習ライブラリを初めて使う方・他の深層学習ライブラリを使っているけどKelpNetも気になる方など対象にしており、ニューラルネットワークの基礎が分かれば理解が出来る内容になっています。この記事では、KelpNetを用いてXORを学習するニューラルネットの実装を「ネットワークの定義」「学習部分」「学習結果の表示」「学習後の重みの表示」の順で説明しています。

ニューラルネットワークの基礎・KelpNetとインストール方法については、それぞれ以下の記事で紹介しています。また、全体のコードは一番下に載っています。

ネットワークの定義

今回は、下図のような中間層が1層でユニット数が2のネットワークを用います。図ではバイアスの表記は省略されています。入力はx_1x_2のため入力層のユニット数は2、出力は0~1の1個の実数であるため出力層のユニット数は1となっています。活性化関数は中間層と出力層ともにシグモイド関数を用います。このように、二値分類の際は出力層の活性化関数にはシグモイド関数がよく用いられます。

f:id:jinbeizame007:20180805235404p:plain

ネットワークを定義するコードは以下のようになります。Linear()は全結合層を示しており、第1引数が入力の数、第2引数が出力の数、第3引数が関数の名前となっています。また、Sigmoid()はシグモイド関数を示しており、引数は関数の名前となっています。

ネットワークを定義した後に、SetOptimizerを用いることでネットワークの重みの更新に用いる最適化手法を適用することが出来ます。最適化手法は、SGD以外にもAdamやRMSPropなど様々な手法が用意されています。

// ネットワークの構成を FunctionStack に書き連ねる
FunctionStack model = new FunctionStack(
    new Linear(2, 2, name: "l1 Linear"), // 入力2出力2の全結合層
    new Sigmoid(name: "l1 Sigmoid"), // シグモイド関数
    new Linear(2, 1, name: "l2 Linear"), // 入力2出力1の全結合層
    new Sigmoid(name: "l2 Sigmoid")  // シグモイド関数
);

// optimizerの宣言
model.SetOptimizer(new SGD());

学習

データと学習部分のコードは以下のようになります。KelpNetは、Trainer.Trainを用いることで、順伝播・損失の計算・誤差逆伝播・重みの更新を1行でまとめて記述することが出来ます。また、それぞれを別々に記述する方法もあり、こちらで紹介しています。誤差関数には二乗誤差関数 (MeanSquaredError) を用いています。

  • Trainer.Train(functionStack, input, teach, lossFunction, isUpdate = true)

    • Parameters
      • functionStack (FunctionStack): ネットワークのモデル
      • input (NdArray): ネットワークへの入力
      • teach (NdArray): 教師データ
      • lossFunction (LossFunction): 損失関数
      • isUpdate (bool): ネットワークの更新を行うかどうか
    • Return
      • sumLoss (Real): ミニバッチ内の各データの損失の合計
// 入力データ
Real[][] x =
{
    new Real[] { 0, 0 },
    new Real[] { 1, 0 },
    new Real[] { 0, 1 },
    new Real[] { 1, 1 }
};

// 教師データ
Real[][] target =
{
    new Real[] { 0 },
    new Real[] { 1 },
    new Real[] { 1 },
    new Real[] { 0 }
};

// 学習回数
const int EPOCH = 10000; 

// 順伝播・損失の計算・誤差逆伝播・重みの更新を行う
for (int ep = 0; ep < EPOCH; ep++) // 全データをEPOCHの回数分学習する
{
    for (int i = 0; i < x.Length; i++) // データの個数分繰り返す
    {
    Trainer.Train(model, new NdArray(x[i]), new NdArray(target[i]), new MeanSquaredError());
    }
}

学習結果

以下のコードを用いて学習結果を表示します。入力データxから1つずつデータを取り出し、Predictを用いて各入力データを伝播させた後、ネットワークの出力を表示しています。

//訓練結果を表示
foreach (Real[] input in x)
{
    NdArray output = model.Predict(input)[0];
    Console.WriteLine(input[0] + " xor " + input[1] + " = " + (output.Data[0] > 0.5 ? 1 : 0) + " " + output);
}

プログラムの出力は以下のようになります。XORが学習出来ていることが分かります。

0 xor 0 = 0 [0.03538282]
1 xor 0 = 1 [0.96938088]
0 xor 1 = 1 [0.96969101]
1 xor 1 = 0 [0.03145395]

学習後の重み

以下のコードを用いてネットワークの重みとバイアスを表示し、XORがニューラルネットワークでどのように表現されているかを確認します。

Linear l1 = (Linear)model.Functions[0];
Console.WriteLine("l1 Weight");
Console.WriteLine(l1.Weight);
Console.WriteLine("l1 Bias");
Console.WriteLine(l1.Bias);

Linear l2 = (Linear)model.Functions[2];
Console.WriteLine("l2 Weight");
Console.WriteLine(l2.Weight);
Console.WriteLine("l2 Bias");
Console.WriteLine(l2.Bias);

出力は以下のようになります。

l1 Weight
[[ 5.28109151 -5.49548032]
 [-5.71896640  5.56578034]]
l1 Bias
[-2.96292368 -3.13341178]
l2 Weight
[[8.24473497 8.17577662]]
l2 Bias
[-4.05190012]

このままの数字で考えると少し複雑なため、大雑把ではありますが以下のように四捨五入してどのように表現されているかを考えていきます。

\begin{align} w^{(1)}_{11} &= 5.0, w^{(1)}_{12} = -5.0, \\ w^{(1)}_{21} &= -5.0, w^{(1)}_{22} = 5.0, \\ b^{(1)}_1 &= -3.0, \\ b^{(1)}_2 &= -3.0, \\ \\ w^{(2)}_{11} &= 8.0, w^{(1)}_{12} = 8.0 \\ b^{(2)}_1 &= -4.0, \\ \end{align}

よって、中間層の出力s_1, s_2と出力層の出力yは以下のようになります。

\begin{align} s_1 &= f(x_1w^{(1)}_{11} + x_2w^{(1)}_{12} + b^{(1)}_1) \\ &= f(5.0x_1 - 5.0x_2 - 3.0) \\ \\ s_2 &= f(x_1w^{(1)}_{21} + x_2w^{(1)}_{22} + b^{(1)}_2) \\ &= f(-5.0x_1 + 5.0x_2 - 3.0) \\ \\ y &= f(s_1w^{(2)}_{11} + s_2w^{(2)}_{12} + b^{(2)}_1) \\ &=f( 8.0s_1 + 8.0s_2 - 4.0) \end{align}

ここで、f()は活性化関数のシグモイド関数を示しています。シグモイド関数は、下図のような関数です。

\begin{align} y = \frac{1}{1+e^{-x}} \end{align}

f:id:jinbeizame007:20180821152332p:plain

これらの式のx_1, x_2に値を代入したものを表にすると以下のようになります。

x_1 x_2 s_1 s_2 y
0 0 f(-3.0) \fallingdotseq 0.05 f(-3.0) \fallingdotseq 0.05 f(-3.2) \fallingdotseq 0.03
0 1 f(-8.0) \fallingdotseq 0.0 f(2.0) \fallingdotseq 0.9 f(3.8) \fallingdotseq 0.96
1 0 f(2.0) \fallingdotseq 0.9 f(-8.0) \fallingdotseq 0.0 f(3.8) \fallingdotseq 0.96
1 1 f(-3.0) \fallingdotseq 0.05 f(-3.0) \fallingdotseq 0.05 f(-3.2) \fallingdotseq 0.03

大雑把な値ではありますが、表に出来ました。この表から、s_1x_2のみ1であるときに発火し、s_2x_1のみ1であるときに発火することが分かります。また、ys_1s_2のどちらかが発火した場合に発火することが分かります。よって、このニューラルネットワークx_1=0, x_2=1またはx_1=1, x_2=0のときに1に近い値を、それ以外のときに0に近い値を出力するニューラルネットワークということが分かります。

最後に全体のコードを載せておきます。

using System;
using KelpNet.Loss;
using KelpNet.Common;
using KelpNet.Common.Functions.Container;
using KelpNet.Common.Tools;
using KelpNet.Functions.Activations;
using KelpNet.Functions.Connections;
using KelpNet.Optimizers;

class XOR
{
    public static void Main()
    {

        // 入力データ
        Real[][] x =
        {
             new Real[] { 0, 0 },
             new Real[] { 1, 0 },
             new Real[] { 0, 1 },
             new Real[] { 1, 1 }
         };

        // 教師データ
        Real[][] target =
        {
             new Real[] { 0 },
             new Real[] { 1 },
             new Real[] { 1 },
             new Real[] { 0 }
         };

        // ネットワークの構成を FunctionStack に書き連ねる
        FunctionStack model = new FunctionStack(
            new Linear(2, 2, name: "l1 Linear"), // 入力2出力2の全結合層
            new Sigmoid(name: "l1 Sigmoid"), // シグモイド関数
            new Linear(2, 1, name: "l2 Linear"), // 入力2出力1の全結合層
            new Sigmoid(name: "l2 Sigmoid")  // シグモイド関数
        );

        // optimizerの宣言
        model.SetOptimizer(new SGD());

        // 学習回数
        const int EPOCH = 10000;

        // 順伝播・損失の計算・誤差逆伝播・重みの更新を行う
        for (int ep = 0; ep < EPOCH; ep++) // 全データをEPOCHの回数分学習する
        {
            for (int i = 0; i < x.Length; i++) // データの個数分繰り返す
            {
                Trainer.Train(model, new NdArray(x[i]), new NdArray(target[i]), new MeanSquaredError());
            }
        }

        // 学習結果を表示
        foreach (Real[] input in x)
        {
            NdArray output = model.Predict(input)[0];
            Console.WriteLine(input[0] + " xor " + input[1] + " = " + (output.Data[0] > 0.5 ? 1 : 0) + " " + output);
        }
        Console.WriteLine("");

        // 学習後の重みを表示
        Linear l1 = (Linear)model.Functions[0];
        Console.WriteLine("l1 Weight");
        Console.WriteLine(l1.Weight);
        Console.WriteLine("l1 Bias");
        Console.WriteLine(l1.Bias);

        Linear l2 = (Linear)model.Functions[2];
        Console.WriteLine("l2 Weight");
        Console.WriteLine(l2.Weight);
        Console.WriteLine("l2 Bias");
        Console.WriteLine(l2.Bias);
    }
}