xla icon indicating copy to clipboard operation
xla copied to clipboard

Codegen all.dim

Open wonjoo-wj opened this issue 3 years ago • 4 comments

Fixes https://github.com/pytorch/xla/issues/3860


Codegen all.dim


LazyIr.h

class AllDim : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::all);
  }

  AllDim(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::all),
              {self}, std::move(shapes),
              [&]() { return AllDimOutputShape(self, dim, keepdim); },
              /* num_outputs */ 1,
              torch::lazy::MHash(dim, keepdim)),
        dim(dim),
        keepdim(keepdim)
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    ss << ", dim=" << dim;
    ss << ", keepdim=" << keepdim;
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  int64_t dim;
  bool keepdim;
  

};

XLANativeFunctions.cpp:

at::Tensor XLANativeFunctions::all(const at::Tensor & self, int64_t dim, bool keepdim) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<AllDim>(lazy_self->GetIrValue(), dim, keepdim);
        if (!node) {
                    auto self_meta = to_meta(self);
        auto out_meta = at::meta::all(self_meta, dim, keepdim);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, dim, keepdim };
                const char* schema_str = "aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<AllDim>(lazy_self->GetIrValue(), dim, keepdim, std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

wonjoo-wj avatar Aug 10 '22 04:08 wonjoo-wj

Hm seems like we can't codegen all.dim yet, the generated LazyIr.h for all.dim looks like:

class AllDim : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::all);
  }

  AllDim(const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::all),
              {self}, std::move(shapes),
              [&]() { return AllDimOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash(dim, keepdim)),
        dim(dim),
        keepdim(keepdim)

The generated shape function is AllDimOutputShape(self) but it requires other inputs const torch::lazy::Value& self, const int64_t& dim, const bool& keepdim to infer the output shape.

wonjoo-wj avatar Aug 10 '22 04:08 wonjoo-wj

https://github.com/pytorch/xla/blob/master/scripts/gen_lazy_tensor.py#L47 controls what got passed to the shape function. I don't think we want to pass all int64 and bool to the shape fn because in many times they don't affect shapes. One thing we can try is to add

a.name == dim or a.name == keepdim

to the check.. I am slightly worried that codegen become too specified and hard to maintain but this seems like the fastest way to unblock this pr.

JackCaoG avatar Aug 10 '22 21:08 JackCaoG

actually in https://github.com/pytorch/xla/pull/3771/files I already make bool to be pass to all shape fns

JackCaoG avatar Aug 10 '22 22:08 JackCaoG

https://github.com/pytorch/xla/pull/3771/files is merged, I think you just need to handle dim now

JackCaoG avatar Aug 11 '22 18:08 JackCaoG

@JackCaoG, this should be ready for review. Thanks!

wonjoo-wj avatar Aug 17 '22 18:08 wonjoo-wj