I've just tried PyBrain and hoped it could learn the simple linear function f(x) = 4x+1:
# Build the network
from pybrain.tools.shortcuts import buildNetwork
net = buildNetwork(1, 2, 1, bias=True)
# Add samples
from pybrain.datasets import SupervisedDataSet
ds = SupervisedDataSet(1, 1)
for x in range(1000):
ds.addSample((x, ), (4*x+1,))
# Train with samples
from pybrain.supervised.trainers import BackpropTrainer
trainer = BackpropTrainer(net, ds)
for i in range(100):
error = trainer.train()
print("Error: %0.2f" % error)
# See if it remembers
print("Test function f(x)=4x+1")
for i in range(10):
print("f(%i) = %i" % (i, net.activate((i, ))))
But when I execute this I get horrible wrong results:
f(0) = 1962
f(1) = 1962
f(2) = 1962
f(3) = 1962
f(4) = 1962
f(5) = 1962
f(6) = 1962
f(7) = 1962
f(8) = 1962
f(9) = 1962
Why doesn't this work?
Try 2
Code:
# Build the network
from pybrain.tools.shortcuts import buildNetwork
net = buildNetwork(1, 2, 1, bias=True)
# Add samples
from pybrain.datasets import SupervisedDataSet
ds = SupervisedDataSet(1, 1)
for x in range(1000):
ds.addSample((x, ), (4*x+1,))
ds.addSample((x, ), (4*x+1,))
# Train with samples
from pybrain.supervised.trainers import BackpropTrainer
trainer = BackpropTrainer(net, ds, learningrate=0.001, momentum=0.99)
print("Start training")
a = trainer.trainUntilConvergence(dataset=ds,
maxEpochs=100,
verbose=True,
continueEpochs=10,
validationProportion=0.1)
print(a)
print("Finished training")
# See if it remembers
print("Test function f(x)=4x+1")
for i in range(10):
print("f(%i) = %i" % (i, net.activate((i, ))))
Output:
Start training
train-errors: [ 827395.411895 755443.286202 722073.904381 748336.584579
[...]
695939.638106 726953.086185 736527.150008 739789.458146 736074.235677 731222.936020 675937.725009]
valid-errors: [ 2479217.507148 915115.526570 703748.266402 605613.979311 592809.132542 686959.683977 612248.174146
[...]
655606.225724 637762.864477 643013.094767 620825.083765 609063.451602 607935.458244 716839.447374]
([827395.41189463751, 755443.28620243724, 722073.90438077366, 748336.58457926242, 739568.58919456392, 725496.58682491502,
[...]
637762.86447708646, 643013.09476733557, 620825.08376532339, 609063.45160197129, 607935.45824447344, 716839.44737418776])
Finished training
Test function f(x)=4x+1
f(0) = 1955
f(1) = 1955
f(2) = 1955
f(3) = 1955
f(4) = 1955
f(5) = 1955
f(6) = 1955
f(7) = 1955
f(8) = 1955
f(9) = 1955