3

I work with The CIFAR-10 dataset. Here is the way I prepare data:

library(R.matlab)
A1 <- readMat("data_batch_1.mat")
A2 <- readMat("data_batch_2.mat")
A3 <- readMat("data_batch_3.mat")
A4 <- readMat("data_batch_4.mat")
A5 <- readMat("data_batch_5.mat")
meta <- readMat("batches.meta.mat")
test <- readMat("test_batch.mat")
A <- rbind(A1$data, A2$data, A3$data, A4$data, A5$data)
Gtrain <- 0.21*A[,1:1024] + 0.71*A[,1025:2048] +0.07*A[,2049:3072]
ytrain <- c(A1$labels, A2$labels, A3$labels, A4$labels, A5$labels)
Gtest <- 0.21*test$data[,1:1024] + 0.71*test$data[,1025:2048]     +0.07*test$data[,2049:3072]
ytest <- test$labels
x_train <- Gtrain[ytrain %in% c(7,9),]
y_train <- ytrain[ytrain %in% c(7,9)]==7
x_test <- Gtest[ytest %in% c(7,9),]
y_test <- ytest[ytest %in% c(7,9)]==7

I train deep neural network:

library(deepnet)
dnn <- dbn.dnn.train(x_train, y_train, hidden = rep(10,2),numepochs = 3)

And I make prediction

prednn <- nn.predict(dnn, x_test)

which returns vector filled with one value (0.4603409 in this case, but for different parameters it is always something around 0.5). What is wrong?

Norbert R
  • 131
  • 8
  • Have you seen answer to this question?http://stackoverflow.com/questions/28623533/r-package-deepnet-training-and-testing-the-mnist-dataset?rq=1 – Marcin Jun 04 '15 at 17:21

1 Answers1

0

Based on this answer to similar question maybe consider this approach: neuralnet prediction returns the same values for all predictions

The first reason to consider when you get weird results with neural networks is normalization. Your data must be normalized, otherwise, yes, the training will result in skewed NN which will produce the same outcome all the time, it is a common symptom.

Looking at your data set, there are values >>1 which means they are all treated by NN essentially the same. The reason for it is that the traditionally used response functions are (almost) constant outside some range around 0.

Always normalize your data before feeding it into a neural network.

Community
  • 1
  • 1
Marcin
  • 7,834
  • 8
  • 52
  • 99