[Improvement] Ensure the graph is connected before connecting geodesics
Is your feature related to a problem?
When plotting geodesics, calling connecting_geodesic, the function assumes that the graph is connected. When not connected, networks.shortest_path throws after some time an error that can be seen as cryptic by the users.
Describe the solution you would like:
Either when discretising the manifold, or before plotting the geodesics, a warning message can be added to check that the graph is connected: assert nx.is_connected(graph), "Graph not connected". Additional guidance or best practices can be added to ensure the graph is connected.
Another improvement would be to check if the size are as expected. For example, in DiscretizedManifold.fit() function:
with torch.no_grad():
weight = model.curve_length(line(t))
assert weight.shape == bs, f"model.curve_length should return a {bs} shape object but found {weight.shape}."
Another question regarding DiscretizedManifold.fit(): a graph is created based on two points obtained from a curve:
t = torch.linspace(0, 1, 2)
(...)
with torch.no_grad():
weight = model.curve_length(line(t))
and this method mainly relies on giving a metric tensor to compute the graph (curve_length depends on inner_product which depends on metric). Yet, when the metric tensor is not easily accessible, once might want to compute the curve lenght based on the derivatives ($\dot{\gamma}$) of the curve: $L[\gamma] = \int \dot{\gamma_t} \ dt$. Derivatives can be nicely computed only if the curve is discretised enough.
The question is: would it make sense to add an argument (ex: num_curve_points) to discretise the curve and use the derivatives to compute the expected metric (or a Finsler metric for example)?
def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0, num_curve_points=2):
(...)
t = torch.linspace(0, 1, num_curve_points)