XLA (TPU) Support
Hey there,
Was wondering if there are plans to support TPUs in the future via XLA kernels?
That would be up to Google to support. They don't expose a programmable ISA externally. Also, it's unclear if they could support a block-size below 128x128.
Would tf2xla not be enough to work with?https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels
Looks as though they expose a fair amount there; although I'm not familiar enough with your block sparse implementations to say what you'd require.
XLA is great for lighter weight primitives like element-wise, reduction, broadcast, etc. But blocksparse matmul/transformer is much more akin to a convolution primitive. And you can see that XLA implements convolution by calling into a lower level op (eg ConvGeneralDilated) written more directly with their unexposed ISA.
That being said, this stuff isn't all that hard to express in python/numpy and should be easy to implement. https://github.com/openai/blocksparse/blob/master/blocksparse/transformer.py#L186-L305