deepmd-kit icon indicating copy to clipboard operation
deepmd-kit copied to clipboard

perf: optionally use `torch_scatter.segment_coo` for feature aggregation

Open caic99 opened this issue 10 months ago • 2 comments

torch_scatter.segment_coo performs the same operation as index_add of native pytorch. Here, we can leverage the sorted trait of the index. My test shows an ~25% speed-up on DPA3 dynamic sel models (0.56s vs 0.76s per step, 24 thin).

In contrast to scatter(), this method expects values in index to be sorted along dimension index.dim() - 1. Due to the use of sorted indices, segment_coo() is usually faster than the more general scatter() operation.

I evaluated the accuracy impact of segment_coo with index_add, using index_add with FP64 as the baseline. Under FP64, the reduced data can pass torch.allclose check with atol=1e-8, and atol=1e-4 for FP32. The MAE of segment_coo is better compared with index_add (x0.3-0.7, depends on the input data). Performing calculation under FP32 would not have a significant impact on the accuracy.


It is also possible to use scatter_reduce for this operation, and this is slightly faster than the original implementation (0.70s/step). However, this requires src.shape == index.shape for computing backward pass, so an additional expanding step is required: owners = owners.unsqueeze(-1).expand(-1,data.shape[1]).

Summary by CodeRabbit

  • Bug Fixes

    • Improved aggregation performance and compatibility by supporting optional use of the torch_scatter library, with automatic fallback to native PyTorch operations if unavailable.
  • Refactor

    • Updated internal logic for aggregation and removed a decorator to enhance flexibility and maintainability.

caic99 avatar Jun 20 '25 05:06 caic99

📝 Walkthrough

Walkthrough

The code in deepmd/pt/model/network/utils.py has been updated to optionally use the torch_scatter library for aggregation operations within the aggregate function. The function now checks for torch_scatter's availability and uses it if present, otherwise defaults to the original PyTorch-based implementation. The @torch.jit.script decorator was removed.

Changes

File(s) Change Summary
deepmd/pt/model/network/utils.py Added conditional import and usage of torch_scatter in aggregate; extended function signature with use_torch_scatter parameter; removed @torch.jit.script decorator.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant utils.py
    participant torch_scatter

    Caller->>utils.py: aggregate(data, owners, average, num_owner, use_torch_scatter)
    alt use_torch_scatter and not scripting
        utils.py->>torch_scatter: segment_coo(data, owners, num_owner, reduce)
        torch_scatter-->>utils.py: aggregated_output
        utils.py-->>Caller: aggregated_output
    else
        utils.py->>utils.py: aggregate using index_add_ and bincount
        utils.py-->>Caller: aggregated_output
    end

[!WARNING] There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 Pylint (3.3.7)
deepmd/pt/model/network/utils.py

No files to lint: exiting.


📜 Recent review details

Configuration used: CodeRabbit UI Review profile: CHILL Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1808972814ca06a5b859ed83614160fd9189262d and 8f5a1c55e3c19e5e927ef86008d048cf4f7439ab.

📒 Files selected for processing (1)
  • deepmd/pt/model/network/utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/network/utils.py
⏰ Context from checks skipped due to timeout of 90000ms (13)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
✨ Finishing Touches
  • [ ] 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

coderabbitai[bot] avatar Jun 20 '25 05:06 coderabbitai[bot]

Codecov Report

Attention: Patch coverage is 81.81818% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.57%. Comparing base (ab6e300) to head (1808972). Report is 17 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/pt/model/network/utils.py 81.81% 2 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4813      +/-   ##
==========================================
- Coverage   84.80%   84.57%   -0.23%     
==========================================
  Files         698      699       +1     
  Lines       67798    68043     +245     
  Branches     3542     3540       -2     
==========================================
+ Hits        57494    57547      +53     
- Misses       9171     9361     +190     
- Partials     1133     1135       +2     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • :package: JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

codecov[bot] avatar Jun 23 '25 08:06 codecov[bot]