fix: add weights_only=True to torch.load in GPU inference pipeline

Mitigate unsafe deserialization vulnerability (CWE-502) in the GPU
inference pipeline. torch.load without weights_only=True allows
arbitrary code execution via malicious pickle payloads in checkpoint
files.

Affected locations:
- gpu/convert_checkpoint.py:37 (checkpoint conversion utility)
- gpu/generate.py:67,69 (fp16 and int2 checkpoint loading)

The utils/ scripts already applied this parameter correctly; this
commit brings the GPU pipeline to the same safety standard.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Ubuntu
2026-03-09 12:09:19 +00:00
parent 8fd3412fbc
commit eb60fc39cb
2 changed files with 3 additions and 3 deletions
+1 -1
View File
@@ -34,7 +34,7 @@ def convert_ts_checkpoint(
def convert_int8_to_int2(weight):
return convert_weight_int8_to_int2(weight)
merged_result = torch.load(input_path, map_location="cpu", mmap=True)
merged_result = torch.load(input_path, map_location="cpu", mmap=True, weights_only=True)
int2_result = {}
fp16_result = {}
zero = torch.zeros(1).to(torch.bfloat16)
+2 -2
View File
@@ -64,9 +64,9 @@ class FastGen:
decode_model = fast.Transformer(model_args_decode)
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True)
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True)
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
decode_model.load_state_dict(int2_checkpoint, strict=True)