I have implemented back propagation algorithm to train my neural network. It solves AND & OR perfectly, but when I try to train to solve XOR, the total error is really high.
The network topology for XOR network is : 2 neurons at input layer, 2 neurons at the hidden layer, and one neuron at the output layer.
I'm using sigmoid as my activation function, and weighted sum as input.
Here is the part of my code responsible for back propagation:
protected void updateOutputLayer(double[] outputErr)
{
double delta;
Neuron neuron;
double errorDerivative;
for ( int i=0;i<this.getNeuralNetwork().getOutputLayer().getSize();i++)
{
neuron=this.getNeuralNetwork().getOutputLayer().getAt(i);
errorDerivative=neuron.getTransferFunction().getDerivative(neuron.getNetInput());
delta=outputErr[i]*errorDerivative;
neuron.setDelta(roundThreeDecimals(delta));
// now update the weights
this.updateNeuronWeights(neuron);
}
}
protected void updateHiddenLayerNeurons()
{
List<Layer> layers=this.network.getLayers();
Layer currentLayer;
double neuronErr;
for ( int i=layers.size()-2;i>0;i--)
{
currentLayer= layers.get(i);
for (int j=0;j<currentLayer.getSize();j++)
{
neuronErr=calculateHiddenLayerError(currentLayer.getAt(j));
currentLayer.getAt(j).setDelta(neuronErr);
this.updateNeuronWeights(currentLayer.getAt(j));
}
}
//System.out.println("*****************************************");
}
protected double calculateHiddenLayerError(Neuron node)
{
List<Connection> outputCon= node.getOutputConnections();
double errFactor=0;
for (Connection outputCon1 : outputCon) {
//System.out.println("output od dst: "+outputCon1.getDst().getOutput());
// System.out.println("w dst: "+outputCon1.getWeight());
//System.out.println("in CalcErr Factor err: "+outputCon.get(i).getDst().getError()+" w: "+outputCon.get(i).getWeight());
errFactor += outputCon1.getDst().getDelta() * outputCon1.getWeight();
}
double derivative= node.getTransferFunction().getDerivative(node.getNetInput());
return roundThreeDecimals(derivative*errFactor);
}
public void updateNeuronWeights(Neuron neuron)
{
double weightChange;
double input, error;
for (Connection con: neuron.getInConnections())
{
input=con.getInput();
// System.out.println("input: "+input);
error = neuron.getDelta();
weightChange=this.learningRate*error*input;// error here is : output error * error derivative
con.setWeight(roundThreeDecimals(con.getWeight()+weightChange));
}
// now update bias
if(neuron.isBiasUsed())
{
//System.out.println("old bias: "+neuron.getBias());
double biasChange=neuron.getBias()+neuron.getDelta()*this.learningRate;
//System.out.println("new bias: "+biasChange);
neuron.setBias(roundThreeDecimals(biasChange));
}
}
I'm using a learning rate in the range [0.01,0.5]. Can anyone tell me what is wrong with my code?