Ping Zhu
Ping Zhu
pytorch 这段代码是不报错的,oneflow 在 InferDataType 似乎默认了有浮点操作时 x 只能是 double,是符合预期的就没问题 https://github.com/Oneflow-Inc/oneflow/blob/1307edfdfe8750e1be46a8b827166df3e335c6e6/oneflow/user/ops/where_op.cpp#L264-L280
> ``` > import torch > x = torch.randn(5, 5) > y = torch.where(x > 0, x, 0.0) > Traceback (most recent call last): > File "", line 1, in...
> pytorch 这段代码是不报错的,oneflow 在 InferDataType 似乎默认了有浮点操作时 x 只能是 double,是符合预期的就没问题 > > https://github.com/Oneflow-Inc/oneflow/blob/1307edfdfe8750e1be46a8b827166df3e335c6e6/oneflow/user/ops/where_op.cpp#L264-L280 这个好像还是有点不对,python 调用 `where(condition, x, y)`, y 是 scalar 的时候其实隐含了 y 是 double,因为 python 没有 float32 类型,所以判断 x 是不是浮点数只判断...
I encountered the same problem. I think it caused by triton version, try install triton-nightly(2.1.0), it works for me. ```bash pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly ``` or install from...