oneflow
oneflow copied to clipboard
OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.
修复 set_acc_grad 后,再 backward 梯度没有正确累加的问题。 ```python # import oneflow as flow import torch as flow value = flow.tensor([[-0.0875, -0.4890, 0.9031], [ 0.4930, -0.6041, -1.5392]]).requires_grad_() value.grad = flow.randn((2, 3)) output =...
- 系统解决目前不支持非contiguous的算子的非contiguous输入inplace计算的问题 比如add不支持non-contiguous,那么对于non-contiguous的输入tensor是不允许使用inplace操作的,比如`a += 1`。但该PR会自动将`a += 1`重写为`b = a + 1; a = b`。 Fixes OneFlow-Inc/OneTeam#1621 Fixes OneFlow-Inc/OneTeam#1627
T5 + DP 4 - AddTaskIntoPlan 总时间从 1021ms 减少到 (51ms+ 337ms),减少了 62% - InferMemBlockId4MemReusedRegst 中的 GenRegstAllocFreeTimeLineAndRegstMutualExclusions 从 113 milliseconds 降低到 41 milliseconds【这部分增加了并行,降低了一半】 - 内存共享算法部分,已经是分设备并行计算的
- [x] 使用`ep::primitive::BroadcastElementwiseUnary`重构cast kernel,同时支持stride - [x] 支持half、nv_bfloat16类型 - [x] fix cast op不支持0-size tensor的bug - [x] test case
### This PR is done: Refine @Alive1024 的 flow.to_global/.to_local 增加多种输入类型的输入支持(包括 state dict)的 PR,增加 dict_to_global op,支持按 s 分片来 save 和 load 大模型,并推进合并。 - [x] Fork branch: https://github.com/Oneflow-Inc/oneflow/pull/8091 - [x] Resolve conflict...
背景:https://github.com/Oneflow-Inc/OneCloud/issues/70#issuecomment-1077397584 概述:module 的 `load_state_dict` 报错信息不太友好,并且没有检查 checkpoint 和 model 参数的 `is_global` 是否匹配,导致用户的 global model 加载 local state_dict 时,报错会被后面的 try except 捕捉到,返回一些无法提炼重要信息的错误内容  实现:在 module.py 的 `load_state_dict` 遍历加载参数时,增加了两者间 `is_global` 是否匹配的检查,防止在后面的 try except...
背景:https://github.com/Oneflow-Inc/oneflow/issues/8841 问题概述:缺少 `pairwise_distance` 算子 实现:Functor 层调用 `norm(x1 - x2, p=p)` torch实现:https://github.com/pytorch/pytorch/blob/6a09847c42bf7d33ba0aea5b083eebd846661ce1/aten/src/ATen/native/Distance.cpp#L16-L23
为oneflow/python/oneflow/test/modules下的文件 test_abs.py 和test_activate.py 中的算子增加性能测试profile
## Code to reproduce bug ``` import oneflow as flow x = flow.randn(5, 5) y = flow.where(x > 0, x, 0.0) ``` Output: ``` RuntimeError: Check failed: (x_dtype) == (GetDataType::value)...