BitNet
BitNet copied to clipboard
About 'replace_hf.py'
Hello @kyegomez
In the inference code of huggingface_example.py, it appears that replace_hf is executed, followed immediately by inference. However, upon examining replace_hf.py, I noticed it converts linear layers to bitlinear layers and seems to declare new weights. I'm curious if there's a need for additional code to transfer the original weights to the bitlinear layers.
maybe ... like this?
def replace_linears_in_hf(
model,
):
"""
Replaces all instances of nn.Linear in the given model with BitLinear15b.
Args:
model (nn.Module): The model to modify.
Returns:
None
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
# Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
new_module = BitLinear(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None)
with torch.no_grad():
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
setattr(model, name, new_module)
else:
# Recursively apply to child modules
replace_linears_in_hf(module)
Thanks.
Stale issue message
@chyoob great idea, submit a pull request pls
Stale issue message