1

To implement learning in a neural network I'm using stochastic gradient descent in which the mini batches are represented via the following list comprehension:

mini_batches = [training_data[j:j+mini_batch_size] for j in range(0,len(training_data),mini_batch_size)]

Inside the list comprehension the snippets mini_batch_size and training_data stem from inputs to the stochastic gradient descent method, which is a part of a neural network class.

The training_data consists of tuples of arrays (x,y), whereby the first array x contains the input data and the second array y contains the output data (the classification). The array x has shape (784,1) and the array y has shape (10,1). The length len(training_data) outputs 50,000.

I'd like to construct two Numpy arrays, whereby the first array is a matrix for the input data and the second array is a matrix for the output data. The two arrays would be utilized in a matrix based approach to the learning algorithm. However, I am not sure of a simple implementation for the array constructions. Help along this direction would be greatly appreciated.

The mini_batch_size is an integer, which specifies how the training_data will be partitioned. The sliced training data, namely training_data[0] is of the following form (the first part of the tuple contains 784 floats from 0. to 1. and the latter part of the tuple contains 10 floats):

(array([[ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.01171875],
    [ 0.0703125 ],
    [ 0.0703125 ],
    [ 0.0703125 ],
    [ 0.4921875 ],
    [ 0.53125   ],
    [ 0.68359375],
    [ 0.1015625 ],
    [ 0.6484375 ],
    [ 0.99609375],
    [ 0.96484375],
    [ 0.49609375],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.1171875 ],
    [ 0.140625  ],
    [ 0.3671875 ],
    [ 0.6015625 ],
    [ 0.6640625 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.87890625],
    [ 0.671875  ],
    [ 0.98828125],
    [ 0.9453125 ],
    [ 0.76171875],
    [ 0.25      ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.19140625],
    [ 0.9296875 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98046875],
    [ 0.36328125],
    [ 0.3203125 ],
    [ 0.3203125 ],
    [ 0.21875   ],
    [ 0.15234375],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0703125 ],
    [ 0.85546875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.7734375 ],
    [ 0.7109375 ],
    [ 0.96484375],
    [ 0.94140625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.3125    ],
    [ 0.609375  ],
    [ 0.41796875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.80078125],
    [ 0.04296875],
    [ 0.        ],
    [ 0.16796875],
    [ 0.6015625 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0546875 ],
    [ 0.00390625],
    [ 0.6015625 ],
    [ 0.98828125],
    [ 0.3515625 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.54296875],
    [ 0.98828125],
    [ 0.7421875 ],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.04296875],
    [ 0.7421875 ],
    [ 0.98828125],
    [ 0.2734375 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.13671875],
    [ 0.94140625],
    [ 0.87890625],
    [ 0.625     ],
    [ 0.421875  ],
    [ 0.00390625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.31640625],
    [ 0.9375    ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.46484375],
    [ 0.09765625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.17578125],
    [ 0.7265625 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.5859375 ],
    [ 0.10546875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0625    ],
    [ 0.36328125],
    [ 0.984375  ],
    [ 0.98828125],
    [ 0.73046875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.97265625],
    [ 0.98828125],
    [ 0.97265625],
    [ 0.25      ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.1796875 ],
    [ 0.5078125 ],
    [ 0.71484375],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.80859375],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.15234375],
    [ 0.578125  ],
    [ 0.89453125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.9765625 ],
    [ 0.7109375 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.09375   ],
    [ 0.4453125 ],
    [ 0.86328125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.78515625],
    [ 0.3046875 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.08984375],
    [ 0.2578125 ],
    [ 0.83203125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.7734375 ],
    [ 0.31640625],
    [ 0.0078125 ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.0703125 ],
    [ 0.66796875],
    [ 0.85546875],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.76171875],
    [ 0.3125    ],
    [ 0.03515625],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.21484375],
    [ 0.671875  ],
    [ 0.8828125 ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.953125  ],
    [ 0.51953125],
    [ 0.04296875],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.53125   ],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.98828125],
    [ 0.828125  ],
    [ 0.52734375],
    [ 0.515625  ],
    [ 0.0625    ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ],
    [ 0.        ]], dtype=float32), array([[ 0.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 1.],
    [ 0.],
    [ 0.],
    [ 0.],
    [ 0.]]))

The data comes from the MNIST training set. Note that training_data = list(training_data) is performed, when the data is initially loaded-- the network class is created at a later time. If the list(foo) operation is not used, then the data appears as follows: <zip at 0x1d744027948>.

sunspots
  • 1,047
  • 13
  • 29
  • 2
    See if this helps out - https://stackoverflow.com/questions/40084931/taking-subarrays-from-numpy-array-with-given-stride-stepsize – Divakar Dec 28 '17 at 12:25

1 Answers1

0

Here's a solution, providing an array for the input data, as well as the target data:

input_data_array = np.asarray([input_data.ravel() for input_data, target_data in mini_batch]).T
target_data_array = np.asarray([target_data.ravel() for input_data, target_data in mini_batch]).T

The code uses a combination of ravel and asarray from Numpy.

Reference: The training_data and mnist_loader module:

import mnist_loader
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
training_data = list(training_data)

As for the motivation, backpropagation is faster in matrix form. The fundamental backpropagation equations are shown below (source: online text from Michael Nielsen: Neural networks and deep learning). The circle with a dot in the interior is notation for the Hadamard, or Schur, product:

Backpropogation equations

sunspots
  • 1,047
  • 13
  • 29