mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-03 11:20:36 +00:00
Merge pull request #421 from microsoft/fix/unsafe-deserialization-gpu-pipeline
fix: add weights_only=True to torch.load in GPU inference pipeline
This commit is contained in:
@@ -34,7 +34,7 @@ def convert_ts_checkpoint(
|
||||
def convert_int8_to_int2(weight):
|
||||
return convert_weight_int8_to_int2(weight)
|
||||
|
||||
merged_result = torch.load(input_path, map_location="cpu", mmap=True)
|
||||
merged_result = torch.load(input_path, map_location="cpu", mmap=True, weights_only=True)
|
||||
int2_result = {}
|
||||
fp16_result = {}
|
||||
zero = torch.zeros(1).to(torch.bfloat16)
|
||||
|
||||
+2
-2
@@ -64,9 +64,9 @@ class FastGen:
|
||||
decode_model = fast.Transformer(model_args_decode)
|
||||
|
||||
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
|
||||
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
|
||||
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True)
|
||||
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
|
||||
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
|
||||
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True)
|
||||
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
|
||||
decode_model.load_state_dict(int2_checkpoint, strict=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user