0


神经网络之BP神经网络

一、BP神经网络的概念

BP神经网络是一种多层的前馈神经网络,其主要的特点是:信号是前向传播的,而误差是反向传播的。具体来说,对于如下的只含一个隐层的神经网络模型:

                                         (三层BP神经网络模型)

BP神经网络的过程主要分为两个阶段,第一阶段是信号的前向传播,从输入层经过隐含层,最后到达输出层;第二阶段是误差的反向传播,从输出层到隐含层,最后到输入层,依次调节隐含层到输出层的权重和偏置,输入层到隐含层的权重和偏置。

二、BP神经网络的流程

在知道了BP神经网络的特点后,我们需要依据信号的前向传播和误差的反向传播来构建整个网络。

1、网络的初始化

2、隐含层的输出

如上面的三层BP网络所示,隐含层的输出![H_{j}](https://latex.codecogs.com/gif.latex?H_%7Bj%7D)为:

3、输出层的输出

4、误差的计算

5、权值的更新

权值的更新公式为:

这里需要解释一下公式的由来:

这是误差反向传播的过程,我们的目标是使得误差函数达到最小值,即minE

  • 隐含层到输出层的权重更新

  • 则权重的更新公式为:

  • 输入层到隐含层的权重更新

其中

则权重的更新公式为:

6、偏置的更新

偏置的更新公式为:

  • 隐含层到输出层的偏置更新

则偏置的更新公式为:

  • 输入层到隐含层的偏置更新

其中

则偏置的更新公式为:

7、判断算法迭代是否结束

有很多的方法可以判断算法是否已经收敛,常见的有指定迭代的代数,判断相邻的两次误差之间的差别是否小于指定的值等等。

**代码 **

package machinelearning.ann;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.Instances;

/**
 * General ANN. Two methods are abstract: forward and backPropagation.
 * 
 * @author Rui Chen [email protected].
 */
public abstract class GeneralAnn {

    /**
     * The whole dataset.
     */
    Instances dataset;

    /**
     * Number of layers. It is counted according to nodes instead of edges.
     */
    int numLayers;

    /**
     * The number of nodes for each layer, e.g., [3, 4, 6, 2] means that there
     * are 3 input nodes (conditional attributes), 2 hidden layers with 4 and 6
     * nodes, respectively, and 2 class values (binary classification).
     */
    int[] layerNumNodes;

    /**
     * Momentum coefficient.
     */
    public double mobp;

    /**
     * Learning rate.
     */
    public double learningRate;

    /**
     * For random number generation.
     */
    Random random = new Random();

    /**
     ********************
     * The first constructor.
     * 
     * @param paraFilename
     *            The arff filename.
     * @param paraLayerNumNodes
     *            The number of nodes for each layer (may be different).
     * @param paraLearningRate
     *            Learning rate.
     * @param paraMobp
     *            Momentum coefficient.
     ********************
     */
    public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
            double paraMobp) {
        // Step 1. Read data.
        try {
            FileReader tempReader = new FileReader(paraFilename);
            dataset = new Instances(tempReader);
            // The last attribute is the decision class.
            dataset.setClassIndex(dataset.numAttributes() - 1);
            tempReader.close();
        } catch (Exception ee) {
            System.out.println("Error occurred while trying to read \'" + paraFilename
                    + "\' in GeneralAnn constructor.\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Accept parameters.
        layerNumNodes = paraLayerNumNodes;
        numLayers = layerNumNodes.length;
        // Adjust if necessary.
        layerNumNodes[0] = dataset.numAttributes() - 1;
        layerNumNodes[numLayers - 1] = dataset.numClasses();
        learningRate = paraLearningRate;
        mobp = paraMobp;    
    }//Of the first constructor    
    
    /**
     ********************
     * Forward prediction.
     * 
     * @param paraInput
     *            The input data of one instance.
     * @return The data at the output end.
     ********************
     */
    public abstract double[] forward(double[] paraInput);

    /**
     ********************
     * Back propagation.
     * 
     * @param paraTarget
     *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
     *            
     ********************
     */
    public abstract void backPropagation(double[] paraTarget);

    /**
     ********************
     * Train using the dataset.
     ********************
     */
    public void train() {
        double[] tempInput = new double[dataset.numAttributes() - 1];
        double[] tempTarget = new double[dataset.numClasses()];
        for (int i = 0; i < dataset.numInstances(); i++) {
            // Fill the data.
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Fill the class label.
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset.instance(i).classValue()] = 1;

            // Train with this instance.
            forward(tempInput);
            backPropagation(tempTarget);
        } // Of for i
    }// Of train

    /**
     ********************
     * Get the index corresponding to the max value of the array.
     * 
     * @return the index.
     ********************
     */
    public static int argmax(double[] paraArray) {
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray.length; i++) {
            if (tempMax < paraArray[i]) {
                tempMax = paraArray[i];
                resultIndex = i;
            } // Of if
        } // Of for i

        return resultIndex;
    }// Of argmax

    /**
     ********************
     * Test using the dataset.
     * 
     * @return The precision.
     ********************
     */
    public double test() {
        double[] tempInput = new double[dataset.numAttributes() - 1];

        double tempNumCorrect = 0;
        double[] tempPrediction;
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset.numInstances(); i++) {
            // Fill the data.
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Train with this instance.
            tempPrediction = forward(tempInput);
            //System.out.println("prediction: " + Arrays.toString(tempPrediction));
            tempPredictedClass = argmax(tempPrediction);
            if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
                tempNumCorrect++;
            } // Of if
        } // Of for i

        System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

        return tempNumCorrect / dataset.numInstances();
    }// Of test
}//Of class GeneralAnn

本文转载自: https://blog.csdn.net/Rick_rui/article/details/125006626
版权归原作者 Rick_rui 所有, 如有侵权,请联系我们删除。

“神经网络之BP神经网络”的评论:

还没有评论