Ethan!
First of all, Deeplearning4j uses ND4j as backend, so your data will have to eventually be converted into INDArray
objects in order to be used in your model. If your trianing data is two array of doubles, inputsArray
and desiredOutputsArray
, you can do the following:
INDArray inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
INDArray desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});
And then you can train your model using those vectors directly:
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(inputs, desiredOutputs);
Alternatively you can create a DataSet
object and used it for training:
DataSet ds = new DataSet(inputs, desiredOutputs);
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(ds);
But creating a custom iterator is the safest approach, specially in larger sets since it gives you more control over your data and keep things organized.
In your DataSetIterator
implementation you must pass your data and in the implementation of the next()
method you should return a DataSet
object comprising the next batch of your training data. It would look like this:
public class MyCustomIterator implements DataSetIterator {
private INDArray inputs, desiredOutputs;
private int itPosition = 0; // the iterator position in the set.
public MyCustomIterator(float[] inputsArray,
float[] desiredOutputsArray,
int numSamples,
int inputDim,
int outputDim) {
inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});
}
public DataSet next(int num) {
// get a view containing the next num samples and desired outs.
INDArray dsInput = inputs.get(
NDArrayIndex.interval(itPosition, itPosition + num),
NDArrayIndex.all());
INDArray dsDesired = desiredOutputs.get(
NDArrayIndex.interval(itPosition, itPosition + num),
NDArrayIndex.all());
itPosition += num;
return new DataSet(dsInput, dsDesired);
}
// implement the remaining virtual methods...
}
The NDArrayIndex
methods you see above are used to access parts of a INDArray
. Then now you can use it for training:
MyCustomIterator it = new MyCustomIterator(
inputs,
desiredOutputs,
numSamples,
inputDim,
outputDim);
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(it);
This example will be particularly useful to you, since it implements a LSTM network and it has a custom iterator implementation (which can be a guide for implementing the remaining methods). Also, for more information on NDArray
, this is helpful. It gives detailed information on creating, modifying and accessing parts of an NDArray
.