Focal_Loss_Keras icon indicating copy to clipboard operation
Focal_Loss_Keras copied to clipboard

This implementation is not equivalent to cross entropy when gamma = 0

Open Construction opened this issue 7 years ago • 1 comments

In the following lines, the computation of weight multiplies y_true one more time.

ce = tf.multiply(y_true, -tf.log(model_out))
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))

I find that in current implementation, focal loss is not equivalent to cross entropy when gamma = 0. Maybe weight should be given as follows:

weight = tf.pow(tf.subtract(1., model_out), gamma)

Construction avatar Mar 20 '19 02:03 Construction

It is equivalent.

Remember to replace code model_out = tf.add(y_pred, epsilon) with model_out = y_pred to get consistent result.

test code

y_t=[[0.0,1.0,0.0],[1.0,0.0,0.0],[0.0,0.0,1.0],[1.0,0.0,0.0]]
y_p=[[0.9,0.1,0.0],[0.5,0.25,0.25],[0.15,0.05,0.8],[0.5,0.4,0.1]]

def focal_loss_fixed(y_true, y_pred, gamma=0.0, alpha=1.0):
    epsilon = 1.e-9
    y_true = tf.convert_to_tensor(y_true, tf.float32)
    y_pred = tf.convert_to_tensor(y_pred, tf.float32)
    #model_out = tf.add(y_pred, epsilon)
    model_out = y_pred
    ce = tf.multiply(y_true, -tf.log(model_out))
    weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
    fl = tf.multiply(alpha, tf.multiply(weight, ce))
    reduced_fl = tf.reduce_max(fl, axis=1)
    return reduced_fl, tf.reduce_mean(reduced_fl)

print(sess.run(focal_loss_fixed(tf.constant(y_t),tf.constant(y_p))))
print(sess.run(tf.keras.losses.categorical_crossentropy(tf.constant(y_t),tf.constant(y_p))))

uranusx86 avatar Sep 17 '19 02:09 uranusx86