3

I'm trying to train a Restricted Boltzmann Machine (RBM) with DeepLearning4J 0.7 but without success. All the examples I found are either not doing anything useful or not working anymore with DeepLearning4J 0.7.

I need to train a single RBM with Contrastive Divergence and then compute the reconstruction error.

Here is what I have so far:

import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.api.Layer;
import static org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit;
import static org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;

public class experiment3 {
    private static final Logger log = LoggerFactory.getLogger(experiment3.class);

    public static void main(String[] args) throws Exception {
        DataSetIterator mnistTrain = new MnistDataSetIterator(100, 60000, true);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .regularization(false)
            .iterations(1)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .list()
            .layer(0, new RBM.Builder()
                    .nIn(784).nOut(500)
                    .weightInit(WeightInit.XAVIER)
                    .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
                    .updater(Updater.NESTEROVS)
                    .learningRate(0.1)
                    .momentum(0.9)
                    .k(1)
                    .build())
            .pretrain(true).backprop(false)
            .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(600));

        for(int i = 0; i < 50; i++) {
            model.fit(mnistTrain);
        }
    }
}

It compiles and print some score at each epoch, but the score augments when it should be diminishing and I have not found any way to do reconstruction.

I have tried to use the reconstruct function and compute the distance:

        while(mnistTrain.hasNext()){
            DataSet next = mnistTrain.next();
            INDArray in = next.getFeatureMatrix();
            INDArray out = model.reconstruct(in, 1); // tried with 0 but arrayindexoutofbounds

            log.info("distance(1):" + in.distance1(out));
        }

but the distance is always 0.0 for each element even when the model has not been trained for a single epoch, which is impossible.

Is this the correct way of training a RBM ? How can I reconstruct input with a single RBM ?

Baptiste Wicht
  • 7,472
  • 7
  • 45
  • 110

3 Answers3

1

I've spoken to Adam Gibson (the author) on the project's Gitter channel about this issue. He's saying that they've actually dropped the support for RBM in all but the codebase, so any kind of RBM bugs can happen and won't be fixed.

The reason they are dropping it is because RBM has been largely superseded by VAE (Variational Auto-Encoder) in general, so they push people to use that instead.

Nikita Volkov
  • 42,792
  • 11
  • 94
  • 169
1

Dropping RBMs or other types of NNs means DL4J is not supporting all Deep Learning use cases. Unfortunately DL4J is not a complete DL library, not to mention the lack of documentation as compared to the many Python alternatives.

me._
  • 51
  • 1
  • 8
1

I realize that the question is old but the recent activity revealed it in my flow. I'd just like to say that I've been using RBM's recently in DL4j, both single and multiple layers. They might not be officially supported, but they work. I also stumbled on reconstruct. To test the model, you should use output, like in a FF network. In your case, I assume it would be:

INDArray in = next.getFeatureMatrix();
INDArray out = model.output(in);

A couple of additions:

I'm using 0.9.1

reden
  • 968
  • 7
  • 14