To implement learning in a neural network I'm using stochastic gradient descent in which the mini batches are represented via the following list comprehension:
mini_batches = [training_data[j:j+mini_batch_size] for j in range(0,len(training_data),mini_batch_size)]
Inside the list comprehension the snippets mini_batch_size
and training_data
stem from inputs to the stochastic gradient descent method, which is a part of a neural network class.
The training_data
consists of tuples of arrays (x,y)
, whereby the first array x
contains the input data and the second array y
contains the output data (the classification). The array x
has shape (784,1)
and the array y
has shape (10,1)
. The length len(training_data)
outputs 50,000.
I'd like to construct two Numpy arrays, whereby the first array is a matrix for the input data and the second array is a matrix for the output data. The two arrays would be utilized in a matrix based approach to the learning algorithm. However, I am not sure of a simple implementation for the array constructions. Help along this direction would be greatly appreciated.
The mini_batch_size
is an integer, which specifies how the training_data
will be partitioned. The sliced training data, namely training_data[0]
is of the following form (the first part of the tuple contains 784 floats from 0. to 1. and the latter part of the tuple contains 10 floats):
(array([[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.01171875],
[ 0.0703125 ],
[ 0.0703125 ],
[ 0.0703125 ],
[ 0.4921875 ],
[ 0.53125 ],
[ 0.68359375],
[ 0.1015625 ],
[ 0.6484375 ],
[ 0.99609375],
[ 0.96484375],
[ 0.49609375],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.1171875 ],
[ 0.140625 ],
[ 0.3671875 ],
[ 0.6015625 ],
[ 0.6640625 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.87890625],
[ 0.671875 ],
[ 0.98828125],
[ 0.9453125 ],
[ 0.76171875],
[ 0.25 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.19140625],
[ 0.9296875 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98046875],
[ 0.36328125],
[ 0.3203125 ],
[ 0.3203125 ],
[ 0.21875 ],
[ 0.15234375],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0703125 ],
[ 0.85546875],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.7734375 ],
[ 0.7109375 ],
[ 0.96484375],
[ 0.94140625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.3125 ],
[ 0.609375 ],
[ 0.41796875],
[ 0.98828125],
[ 0.98828125],
[ 0.80078125],
[ 0.04296875],
[ 0. ],
[ 0.16796875],
[ 0.6015625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0546875 ],
[ 0.00390625],
[ 0.6015625 ],
[ 0.98828125],
[ 0.3515625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.54296875],
[ 0.98828125],
[ 0.7421875 ],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.04296875],
[ 0.7421875 ],
[ 0.98828125],
[ 0.2734375 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.13671875],
[ 0.94140625],
[ 0.87890625],
[ 0.625 ],
[ 0.421875 ],
[ 0.00390625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.31640625],
[ 0.9375 ],
[ 0.98828125],
[ 0.98828125],
[ 0.46484375],
[ 0.09765625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.17578125],
[ 0.7265625 ],
[ 0.98828125],
[ 0.98828125],
[ 0.5859375 ],
[ 0.10546875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0625 ],
[ 0.36328125],
[ 0.984375 ],
[ 0.98828125],
[ 0.73046875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.97265625],
[ 0.98828125],
[ 0.97265625],
[ 0.25 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.1796875 ],
[ 0.5078125 ],
[ 0.71484375],
[ 0.98828125],
[ 0.98828125],
[ 0.80859375],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.15234375],
[ 0.578125 ],
[ 0.89453125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.9765625 ],
[ 0.7109375 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.09375 ],
[ 0.4453125 ],
[ 0.86328125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.78515625],
[ 0.3046875 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.08984375],
[ 0.2578125 ],
[ 0.83203125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.7734375 ],
[ 0.31640625],
[ 0.0078125 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.0703125 ],
[ 0.66796875],
[ 0.85546875],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.76171875],
[ 0.3125 ],
[ 0.03515625],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.21484375],
[ 0.671875 ],
[ 0.8828125 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.953125 ],
[ 0.51953125],
[ 0.04296875],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0.53125 ],
[ 0.98828125],
[ 0.98828125],
[ 0.98828125],
[ 0.828125 ],
[ 0.52734375],
[ 0.515625 ],
[ 0.0625 ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ],
[ 0. ]], dtype=float32), array([[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 0.],
[ 1.],
[ 0.],
[ 0.],
[ 0.],
[ 0.]]))
The data comes from the MNIST training set. Note that training_data = list(training_data)
is performed, when the data is initially loaded-- the network class is created at a later time. If the list(foo)
operation is not used, then the data appears as follows: <zip at 0x1d744027948>
.