ChatGLM-6B icon indicating copy to clipboard operation
ChatGLM-6B copied to clipboard

[ONNX格式转换] <目前该模型是否支持模型转换> 现在想试着转一下到ONNX格式,但是不知道张量那块怎么填,有没有转过的,想学习一下

Open Wukuku opened this issue 2 years ago • 1 comments

Is your feature request related to a problem? Please describe.

https://www.zhihu.com/org/mnntuan-dui 用这个知乎里的张量不知道填什么, 感觉值不对 `def model_export(model, model_args: tuple, output_path: str, ordered_input_names, output_names, dynamic_axes, opset): from torch.onnx import export export(model, model_args, f=output_path, input_names=ordered_input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=True, opset_version=opset, verbose=False )

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, resume_download=True).float().cpu() model_export(model, model_args=( torch.randn(4, 1, 4096), torch.tensor([[[[False, False, False, True], [False, False, False, True], [False, False, False, True], [False, False, False, False]]]]), torch.tensor([[[0, 1, 2, 3], [0, 0, 0, 1]]]), torch.zeros(2, 0, 1, 32, 128) ), output_path= "dyn_model/glm_block_{}.onnx".format(sys.argv[1]), ordered_input_names=["inputs_embeds", "attention_mask", "position_ids", "past_key_values"], output_names=["hidden_states", "presents"], dynamic_axes={ "inputs_embeds" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, "position_ids" : { 2: "seq_len" }, "past_key_values" : { 1: "history_len" } }, opset= 14)`

Solutions

1.具体导出ONNX的例子? 2.张量如何定义?

Additional context

No response

Wukuku avatar Jun 20 '23 15:06 Wukuku

+1

Crazybean-lwb avatar Jun 14 '24 09:06 Crazybean-lwb