torchstain icon indicating copy to clipboard operation
torchstain copied to clipboard

Vahadane numpy backend

Open andreped opened this issue 3 years ago • 8 comments

I have reproduced the Vahadane implementation from StainTools. The implementation is with numpy backend, but there are two steps that remains for this to work with pure numpy.

Hence, this is a draft, where we can discuss where to go next.

The two challenges are:

  • Directionary learning step
  • LARS-LASSO step

Scikit-learn offer similar implementations as the one used by spams, but when swapping them, I was unable to get the same results. Hence, right now, the current implementation depends on spams.

This will also be a problem with both TensorFlow and PyTorch backends, as it might be challenging to implement these from scratch in an optimized manner with both backends.

Adding spams itself is not that challenging. However, pip-installing the original project has deemed challenging in several projects, which I have used spams. Luckily, there is someone who has built precompiled wheels which makes installation seemless: https://github.com/samuelstjean/spams-python/releases/tag/v2.6.4

I have not yet added spams as a dependency, and will wait for further instructions.

andreped avatar Dec 22 '22 21:12 andreped

Can also be noted that @leengit and the devs at @Kitware, have done a great job of adding support for the Vahadane algorithm to itk-spcn. Tutorial on how to use it can be found here.

If we ever get some kind of backend which supports it, we should benchmark against their implementation.

andreped avatar Jan 29 '23 12:01 andreped

Hi, andreped. I also tried Vahadane for stain norm using staintools. But it seems that the current implementation of Vahadane cannot work well with torch's DataLoader when setting num_works>0 (use multi-threads to load data).

I watched the code of Vahadane implementation, and found that the problem lies in that the Spams which Vahadane implementation requires also uses multi-threads.

I also noticed that this problem can be solved by using DataLoader with num_works=0 or setting numThreads=1 in every spams invoking (please see Issue #43 in staintools), but this leads to an extremely-slow data loading for large WSIs.

So, I wonder, if there is any appropriate solution to this problem. Thank you so much.

liupei101 avatar Feb 06 '23 13:02 liupei101

Hello, @liupei101!

As I guess you saw from my comment above, modifying this implementation such that it only depends on PyTorch (hence, removing the spams backend), is not trivial. Right now, it is outside the scope of what I have time for, as I have mostly participated in this project in my spare time.

As you are aware, Vahadane is not a fast algorithm. It is made faster through parallel computing. That you observed runtime degradation by removing the threading within a single patch to enable more patches to be normalized in parallel, is not really that surprising, as I believe the main bottleneck comes within the spams computation itself, which is done per image. Having threads on top of threads also rarely works well in practice. Some times doing multiprocessing instead to process patches in parallel works better, but in general that tend to run into deadlocks, especially if the goal is to normalize batches of patches on the fly during training. Hence, I'm not sure how one could further optimize the Vahadane implementation, unless the algorithm itself worked directly on batches, and thus supported a batch dimension. But due to the complex nature of the algorithm, I don't see how to do that easily.

If I were you, I would try out the Vahadane implementation in itk-spcn (for tutorial see here).

AFAIK, they have implemented everything from scratch using only ITK in C++ without the spams backend. The code has then been wrapped to be accessible from Python. If that's the case, it might be that you are able to take advantage of PyTorch's multithreaded dataloader, to normalize patches in parallel. But then again, might be that you run into the exact same issue. Might also be that it works, but that the normalization runtime does not linearly decrease with more workers.

If you were to make an attempt, it would be very interesting for me as well. Would be great if you could share a gist, such that I or others could debug this further.


EDIT: BTW, did you try to run our macenko implementation in a similar batch processing pipeline? If you have a gist that demonstrates this, it could be an idea to add that to the README, to show people how to do the same easily.

andreped avatar Feb 06 '23 15:02 andreped

Regarding multiprocessing, the LARS-LASSO part of the Vahadane algorithm can work directly on patches. The dictionary learning step requires the whole image, but this step is relatively faster than the LARS-LASSO.

eyedealism avatar Jan 04 '24 20:01 eyedealism

Regarding multiprocessing, the LARS-LASSO part of the Vahadane algorithm can work directly on patches. The dictionary learning step requires the whole image, but this step is relatively faster than the LARS-LASSO.

@Robot-Eyes Thats probably true. But the main issue here is that implementing the entire Vahadane algorithm in pure pytorch/tf is much more difficult and tedious than we thought. I made an attempt, but I failed to achieve the expected results in some of the early steps of the algorithm already, so I did not venture any further.

One issue was that I tried to go away from the old LARS LASSO step which was not performed in pure python by using scipy's implementation, or something like that. But the results were very different. If they were the same or similar, I would base the pytorch implementation on it by converting submodules.

As this did not work, we did not merge this solution as part of torchstain, as we do not wish to support this legacy backend.

andreped avatar Jan 05 '24 08:01 andreped

Thanks for your reply.

I saw in your commits you printed the lasso results from scipy and spams to compare. I used to use staintools and I know the staintools implementation has its own problem here: the default dictionary learning in spams takes all available threads and runs for 1s no matter how large the image is and how fast your CPUs are. If you run it on the same image, the results will be different every time. Instead, in scipy by default it takes 1000 iterations and a single thread, which I believe is more reasonable. If you have tested it on a small image and a fast computer, then spams probably ran many more iterations than 1000. Maybe one could fix the iteration steps and see whether it helps.

eyedealism avatar Jan 05 '24 15:01 eyedealism

Yes, I also used staintools and had great success with Vahadane, so finding ways to optimize it would be of great value. On way is to go to pytorch/tensorflow to enable GPU-compute, but this can further be accelerated, and greatly so with GPUs, if compute can be performed in parallel, where relevant.

I stopped implementing support for Vahadane with the various backbones, as I failed to see how I could get the same or similar results using the implementation available in the other Python library (not spams). There are also a lot of other stuff that needs translating to torch/tf, hence, through a fast-fail principle, I chose to stop there.

But maybe you are interested in making a try? If you manage to get torch implementation working, making an equivalent tf/numpy-implementation is trivial. Are you interested, @Robot-Eyes? You can start from my branch, if you'd like.

andreped avatar Jan 06 '24 15:01 andreped

Yes, I could give it a try later.

eyedealism avatar Jan 07 '24 02:01 eyedealism