SofaPython3 icon indicating copy to clipboard operation
SofaPython3 copied to clipboard

Example of a ForceField implemented with JAX

Open leobois67 opened this issue 2 months ago • 10 comments

Illustrates how to leverage JAX autodiff to implement the methods 'addDForce()' and 'addKToMatrix()' automatically. This example shows a set of particles attached to the origin with a simple spring.

I included a few options for the scene, that might be unnecessary:

  • time integration and linear solver, to illustrate that not all combinations require all methods to be implemented
  • number of particles, to show how computational cost and memory requirement scale (important here since the jacobian computed by JAX is dense)
  • SOFA built-in alternative for comparison

leobois67 avatar Nov 07 '25 15:11 leobois67

Note: it requires JAX to run

leobois67 avatar Nov 07 '25 15:11 leobois67

Sorry, but I get a segfault while running it at the very first timestep.

terminate called after throwing an instance of 'pybind11::type_error'
  what():  Can't read return value of AddKToMatrix. A numpy array is expected

bakpaul avatar Nov 13 '25 16:11 bakpaul

I believe this example stopped working after this commit 5029c050f68078ab5173d1ffcace69598caa7f8b by @alxbilger.

I don’t know exactly how things worked before, but now when I try to fix it, (1) I have to add an extra dimension to enter the correct if statement, so the KMatrix has shape (n, n, 1), and (2) it is super slow. I will look into it.

leobois67 avatar Nov 13 '25 17:11 leobois67

@leobois67 you're probably right about my commit. However, I don't understand how it could work before. I suspect that the matrix was not filled with your Python force field. This hypothesis is strenghened by the fact that it is slower now. Is it possible?

alxbilger avatar Nov 13 '25 17:11 alxbilger

@alxbilger I agree that it probably did not work as it was supposed to. The matrix that was passed was not a numpy array but a jax array, with shape (n, 3, n, 3), so I don’t know how it was processed, but I guess it was silently ignored.

Also, it is the processing of the matrix that is slow, not its computation: returning a dense matrix full of zeros seems to be just as slow. Is there a way to improve that?

leobois67 avatar Nov 13 '25 17:11 leobois67

In the code in https://github.com/sofa-framework/SofaPython3/commit/5029c050f68078ab5173d1ffcace69598caa7f8b, each mat->add is time consuming. It is best to reduce the number of calls. You can try to test if the value is non-zero before calling mat->add.

alxbilger avatar Nov 13 '25 17:11 alxbilger

I think I got something that works relatively well, by returning only the non-zero values as you suggested. I left the code that returns the dense matrix for people who don’t have a big sparse jacobian.

To give you an idea of the impact on the performances, here are some stats with the simulation of 1000 independent particles:

  • SOFA force field: ~1300 FPS
  • JAX force field, sparse jacobian, optimized: ~390 FPS on GPU, ~40 FPS on CPU
  • JAX force field, sparse jacobian, not optimized: ~120 FPS on GPU, ~7 FPS on CPU
  • JAX force field, dense jacobian: ~7 FPS on GPU and CPU

By "sparse jacobian" I mean returning only the non-zero values; by "optimized" I refer to an optimization I mention in the code, that leverages the knowledge of the sparsity of the jacobian; by "GPU"/"CPU" I refer to the device JAX uses.

Also, most of the time spent in addKToMatrix() is spent in the conversion of the JAX array (typically on GPU) to the numpy array (on CPU): about 90% of the time for the 390 FPS case. I guess JAX would be even better if it could share data with SOFA directly on the GPU. Is this something doable?

leobois67 avatar Nov 14 '25 11:11 leobois67

@leobois67 thanks for the benchmark. I don't know JAX enough to answer your question. I can tell that in C++/CUDA, we manipulate a raw pointer, whether the data are on the CPU or on the GPU. We would need a way to communicate the memory location of this data to JAX somehow. Don't know if that's possible

alxbilger avatar Nov 14 '25 12:11 alxbilger

Hello ! I still have a segfault in the buildStifnessMatrix. Could someone else than me try this ? @alxbilger

bakpaul avatar Dec 04 '25 10:12 bakpaul

I ran it on CPU and also have it @bakpaul

########## SIG 11 - SIGSEGV: segfault ##########
  sofa::helper::BackTrace::sig(int)
  sofa::core::behavior::BaseForceField::buildStiffnessMatrix(sofa::core::behavior::StiffnessMatrix*)
  sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::assembleSystem(sofa::core::MechanicalParams const*)::{lambda(sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::IndependentContributors&)#1}::operator()(sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::IndependentContributors&) const
  sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::assembleSystem(sofa::core::MechanicalParams const*)
  sofa::core::behavior::BaseMatrixLinearSystem::buildSystemMatrix(sofa::core::MechanicalParams const*)
  sofa::component::odesolver::backward::EulerImplicitSolver::solve(sofa::core::ExecParams const*, double, sofa::core::TMultiVecId<(sofa::core::VecType)1, (sofa::core::VecAccess)1>, sofa::core::TMultiVecId<(sofa::core::VecType)2, (sofa::core::VecAccess)1>)
  sofa::simulation::SolveVisitor::processSolver(sofa::simulation::Node*, sofa::core::behavior::OdeSolver*)
  void sofa::simulation::Visitor::for_each<sofa::simulation::SolveVisitor, sofa::simulation::Node, sofa::simulation::NodeSequence<sofa::core::behavior::OdeSolver, false>, sofa::core::behavior::OdeSolver>(sofa::simulation::SolveVisitor*, sofa::simulation::Node*, sofa::simulation::NodeSequence<sofa::core::behavior::OdeSolver, false> const&, void (sofa::simulation::SolveVisitor::*)(sofa::simulation::Node*, sofa::core::behavior::OdeSolver*), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
  sofa::simulation::SolveVisitor::processNodeTopDown(sofa::simulation::Node*)
  sofa::simulation::Node::executeVisitorTopDown(sofa::simulation::Visitor*, std::__cxx11::list<sofa::simulation::Node*, std::allocator<sofa::simulation::Node*> >&, std::map<sofa::simulation::Node*, sofa::simulation::Node::StatusStruct, std::less<sofa::simulation::Node*>, std::allocator<std::pair<sofa::simulation::Node* const, sofa::simulation::Node::StatusStruct> > >&, sofa::simulation::Node*)
  sofa::simulation::Node::executeVisitorTopDown(sofa::simulation::Visitor*, std::__cxx11::list<sofa::simulation::Node*, std::allocator<sofa::simulation::Node*> >&, std::map<sofa::simulation::Node*, sofa::simulation::Node::StatusStruct, std::less<sofa::simulation::Node*>, std::allocator<std::pair<sofa::simulation::Node* const, sofa::simulation::Node::StatusStruct> > >&, sofa::simulation::Node*)
  sofa::simulation::Node::doExecuteVisitor(sofa::simulation::Visitor*, bool)
  sofa::simulation::DefaultAnimationLoop::solve(sofa::core::ExecParams const*, double) const
  sofa::simulation::DefaultAnimationLoop::animate(sofa::core::ExecParams const*, double) const
  sofa::simulation::DefaultAnimationLoop::step(sofa::core::ExecParams const*, double)
  sofa::simulation::node::animate(sofa::simulation::Node*, double)
  sofaglfw::SofaGLFWBaseGUI::runLoop(unsigned long)
  sofaglfw::SofaGLFWGUI::mainLoop()
  sofa::gui::common::GUIManager::MainLoop(boost::intrusive_ptr<sofa::simulation::Node>, char const*)
  __libc_start_main

@leobois67 could you please update us on the status on your machine?

hugtalbot avatar Dec 20 '25 15:12 hugtalbot

I just checked and it still works on my machine, both on CPU and GPU.

To be sure that addKToMatrix() is to blame, have you tried a version without the stiffness matrix assembly? python3 SofaPython3/examples/jax/forcefield.py --method=implicit-matrix-free python3 SofaPython3/examples/jax/forcefield.py --method=explicit

leobois67 avatar Jan 05 '26 10:01 leobois67

You can come to me next time you work on this PR, but in case I am not available, here are some other quick suggestions:

  • Try changing the number of particles to see if something changes: --particles=100 (default 1000)
  • Try one of the other options for addKToMatrix by uncommenting one of them, and commenting the current one. (You may want to lower the number of particles since the other options are slower.) One way a segfault could occur is if the matrix that addKToMatrix outputs is bigger than the one expected by the binding. I don’t see a reason for that to happen though.

leobois67 avatar Jan 05 '26 13:01 leobois67