Codegen all.dim
Fixes https://github.com/pytorch/xla/issues/3860
Codegen all.dim
LazyIr.h
class AllDim : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::all);
}
AllDim(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::all),
{self}, std::move(shapes),
[&]() { return AllDimOutputShape(self, dim, keepdim); },
/* num_outputs */ 1,
torch::lazy::MHash(dim, keepdim)),
dim(dim),
keepdim(keepdim)
{
}
std::string ToString() const override {
std::stringstream ss;
ss << XlaNode::ToString();
ss << ", dim=" << dim;
ss << ", keepdim=" << keepdim;
return ss.str();
}
bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim) const {
return false;
}
torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
int64_t dim;
bool keepdim;
};
XLANativeFunctions.cpp:
at::Tensor XLANativeFunctions::all(const at::Tensor & self, int64_t dim, bool keepdim) {
XLA_FN_COUNTER("xla::");
auto common_device = torch_xla::bridge::GetXlaDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<AllDim>(lazy_self->GetIrValue(), dim, keepdim);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::meta::all(self_meta, dim, keepdim);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()){
std::vector<torch::jit::IValue> inputs = { self, dim, keepdim };
const char* schema_str = "aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<AllDim>(lazy_self->GetIrValue(), dim, keepdim, std::move(shapes));
CacheNode(node);
}
auto result = torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
return result;
};
Hm seems like we can't codegen all.dim yet, the generated LazyIr.h for all.dim looks like:
class AllDim : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::all);
}
AllDim(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim, std::vector<torch::lazy::Shape>&& shapes)
: XlaNode(torch::lazy::OpKind(at::aten::all),
{self}, std::move(shapes),
[&]() { return AllDimOutputShape(self); },
/* num_outputs */ 1,
torch::lazy::MHash(dim, keepdim)),
dim(dim),
keepdim(keepdim)
The generated shape function is AllDimOutputShape(self) but it requires other inputs const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim to infer the output shape.
https://github.com/pytorch/xla/blob/master/scripts/gen_lazy_tensor.py#L47 controls what got passed to the shape function. I don't think we want to pass all int64 and bool to the shape fn because in many times they don't affect shapes. One thing we can try is to add
a.name == dim or a.name == keepdim
to the check.. I am slightly worried that codegen become too specified and hard to maintain but this seems like the fastest way to unblock this pr.
actually in https://github.com/pytorch/xla/pull/3771/files I already make bool to be pass to all shape fns
https://github.com/pytorch/xla/pull/3771/files is merged, I think you just need to handle dim now
@JackCaoG, this should be ready for review. Thanks!