I used the smaller VGG model and modified the training script of the following tutorial for training a previously trained model. Original source of model and script: https://www.pyimagesearch.com/2018/04/16/keras-and-convolutional-neural-networks-cnns/
This is what I did:
1st training session:
- train the model with image dataset of 2 classes A and B with the original training script from the tutorial
2nd training session:
- load the trained keras model and train with image dataset of class C without any new data of class A and B with the following modified training script (the loading and saving model method is referenced from the following stackoverflow thread: Loading a trained Keras model and continue training)
- load the pickled array of labels of the 1st session, combine it with the new label array in the 2nd session and save it in lb.pickle
Result:
The trained model after the 2nd session can only recognize the new class in the 2nd session. It seems other classes trained in the 1st session are lost. It just doesn't work.
My question: How to fix the following script to make incremental training work? Or any other suggestion or reference of incremental training that is similar to my case?
My modified training script:
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.preprocessing.image import img_to_array
from keras.models import load_model
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from smallervggnet import SmallerVGGNet
from imutils import paths
import numpy as np
import argparse, os, sys
import random
import pickle
import cv2
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
help="path to input dataset (i.e., directory of images)")
ap.add_argument("-im", "--loadmodel", required=True,
help="path to model to be loaded")
ap.add_argument("-m", "--model", required=True,
help="path to output model")
ap.add_argument("-l", "--labelbin", required=True,
help="path to output label binarizer")
ap.add_argument("-p", "--plot", type=str, default="plot.png",
help="path to output accuracy/loss plot")
args = vars(ap.parse_args())
EPOCHS = 100
INIT_LR = 1e-3
BS = 10
IMAGE_DIMS = (256, 256, 3)
data = []
labels = []
print("[INFO] loading images...")
imagePaths = sorted(list(paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
for imagePath in imagePaths:
image = cv2.imread(imagePath)
image = cv2.resize(image, (IMAGE_DIMS[1], IMAGE_DIMS[0]))
image = img_to_array(image)
data.append(image)
label = imagePath.split(os.path.sep)[-2]
labels.append(label)
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
print("[INFO] data matrix: {:.2f}MB".format(
data.nbytes / (1024 * 1000.0)))
lb = LabelBinarizer()
bLabels = lb.fit_transform(labels)
(trainX, testX, trainY, testY) = train_test_split(data,
bLabels, test_size=0.2, random_state=42)
#add these 2 lines to avoid error
trainY = np_utils.to_categorical(trainY, 2)
testY = np_utils.to_categorical(testY, 2)
aug = ImageDataGenerator(rotation_range=25, width_shift_range=0.1,
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest")
print("[INFO] load previously trained model")
modelPath = args["loadmodel"]
model = load_model(modelPath)
print("[INFO] training network...")
H = model.fit_generator(
aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY),
steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS, verbose=1)
print("[INFO] serializing network...")
model.save(args["model"])
# my attempt to keep the labels of all the training session in label binarizer
prevArray = './train_output/previous_data_array.pickle'
arrPickle = labels
if os.path.getsize(prevArray) > 0:
prev = pickle.loads(open(prevArray, 'rb').read())
arrPickle = np.concatenate((prev,labels), axis=0)
lb = LabelBinarizer()
lb.fit_transform(arrPickle)
print("[INFO] serializing combined label array...")
f = open(prevArray, "wb")
f.write(pickle.dumps(arrPickle))
f.close()
print("[INFO] serializing label binarizer...")
f = open(args["labelbin"], "wb")
f.write(pickle.dumps(lb))
f.close()