In order to build condition-based CNN
, we could pass full batch of inputs to each sub-model in the Model2 and select the desired outputs from all sub-models outputs base on conditions (which the models you defined in the question does), or we can choose a faster way by follow the step of the conditions (which is the three conditions you listed)
Example code for showing condition mechanism:
# Mimic the test dataset and labels
batch = tf.constant([[1, 2, 3], [2, 3, 1], [3, 1, 2]])
y_all = [tf.one_hot(i, number_of_class, dtype=tf.float32) for i in range(number_of_class)]
# Mimic the outputs of model_01
y_p = tf.constant([[0.9, 0.1], [0.1, 0.9], [0.3, 0.7]])
y_p = tf.one_hot(tf.math.argmax(y_p, axis=1), number_of_class, dtype=tf.float32)
# Mimic the conditions by choose the samples from batch base on if prediction is equal to label wrt each class
for y in y_all:
condition = tf.reduce_all(tf.math.equal(y_p, y), 1)
indices = tf.where(condition)
choosed_inputs = tf.gather_nd(batch, indices)
print("label:\n{}\ncondition:\n{}\nindices:\n{}\nchoosed_inputs:\n{}\n".format(y, condition, indices, choosed_inputs))
Outputs:
label:
[1. 0.]
condition:
[ True False False]
indices:
[[0]]
choosed_inputs:
[[1 2 3]]
label:
[0. 1.]
condition:
[False True True]
indices:
[[1]
[2]]
choosed_inputs:
[[2 3 1]
[3 1 2]]
Example code that build the condition-based CNN
model and training it in custom training fashion:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.utils import *
import numpy as np
img_rows, img_cols, number_of_class, batch_size = 256, 256, 2, 64
#----------- main model (Model 1) ------------------------------------
inputs = Input(shape=(img_rows, img_cols, 3))
conv_01 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_01') (inputs)
conv_02 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_02') (conv_01)
skip_dog = conv_02
conv_03 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_03') (conv_02)
skip_cat = conv_03
conv_04 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_04') (conv_03)
flatten_main_model = Flatten() (conv_04)
Output_main_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer")(flatten_main_model)
#----------- Conditional Cat model ------------------------------------
inputs_1 = Input(shape=skip_cat.shape[1:])
conv_05 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_05') (inputs_1)
flatten_cat_model = Flatten() (conv_05)
Output_cat_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer_cat")(flatten_cat_model)
#----------- Conditional Dog model ------------------------------------
inputs_2 = Input(shape=skip_dog.shape[1:])
conv_06 = Convolution2D(64, 3, 3, activation='relu', name = 'conv_06') (inputs_2)
flatten_dog_model = Flatten() (conv_06)
Output_dog_model = Dense(units = number_of_class , activation = 'softmax', name = "Output_layer_dog")(flatten_dog_model)
#----------------------------- My discrete 3 models --------------------------------
model_01 = Model(inputs = inputs, outputs = [skip_cat, skip_dog, Output_main_model], name = 'model_main')
model_02_1 = Model(inputs = inputs_1, outputs = Output_cat_model, name = 'Conditional_cat_model')
model_02_2 = Model(inputs = inputs_2, outputs = Output_dog_model, name = 'Conditional_dog_model')
# Get one hot vectors for all the labels
y_all = [tf.one_hot(i, number_of_class, dtype=tf.float32) for i in range(number_of_class)]
sub_models_all = [model_02_1, model_02_2]
sub_models_trainable_variables = [model_01.trainable_variables[:6] + model_02_1.trainable_variables,
model_01.trainable_variables[:4] + model_02_2.trainable_variables]
cce = keras.losses.CategoricalCrossentropy()
optimizer_01 = keras.optimizers.Adam(learning_rate=1e-3, name='Adam_01')
optimizer_02 = keras.optimizers.Adam(learning_rate=2e-3, name='Adam_02')
@tf.function
def train_step(batch_imgs, labels):
with tf.GradientTape(persistent=True) as tape:
model_01_outputs = model_01(batch_imgs)
y_p_01 = model_01_outputs[-1]
loss_01 = cce(labels, y_p_01)
# Convert outputs of model_01 from float in (0, 1) to one hot vectors, no gradients flow back from here
y_p_01 = tf.one_hot(tf.math.argmax(y_p_01, axis=1), number_of_class, dtype=tf.float32)
loss_02_all = []
for i in range(number_of_class):
condition = tf.reduce_all(tf.math.equal(y_p_01, y_all[i]), 1)
indices = tf.where(condition)
choosed_inputs = tf.gather_nd(model_01_outputs[i], indices)
# Note here the inputs batch size for each sub-model is dynamic
y_p_02 = sub_models_all[i](choosed_inputs)
y_t = tf.gather_nd(labels, indices)
loss_02 = cce(y_t, y_p_02)
loss_02_all.append(loss_02)
grads_01 = tape.gradient(loss_01, model_01.trainable_variables)
optimizer_01.apply_gradients(zip(grads_01, model_01.trainable_variables))
for i in range(number_of_class):
grads_02 = tape.gradient(loss_02_all[i], sub_models_trainable_variables[i])
optimizer_02.apply_gradients(zip(grads_02, sub_models_trainable_variables[i]))
return loss_01, loss_02_all
def training():
for j in range(10):
random_imgs = np.random.rand(batch_size, img_rows, img_cols, 3)
random_labels = np.eye(number_of_class)[np.random.choice(number_of_class, batch_size)]
loss_01, loss_02_all = train_step(random_imgs, random_labels)
print("Step: {}, Loss_01: {}, Loss_02_all: {}".format(j, loss_01, loss_02_all))
Outputs is something like:
Step: 0, Loss_01: 0.6966696977615356, Loss_02_1: 0.0, Loss_02_2: 0.6886894702911377
Step: 1, Loss_01: 0.6912064552307129, Loss_02_1: 0.6968430280685425, Loss_02_2: 0.6911896467208862
Step: 2, Loss_01: 0.6910352110862732, Loss_02_1: 0.698455274105072, Loss_02_2: 0.6935626864433289
Step: 3, Loss_01: 0.6955667734146118, Loss_02_1: 0.6843984127044678, Loss_02_2: 0.6953505277633667
Step: 4, Loss_01: 0.6941269636154175, Loss_02_1: 0.673763632774353, Loss_02_2: 0.6994296908378601
Step: 5, Loss_01: 0.6872361898422241, Loss_02_1: 0.6769005060195923, Loss_02_2: 0.6907837390899658
Step: 6, Loss_01: 0.6931678056716919, Loss_02_1: 0.7674703598022461, Loss_02_2: 0.6935689449310303
Step: 7, Loss_01: 0.6976977586746216, Loss_02_1: 0.7503389120101929, Loss_02_2: 0.7076178789138794
Step: 8, Loss_01: 0.6932153105735779, Loss_02_1: 0.7428234219551086, Loss_02_2: 0.6935019493103027
Step: 9, Loss_01: 0.693305253982544, Loss_02_1: 0.6476342082023621, Loss_02_2: 0.6916818618774414