Say I have a classification problem that has 30 potential binary labels. These labels are not mutually exclusive. The labels tend to be sparse--there is, on average, 1 positive label per all 30 labels but sometimes more than only 1. In the following code, how can I penalize the model from predicting all zeros? The accuracy will be high, but recall will be awful!
import numpy as np
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
OUTPUT_NODES = 30
np.random.seed(0)
def get_dataset():
"""
Get a dataset of X and y. This is a learnable problem as there is some signal in the features. 10% of the time, a
positive-output's index will also have a positive feature for that index
:return: X and y data for training
"""
n_observations = 30000
y = np.random.rand(n_observations, OUTPUT_NODES)
y = (y <= (1 / OUTPUT_NODES)).astype(int) # Makes a sparse output where there is roughly 1 positive label: ((1 / OUTPUT_NODES) * OUTPUT_NODES ≈ 1)
X = np.zeros((n_observations, OUTPUT_NODES))
for i in range(len(y)):
for j, feature in enumerate(y[i]):
if feature == 1:
X[i][j] = 1 if np.random.rand(1) > 0.9 else 0 # Makes the input features more noisy
# X[i][j] = 1 # Using this instead will make the model perform very well
return X, y
def create_model():
input_layer = Input(shape=(OUTPUT_NODES, ))
dense1 = Dense(100, activation='relu')(input_layer)
dense2 = Dense(100, activation='relu')(dense1)
output_layer = Dense(30, activation='sigmoid')(dense2)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['Recall'])
return model
def main():
X, y = get_dataset()
model = create_model()
model.fit(X, y, epochs=10, batch_size=10)
X_pred = np.random.randint(0, 2, (100, OUTPUT_NODES))
y_pred = model.predict(X_pred)
print(X_pred)
print(y_pred.round(1))
if __name__ == '__main__':
main()
I believe I read here that I could use:
weighted_cross_entropy_with_logits
to address this issue. How would that affect my final output layer's activation functions? Would I have to have an activation function? How do I specify a penalty to misclassifications of a true positive class?