Negative cost in OTTOutput
I'm computing all pair distances between 6 spatial transcriptomics slices, and some of them have negative costs, and even absurdly large numbers, even though they have converged.
I prepared the problem like this:
stp = stp.prepare(time_key="Batch_idx",spatial_key="spatial",joint_attr=joint_attr,cost='sq_euclidean',policy="triu")
where joint_attr="X_pca" is a global PCA.
Here's my call to the solver, and the output:
stp = stp.solve(epsilon=epsilon_scheduler.Epsilon(target=1e-3, init=100, decay=0.99),
alpha=0.5,
linear_solver_kwargs={"momentum":acceleration.Momentum(start=300)})
[((0, 1), OTTOutput[shape=(147, 292), cost=0.9862, converged=True]),
((2, 4), OTTOutput[shape=(441, 824), cost=0.9677, converged=True]),
((1, 2), OTTOutput[shape=(292, 441), cost=0.6782, converged=True]),
((0, 4), OTTOutput[shape=(147, 824), cost=0.9633, converged=True]),
((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
((1, 5), OTTOutput[shape=(292, 744), cost=-0.0029, converged=True]),
((0, 3), OTTOutput[shape=(147, 1169), cost=-56.5436, converged=True]),
((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]),
((2, 3), OTTOutput[shape=(441, 1169), cost=0.9884, converged=True]),
((0, 2), OTTOutput[shape=(147, 441), cost=-3458.1289, converged=True]),
((4, 5), OTTOutput[shape=(824, 744), cost=-0.4292, converged=True]),
((0, 5), OTTOutput[shape=(147, 744), cost=1.1268, converged=True]),
((2, 5), OTTOutput[shape=(441, 744), cost=0.9649, converged=True]),
((1, 3), OTTOutput[shape=(292, 1169), cost=0.8918, converged=True]),
((3, 5), OTTOutput[shape=(1169, 744), cost=0.8036, converged=True])]
I then tried to change some parameters of the solvers to see if it would matter, and it seems it does. Removing the epsilon decay, we only get 3 negative values (and a nan).
stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(start=200)})
[((0, 1), OTTOutput[shape=(147, 292), cost=0.9831, converged=True]),
((2, 4), OTTOutput[shape=(441, 824), cost=nan, converged=False]),
((1, 2), OTTOutput[shape=(292, 441), cost=0.678, converged=True]),
((0, 4), OTTOutput[shape=(147, 824), cost=-11.4962, converged=True]),
((3, 4), OTTOutput[shape=(1169, 824), cost=0.9322, converged=True]),
((1, 5), OTTOutput[shape=(292, 744), cost=1.0901, converged=True]),
((0, 3), OTTOutput[shape=(147, 1169), cost=-55.8705, converged=True]),
((1, 4), OTTOutput[shape=(292, 824), cost=0.9495, converged=True]),
((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=True]),
((0, 2), OTTOutput[shape=(147, 441), cost=-2565.3809, converged=True]),
((4, 5), OTTOutput[shape=(824, 744), cost=0.671, converged=True]),
((0, 5), OTTOutput[shape=(147, 744), cost=1.1225, converged=True]),
((2, 5), OTTOutput[shape=(441, 744), cost=0.9571, converged=True]),
((1, 3), OTTOutput[shape=(292, 1169), cost=1.0639, converged=True]),
((3, 5), OTTOutput[shape=(1169, 744), cost=0.7674, converged=True])]
Removing the epsilon decay, and choosing a fixed momentum, we still have 4 negative values, but still a few (and many problems didn't converge, but I'm not sure that's relevant).
stp = stp.solve(epsilon=1e-3, alpha=0.5, linear_solver_kwargs={"momentum":acceleration.Momentum(value=1.6)})
[((0, 1), OTTOutput[shape=(147, 292), cost=0.9871, converged=True]),
((2, 4), OTTOutput[shape=(441, 824), cost=0.9025, converged=False]),
((1, 2), OTTOutput[shape=(292, 441), cost=0.6788, converged=False]),
((0, 4), OTTOutput[shape=(147, 824), cost=1.0402, converged=True]),
((3, 4), OTTOutput[shape=(1169, 824), cost=0.9117, converged=False]),
((1, 5), OTTOutput[shape=(292, 744), cost=1.0907, converged=True]),
((0, 3), OTTOutput[shape=(147, 1169), cost=-43.8229, converged=False]),
((1, 4), OTTOutput[shape=(292, 824), cost=0.9567, converged=True]),
((2, 3), OTTOutput[shape=(441, 1169), cost=0.9883, converged=False]),
((0, 2), OTTOutput[shape=(147, 441), cost=-1995.7041, converged=True]),
((4, 5), OTTOutput[shape=(824, 744), cost=0.6685, converged=False]),
((0, 5), OTTOutput[shape=(147, 744), cost=-9.9394, converged=False]),
((2, 5), OTTOutput[shape=(441, 744), cost=0.9292, converged=True]),
((1, 3), OTTOutput[shape=(292, 1169), cost=0.9441, converged=True]),
((3, 5), OTTOutput[shape=(1169, 744), cost=-2.0191, converged=False])]
What's strange is that it's not always the same problems that have a negative cost. Though (0,3) and (0,2) seem to always be negative, and others are on and off. Any idea what could be causing this? Thanks!
Hi @matthieuheitz ,
Thanks for opening this issue.
A negative cost is possible because we (following ott-jax) report the entropy-regularized optimal transport cost.
The cost=nan (with converged=False) can happen unfortunately.
For the case
((1, 4), OTTOutput[shape=(292, 824), cost=-28217723322368.0, converged=True]), could you try plotting the cost (https://moscot.readthedocs.io/en/latest/genapi/moscot.backends.ott.OTTOutput.plot_costs.html) . About the errors, let's discuss that in issue #768
The cost plot doesn't seem crazy...
stp.solutions[(1,4)].plot_costs()
Yeah, I don't know where this -28217723322368.0 value comes from, it's nowhere in stp.solutions[(1,4)]._costs:
Array([ 1.7440577, 1.2362485, 0.9642093, 0.9561533, 0.9549437,
0.9557142, -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ,
-1. , -1. , -1. , -1. , -1. ], dtype=float32)
It looks like the setting of the cost attribute is not always happening like it should.
I guess that would be some numerical overflow. Are you running it with dtype 32 or 64? If 32, could you try with dtype 64?
@michalk8 , do you have any idea where this might come from?
Closing it for now, please feel free to re-open if necessary.