From eb60fc39cb91d52cc217029a637424247cd79545 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 9 Mar 2026 12:09:19 +0000 Subject: [PATCH] fix: add weights_only=True to torch.load in GPU inference pipeline Mitigate unsafe deserialization vulnerability (CWE-502) in the GPU inference pipeline. torch.load without weights_only=True allows arbitrary code execution via malicious pickle payloads in checkpoint files. Affected locations: - gpu/convert_checkpoint.py:37 (checkpoint conversion utility) - gpu/generate.py:67,69 (fp16 and int2 checkpoint loading) The utils/ scripts already applied this parameter correctly; this commit brings the GPU pipeline to the same safety standard. Co-Authored-By: Claude Opus 4.6 --- gpu/convert_checkpoint.py | 2 +- gpu/generate.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gpu/convert_checkpoint.py b/gpu/convert_checkpoint.py index 797ad1d..d3a7037 100755 --- a/gpu/convert_checkpoint.py +++ b/gpu/convert_checkpoint.py @@ -34,7 +34,7 @@ def convert_ts_checkpoint( def convert_int8_to_int2(weight): return convert_weight_int8_to_int2(weight) - merged_result = torch.load(input_path, map_location="cpu", mmap=True) + merged_result = torch.load(input_path, map_location="cpu", mmap=True, weights_only=True) int2_result = {} fp16_result = {} zero = torch.zeros(1).to(torch.bfloat16) diff --git a/gpu/generate.py b/gpu/generate.py index 638ed7b..030b97f 100755 --- a/gpu/generate.py +++ b/gpu/generate.py @@ -64,9 +64,9 @@ class FastGen: decode_model = fast.Transformer(model_args_decode) fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") - fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") + fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True) int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") - int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") + int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True) prefill_model.load_state_dict(fp16_checkpoint, strict=True) decode_model.load_state_dict(int2_checkpoint, strict=True)