I created a custom layer in python so that I can feed the data directly.
but I noticed it runs extremely slow and the GPU usage is at most 1% ( the memory is allocated, i.e. I can see that when I run the script, it allocates 2100MB VRAM and terminating the training, frees around 1G.
I'm not sure if this is an expected behavior or I'm doing something wrong.
Here is the script I wrote (based on this former pr) :
import json
import caffe
import numpy as np
from random import shuffle
from PIL import Image
class MyDataLayer(caffe.Layer):
"""
This is a simple datalayer for training a network on CIFAR10.
"""
def setup(self, bottom, top):
self.top_names = ['data', 'label']
# === Read input parameters ===
params = eval(self.param_str)
# Check the paramameters for validity.
check_params(params)
# store input as class variables
self.batch_size = params['batch_size']
# Create a batch loader to load the images.
self.batch_loader = BatchLoader(params, None)
# === reshape tops ===
# since we use a fixed input image size, we can shape the data layer
# once. Else, we'd have to do it in the reshape call.
top[0].reshape(self.batch_size, 3, params['im_height'], params['im_width'])
# this is for our label, since we only have one label we set this to 1
top[1].reshape(self.batch_size, 1)
print_info("MyDataLayer", params)
def forward(self, bottom, top):
"""
Load data.
"""
for itt in range(self.batch_size):
# Use the batch loader to load the next image.
im, label = self.batch_loader.load_next_image()
# Add directly to the caffe data layer
top[0].data[itt, ...] = im
top[1].data[itt, ...] = label
def reshape(self, bottom, top):
"""
There is no need to reshape the data, since the input is of fixed size
(rows and columns)
"""
pass
def backward(self, top, propagate_down, bottom):
"""
These layers does not back propagate
"""
pass
class BatchLoader(object):
"""
This class abstracts away the loading of images.
Images can either be loaded singly, or in a batch. The latter is used for
the asyncronous data layer to preload batches while other processing is
performed.
labels:
the format is like :
png_data_batch_1/leptodactylus_pentadactylus_s_000004.png 6
png_data_batch_1/camion_s_000148.png 9
png_data_batch_1/tipper_truck_s_001250.png 9
"""
def __init__(self, params, result):
self.result = result
self.batch_size = params['batch_size']
self.image_root = params['image_root']
self.im_shape = [params['im_height'],params['im_width']]
# get list of images and their labels.
self.image_labels = params['label']
#getting the list of all image filenames along with their labels
self.imagelist = [line.rstrip('\n\r') for line in open(self.image_labels)]
self._cur = 0 # current image
# this class does some simple data-manipulations
self.transformer = SimpleTransformer()
print ("BatchLoader initialized with {} images".format(len(self.imagelist)))
def load_next_image(self):
"""
Load the next image in a batch.
"""
# Did we finish an epoch?
if self._cur == len(self.imagelist):
self._cur = 0
shuffle(self.imagelist)
# Load an image
image_and_label = self.imagelist[self._cur] # Get the image index
#read the image filename
image_file_name = image_and_label[0:-1]
#load the image
im = np.asarray(Image.open(self.image_root +'/'+image_file_name))
#im = scipy.misc.imresize(im, self.im_shape) # resize
# do a simple horizontal flip as data augmentation
flip = np.random.choice(2)*2-1
im = im[:, ::flip, :]
# Load and prepare ground truth
#read the label
label = image_and_label[-1]
#convert to onehot encoded vector
#fix: caffe automatically converts the label into one hot encoded vector. so we only need to simply use the decimal number (i.e. the plain label number)
#one_hot_label = np.eye(10)[label]
self._cur += 1
return self.transformer.preprocess(im), label
def check_params(params):
"""
A utility function to check the parameters for the data layers.
"""
required = ['batch_size', 'image_root', 'im_width', 'im_height', 'label']
for r in required:
assert r in params.keys(), 'Params must include {}'.format(r)
def print_info(name, params):
"""
Ouput some info regarding the class
"""
print ("{} initialized for split: {}, with bs: {}, im_shape: {}.".format(
name,
params['image_root'],
params['batch_size'],
params['im_height'],
params['im_width'],
params['label']))
class SimpleTransformer:
"""
SimpleTransformer is a simple class for preprocessing and deprocessing
images for caffe.
"""
def __init__(self, mean=[125.30, 123.05, 114.06]):
self.mean = np.array(mean, dtype=np.float32)
self.scale = 1.0
def set_mean(self, mean):
"""
Set the mean to subtract for centering the data.
"""
self.mean = mean
def set_scale(self, scale):
"""
Set the data scaling.
"""
self.scale = scale
def preprocess(self, im):
"""
preprocess() emulate the pre-processing occuring in the vgg16 caffe
prototxt.
"""
im = np.float32(im)
im = im[:, :, ::-1] # change to BGR
im -= self.mean
im *= self.scale
im = im.transpose((2, 0, 1))
return im
def deprocess(self, im):
"""
inverse of preprocess()
"""
im = im.transpose(1, 2, 0)
im /= self.scale
im += self.mean
im = im[:, :, ::-1] # change to RGB
return np.uint8(im)
And in my train_test.prototxt
file I have :
name: "CIFAR10_SimpleTest_PythonLayer"
layer {
name: 'MyPythonLayer'
type: 'Python'
top: 'data'
top: 'label'
include {
phase: TRAIN
}
python_param {
#the python script filename
module: 'mypythonlayer'
#the class name
layer: 'MyDataLayer'
#needed parameters in json
param_str: '{"phase":"TRAIN", "batch_size":10, "im_height":32, "im_width":32, "image_root": "G:/Caffe/examples/cifar10/testbed/Train and Test using Pycaffe", "label": "G:/Caffe/examples/cifar10/testbed/Train and Test using Pycaffe/train_cifar10.txt"}'
}
}
layer {
name: 'MyPythonLayer'
type: 'Python'
top: 'data'
top: 'label'
include {
phase: TEST
}
python_param {
#the python script filename
module: 'mypythonlayer'
#the class name
layer: 'MyDataLayer'
#needed parameters in json
param_str: '{"phase":"TEST", "batch_size":10, "im_height":32, "im_width":32, "image_root": "G:/Caffe/examples/cifar10/testbed/Train and Test using Pycaffe", "label": "G:/Caffe/examples/cifar10/testbed/Train and Test using Pycaffe/test_cifar10.txt"}'
}
}
Whats wrong here?