Focal_Loss_Keras
Focal_Loss_Keras copied to clipboard
This implementation is not equivalent to cross entropy when gamma = 0
In the following lines, the computation of weight multiplies y_true one more time.
ce = tf.multiply(y_true, -tf.log(model_out))
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
I find that in current implementation, focal loss is not equivalent to cross entropy when gamma = 0.
Maybe weight should be given as follows:
weight = tf.pow(tf.subtract(1., model_out), gamma)
It is equivalent.
Remember to replace code model_out = tf.add(y_pred, epsilon) with model_out = y_pred to get consistent result.
test code
y_t=[[0.0,1.0,0.0],[1.0,0.0,0.0],[0.0,0.0,1.0],[1.0,0.0,0.0]]
y_p=[[0.9,0.1,0.0],[0.5,0.25,0.25],[0.15,0.05,0.8],[0.5,0.4,0.1]]
def focal_loss_fixed(y_true, y_pred, gamma=0.0, alpha=1.0):
epsilon = 1.e-9
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
#model_out = tf.add(y_pred, epsilon)
model_out = y_pred
ce = tf.multiply(y_true, -tf.log(model_out))
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
fl = tf.multiply(alpha, tf.multiply(weight, ce))
reduced_fl = tf.reduce_max(fl, axis=1)
return reduced_fl, tf.reduce_mean(reduced_fl)
print(sess.run(focal_loss_fixed(tf.constant(y_t),tf.constant(y_p))))
print(sess.run(tf.keras.losses.categorical_crossentropy(tf.constant(y_t),tf.constant(y_p))))