perf: optionally use `torch_scatter.segment_coo` for feature aggregation
torch_scatter.segment_coo performs the same operation as index_add of native pytorch. Here, we can leverage the sorted trait of the index. My test shows an ~25% speed-up on DPA3 dynamic sel models (0.56s vs 0.76s per step, 24 thin).
In contrast to scatter(), this method expects values in index to be sorted along dimension index.dim() - 1. Due to the use of sorted indices, segment_coo() is usually faster than the more general scatter() operation.
I evaluated the accuracy impact of segment_coo with index_add, using index_add with FP64 as the baseline. Under FP64, the reduced data can pass torch.allclose check with atol=1e-8, and atol=1e-4 for FP32. The MAE of segment_coo is better compared with index_add (x0.3-0.7, depends on the input data). Performing calculation under FP32 would not have a significant impact on the accuracy.
It is also possible to use scatter_reduce for this operation, and this is slightly faster than the original implementation (0.70s/step). However, this requires src.shape == index.shape for computing backward pass, so an additional expanding step is required: owners = owners.unsqueeze(-1).expand(-1,data.shape[1]).
Summary by CodeRabbit
-
Bug Fixes
- Improved aggregation performance and compatibility by supporting optional use of the
torch_scatterlibrary, with automatic fallback to native PyTorch operations if unavailable.
- Improved aggregation performance and compatibility by supporting optional use of the
-
Refactor
- Updated internal logic for aggregation and removed a decorator to enhance flexibility and maintainability.
📝 Walkthrough
Walkthrough
The code in deepmd/pt/model/network/utils.py has been updated to optionally use the torch_scatter library for aggregation operations within the aggregate function. The function now checks for torch_scatter's availability and uses it if present, otherwise defaults to the original PyTorch-based implementation. The @torch.jit.script decorator was removed.
Changes
| File(s) | Change Summary |
|---|---|
| deepmd/pt/model/network/utils.py | Added conditional import and usage of torch_scatter in aggregate; extended function signature with use_torch_scatter parameter; removed @torch.jit.script decorator. |
Sequence Diagram(s)
sequenceDiagram
participant Caller
participant utils.py
participant torch_scatter
Caller->>utils.py: aggregate(data, owners, average, num_owner, use_torch_scatter)
alt use_torch_scatter and not scripting
utils.py->>torch_scatter: segment_coo(data, owners, num_owner, reduce)
torch_scatter-->>utils.py: aggregated_output
utils.py-->>Caller: aggregated_output
else
utils.py->>utils.py: aggregate using index_add_ and bincount
utils.py-->>Caller: aggregated_output
end
[!WARNING] There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.
🔧 Pylint (3.3.7)
deepmd/pt/model/network/utils.py
No files to lint: exiting.
📜 Recent review details
Configuration used: CodeRabbit UI Review profile: CHILL Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 1808972814ca06a5b859ed83614160fd9189262d and 8f5a1c55e3c19e5e927ef86008d048cf4f7439ab.
📒 Files selected for processing (1)
-
deepmd/pt/model/network/utils.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/network/utils.py
⏰ Context from checks skipped due to timeout of 90000ms (13)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
✨ Finishing Touches
- [ ] 📝 Generate Docstrings
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
🪧 Tips
Chat
There are 3 ways to chat with CodeRabbit:
- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
-
I pushed a fix in commit <commit_id>, please review it. -
Explain this complex logic. -
Open a follow-up GitHub issue for this discussion.
-
- Files and specific lines of code (under the "Files changed" tab): Tag
@coderabbitaiin a new review comment at the desired location with your query. Examples:-
@coderabbitai explain this code block. -
@coderabbitai modularize this function.
-
- PR comments: Tag
@coderabbitaiin a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:-
@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase. -
@coderabbitai read src/utils.ts and explain its main purpose. -
@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format. -
@coderabbitai help me debug CodeRabbit configuration file.
-
Support
Need help? Create a ticket on our support page for assistance with any issues or questions.
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.
CodeRabbit Commands (Invoked using PR comments)
-
@coderabbitai pauseto pause the reviews on a PR. -
@coderabbitai resumeto resume the paused reviews. -
@coderabbitai reviewto trigger an incremental review. This is useful when automatic reviews are disabled for the repository. -
@coderabbitai full reviewto do a full review from scratch and review all the files again. -
@coderabbitai summaryto regenerate the summary of the PR. -
@coderabbitai generate docstringsto generate docstrings for this PR. -
@coderabbitai generate sequence diagramto generate a sequence diagram of the changes in this PR. -
@coderabbitai resolveresolve all the CodeRabbit review comments. -
@coderabbitai configurationto show the current CodeRabbit configuration for the repository. -
@coderabbitai helpto get help.
Other keywords and placeholders
- Add
@coderabbitai ignoreanywhere in the PR description to prevent this PR from being reviewed. - Add
@coderabbitai summaryto generate the high-level summary at a specific location in the PR description. - Add
@coderabbitaianywhere in the PR title to generate the title automatically.
CodeRabbit Configuration File (.coderabbit.yaml)
- You can programmatically configure CodeRabbit by adding a
.coderabbit.yamlfile to the root of your repository. - Please see the configuration documentation for more information.
- If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation:
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
Documentation and Community
- Visit our Documentation for detailed information on how to use CodeRabbit.
- Join our Discord Community to get help, request features, and share feedback.
- Follow us on X/Twitter for updates and announcements.
Codecov Report
Attention: Patch coverage is 81.81818% with 2 lines in your changes missing coverage. Please review.
Project coverage is 84.57%. Comparing base (
ab6e300) to head (1808972). Report is 17 commits behind head on devel.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| deepmd/pt/model/network/utils.py | 81.81% | 2 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## devel #4813 +/- ##
==========================================
- Coverage 84.80% 84.57% -0.23%
==========================================
Files 698 699 +1
Lines 67798 68043 +245
Branches 3542 3540 -2
==========================================
+ Hits 57494 57547 +53
- Misses 9171 9361 +190
- Partials 1133 1135 +2
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
- :package: JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.