GPU support in forecast
Description
All forecast models were using the CPU exclusively, so I added GPU support for the models that can benefit from that:
- brnn
- nbeats
- nhits
- rnn
- tft
- trans
CPU is still used by default. GPU is used only if the user sets --use-gpu flag and GPU availability check passes.
Dependency wise, no change is required, but the default torch installation only supports CUDA capabilities sm_37 sm_50 sm_60 sm_70. In order to run on newer architectures, the user will have to manually re-install the appropriate version as described here:
pip uninstall torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
Note: tcn model should be able to support GPU acceleration, but I had to omit it due to the same bug reported in #2732.
- [X] Summary of the change / bug fix.
- [X] Link # issue, if applicable.
- [ ] Screenshot of the feature or the bug before/after fix, if applicable.
- [X] Relevant motivation and context.
- [X] List any dependencies that are required for this change.
How has this been tested?
I tested it by running all forecast methods in the terminal with and without --use-gpu flag.
- Please describe the tests that you ran to verify your changes.
- Provide instructions so we can reproduce.
- Please also list any relevant details for your test configuration.
- [X] Make sure affected commands still run in terminal
- [X] Ensure the api still works
- [ ] Check any related reports
Checklist:
- [X] Update our Hugo documentation following these guidelines.
- [ ] Update our tests following these guidelines.
- [X] Make sure you are following our CONTRIBUTING guidelines.
- [ ] If a feature was added make sure to add it to the corresponding scripts file.
Others
- [X] I have performed a self-review of my own code.
- [X] I have commented my code, particularly in hard-to-understand areas.
@husmen this is awesome!! @martinb-bb what do you think about this implementation?
@husmen Thanks for the contribution! This is a great start. 😄
We cannot merge just yet, as we must wait for the current installer to be released before our 2.0 launch within the next 2 weeks. 🚀
In the meantime, I will test and work diligently with you to have this merged in the next couple of weeks. We should also consider updating the installer to install GPU support if the user has one automatically. (something to think about)
Just contributing to the discussion
@husmen, I few months ago I've been testing MPS (Apple's new Metal Performance Shaders) for both training and inference. My findings are that for the models that were in the terminal at that time the performance boost was negative (computation on MPS was over 5 times slower than on CPU) because the models and the batches are not big enough to actually make use of the parallelization that is offered by MPS. I assume it's the same with CUDA. I have to note that this was before the forecasting menu has been built and the prediction menu models used tensorflow, not pytorch.
Did you do some benchmark test that compares execution on CPU to execution on GPU? It would be great to see those numbers if you have them
@piiq, Actually yes, I can confirm that running on CUDA was slower than on CPU. I will provide the numbers when I find the time in few days, but if I remember correctly the slowdown was not as dramatic as 5 times. That being said, I still think it will be useful to offer the feature for anyone who wants to play with larger datasets and tune the model arguments accordingly. I might be able to test that as well.
Hi @martinb-bb, is this PR still relevant with the current version of the forecast menu? Thanks
@Chavithra it's relevant but we are not going to support GPU for now until we update install docs to add support for GPU enabled libraries.
until we update install docs to add support for GPU enabled libraries
The documentation update can be a part of this PR? Can you specify what is needed?
- We will need to bring in docs for installing GPU torch support. Users will be instructed to know which GPUs work (only Nvidia) along with grabbing the right version.
- Also, it is important to test what the min and max versions of this GPU version are able to run on Darts.
Both of these should be added to this PR. See https://pytorch.org/get-started/locally/ - We will need to have options on CUDA compatibilities and also OS variations.
My only reservations are that I am not sure how this will affect the installer when we package with torch GPU version. Since we will then need to have 2 installers - with/without GPU support.