moscot icon indicating copy to clipboard operation
moscot copied to clipboard

Negative cost in OTTOutput

Open matthieuheitz opened this issue 1 year ago • 4 comments

I'm computing all pair distances between 6 spatial transcriptomics slices, and some of them have negative costs, and even absurdly large numbers, even though they have converged.

I prepared the problem like this: stp = stp.prepare(time_key="Batch_idx",spatial_key="spatial",joint_attr=joint_attr,cost='sq_euclidean',policy="triu") where joint_attr="X_pca" is a global PCA.

Here's my call to the solver, and the output:

stp = stp.solve(epsilon=epsilon_scheduler.Epsilon(target=1e-3, init=100, decay=0.99), 
                           alpha=0.5, 
                           linear_solver_kwargs={"momentum":acceleration.Momentum(start=300)})

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9862, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=0.9677, converged=True]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.6782, converged=True]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=0.9633, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=-0.0029, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-56.5436, converged=True]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9884, converged=True]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-3458.1289, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=-0.4292, converged=True]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=1.1268, converged=True]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9649, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=0.8918, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=0.8036, converged=True])]

I then tried to change some parameters of the solvers to see if it would matter, and it seems it does. Removing the epsilon decay, we only get 3 negative values (and a nan).

stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(start=200)}) 

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9831, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=nan, converged=False]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.678, converged=True]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=-11.4962, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=1.0901, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-55.8705, converged=True]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=0.9495, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=True]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-2565.3809, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=0.671, converged=True]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=1.1225, converged=True]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9571, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=1.0639, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=0.7674, converged=True])]

Removing the epsilon decay, and choosing a fixed momentum, we still have 4 negative values, but still a few (and many problems didn't converge, but I'm not sure that's relevant).

stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(value=1.6)})

[((0, 1), OTTOutput[shape=(147, 292), cost=0.9871, converged=True]),
 ((2, 4), OTTOutput[shape=(441, 824), cost=0.9025, converged=False]),
 ((1, 2), OTTOutput[shape=(292, 441), cost=0.6788, converged=False]),
 ((0, 4), OTTOutput[shape=(147, 824), cost=1.0402, converged=True]),
 ((3, 4), OTTOutput[shape=(1169, 824), cost=0.9117, converged=False]),
 ((1, 5), OTTOutput[shape=(292, 744), cost=1.0907, converged=True]),
 ((0, 3), OTTOutput[shape=(147, 1169), cost=-43.8229, converged=False]),
 ((1, 4), OTTOutput[shape=(292, 824), cost=0.9567, converged=True]),
 ((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=False]),
 ((0, 2), OTTOutput[shape=(147, 441), cost=-1995.7041, converged=True]),
 ((4, 5), OTTOutput[shape=(824, 744), cost=0.6685, converged=False]),
 ((0, 5), OTTOutput[shape=(147, 744), cost=-9.9394, converged=False]),
 ((2, 5), OTTOutput[shape=(441, 744), cost=0.9292, converged=True]),
 ((1, 3), OTTOutput[shape=(292, 1169), cost=0.9441, converged=True]),
 ((3, 5), OTTOutput[shape=(1169, 744), cost=-2.0191, converged=False])]

What's strange is that it's not always the same problems that have a negative cost. Though (0,3) and (0,2) seem to always be negative, and others are on and off. Any idea what could be causing this? Thanks!

matthieuheitz avatar Dec 06 '24 06:12 matthieuheitz

Hi @matthieuheitz ,

Thanks for opening this issue.

A negative cost is possible because we (following ott-jax) report the entropy-regularized optimal transport cost.

The cost=nan (with converged=False) can happen unfortunately.

For the case

((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]), could you try plotting the cost (https://moscot.readthedocs.io/en/latest/genapi/moscot.backends.ott.OTTOutput.plot_costs.html) . About the errors, let's discuss that in issue #768

MUCDK avatar Dec 06 '24 07:12 MUCDK

The cost plot doesn't seem crazy... stp.solutions[(1,4)].plot_costs() Image

matthieuheitz avatar Dec 06 '24 19:12 matthieuheitz

Yeah, I don't know where this -28217723322368.0 value comes from, it's nowhere in stp.solutions[(1,4)]._costs:

Array([ 1.7440577,  1.2362485,  0.9642093,  0.9561533,  0.9549437,
        0.9557142, -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ,
       -1.       , -1.       , -1.       , -1.       , -1.       ],      dtype=float32)

It looks like the setting of the cost attribute is not always happening like it should.

matthieuheitz avatar Dec 06 '24 19:12 matthieuheitz

I guess that would be some numerical overflow. Are you running it with dtype 32 or 64? If 32, could you try with dtype 64?

@michalk8 , do you have any idea where this might come from?

MUCDK avatar Dec 09 '24 22:12 MUCDK

Closing it for now, please feel free to re-open if necessary.

MUCDK avatar May 05 '25 13:05 MUCDK