stochman icon indicating copy to clipboard operation
stochman copied to clipboard

[Improvement] Ensure the graph is connected before connecting geodesics

Open a-pouplin opened this issue 2 years ago • 2 comments

Is your feature related to a problem? When plotting geodesics, calling connecting_geodesic, the function assumes that the graph is connected. When not connected, networks.shortest_path throws after some time an error that can be seen as cryptic by the users.

Describe the solution you would like: Either when discretising the manifold, or before plotting the geodesics, a warning message can be added to check that the graph is connected: assert nx.is_connected(graph), "Graph not connected". Additional guidance or best practices can be added to ensure the graph is connected.

a-pouplin avatar Mar 20 '23 10:03 a-pouplin

Another improvement would be to check if the size are as expected. For example, in DiscretizedManifold.fit() function:

with torch.no_grad():
   weight = model.curve_length(line(t))
   assert weight.shape == bs, f"model.curve_length should return a {bs} shape object but found {weight.shape}."

a-pouplin avatar Mar 20 '23 12:03 a-pouplin

Another question regarding DiscretizedManifold.fit(): a graph is created based on two points obtained from a curve:

t = torch.linspace(0, 1, 2)

(...)

with torch.no_grad():
   weight = model.curve_length(line(t))

and this method mainly relies on giving a metric tensor to compute the graph (curve_length depends on inner_product which depends on metric). Yet, when the metric tensor is not easily accessible, once might want to compute the curve lenght based on the derivatives ($\dot{\gamma}$) of the curve: $L[\gamma] = \int \dot{\gamma_t} \ dt$. Derivatives can be nicely computed only if the curve is discretised enough.

The question is: would it make sense to add an argument (ex: num_curve_points) to discretise the curve and use the derivatives to compute the expected metric (or a Finsler metric for example)?

def fit(self, model, grid, use_diagonals=True, batch_size=4, interpolation_noise=0.0, num_curve_points=2):

   (...)

   t = torch.linspace(0, 1, num_curve_points)

a-pouplin avatar Mar 20 '23 18:03 a-pouplin