qihqi

Results 42 comments of qihqi

afaik we still need quant_dequant_converter @lsy323

Hi, sorry for the delay. Just tried on HEAD and your example is passing. My guess is that is is fixed after https://github.com/pytorch/xla/pull/6460 So we introduces f64 when scalars are...

cc. @gmagogsfm @angelayi @zhxchen17 We chatted briefly on this topic and agreed that is desirable. Although there is no timeline for it. Feel free to update about the future plan....

HI @kiratp few questions: On 1. prompt_token_counts would be the same behavior as per-request max_output_len? On 2> logprobs is a boolean arg on input to signify to return the logprobs...

Hi, You can try this in torchax: (https://github.com/pytorch/xla/tree/master/torchax) ```python import torch as torch import torchax torchax.enable_globally() device = 'jax' diff = torch.randn(8,3,1,requires_grad=True,device=device) A = torch.randn(8,251,3,requires_grad=True,device=device) B = torch.randn(8,251,1,requires_grad=True,device=device) C =...

> [@qihqi](https://github.com/qihqi) thank you again! I can give this a try today. Other than the install instructions you provided in that link, are there any other differences to using jax...

> Regarding the training of the model: if I use the code above - you are recommending that I adjust the entire train paradigm we are using to use that...

Hi @ttdd11, I will further debug what is going on with the environment. Meanwhile, we landed a new feature `assume_pure` in nightly and I tried it it can solve the...