For those who are faceing the same problem, I've found a good solution for it by reimplementing the CE in a numerically stable way. If you want to know why you shouldn't implement CE directly as its equation says -∑ p_i log(q_i)
check out this tutorial.
The implementation I used to apply the priors works as follows:
def modified_CE(logits=None, labels=None, priors=None):
# subtracting the maximum value to prevent inf results
# you should change the shape of your logits based on your data
scaled_logits = logits - tf.reshape(tf.reduce_max(logits,1),shape=(7500,1))
# renormalize your logits as a finale step for the log softmax function
normalized_logits = scaled_logits - tf.reshape(tf.reduce_logsumexp(scaled_logits,1),shape=(7500,1))
# apply the priors
normalized_logits -= tf.log(np.array(priors,dtype=np.float32))
# renormalize
normalized_logits -= tf.reshape(tf.reduce_logsumexp(normalized_logits,1),shape=(7500,1))
return tf.reduce_mean(-tf.reduce_sum(labels[0,:,:]*normalized_logits,1))