Files
BitNet/gpu/convert_checkpoint.py
T
Ubuntu eb60fc39cb fix: add weights_only=True to torch.load in GPU inference pipeline
Mitigate unsafe deserialization vulnerability (CWE-502) in the GPU
inference pipeline. torch.load without weights_only=True allows
arbitrary code execution via malicious pickle payloads in checkpoint
files.

Affected locations:
- gpu/convert_checkpoint.py:37 (checkpoint conversion utility)
- gpu/generate.py:67,69 (fp16 and int2 checkpoint loading)

The utils/ scripts already applied this parameter correctly; this
commit brings the GPU pipeline to the same safety standard.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 12:09:19 +00:00

101 lines
3.9 KiB
Python
Executable File

import json
import os
import re
import sys
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
import torch
from einops import rearrange
from safetensors.torch import save_file
import model
from pack_weight import convert_weight_int8_to_int2
@torch.inference_mode()
def convert_ts_checkpoint(
*,
input_path: str = "",
) -> None:
config = model.ModelArgs()
print(f"Model config {config.__dict__}")
def quant_weight_int8(weight):
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8)
new_scale = (1.0 / s).to(torch.bfloat16)
return new_weight, new_scale.reshape(1)
def quant_weight_fp16(weight):
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
new_weight = (weight * s).round().clamp(-1, 1) / s
return new_weight
def convert_int8_to_int2(weight):
return convert_weight_int8_to_int2(weight)
merged_result = torch.load(input_path, map_location="cpu", mmap=True, weights_only=True)
int2_result = {}
fp16_result = {}
zero = torch.zeros(1).to(torch.bfloat16)
for key, value in merged_result.items():
if 'wqkv' in key:
wq = value[:config.dim]
wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim]
wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:]
wq_weight, wa_scale = quant_weight_int8(wq)
wk_weight, wb_scale = quant_weight_int8(wk)
wv_weight, wc_scale = quant_weight_int8(wv)
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
int2_result[key] = convert_int8_to_int2(wqkv_weight)
int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale
wq_weight = quant_weight_fp16(wq)
wk_weight = quant_weight_fp16(wk)
wv_weight = quant_weight_fp16(wv)
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
fp16_result[key] = wqkv_weight
elif 'w13' in key:
w1 = value[:config.ffn_dim]
w3 = value[config.ffn_dim:]
w1_weight, w1_scale = quant_weight_int8(w1)
w3_weight, w3_scale = quant_weight_int8(w3)
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(w13_weight)
int2_result[key.replace('weight', 'weight_scale')] = w13_scale
w1_weight = quant_weight_fp16(w1)
w3_weight = quant_weight_fp16(w3)
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
fp16_result[key] = w13_weight
elif 'w2' in key or 'wo' in key:
weight, scale = quant_weight_int8(value)
scale = torch.cat([scale, zero, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(weight)
int2_result[key.replace('weight', 'weight_scale')] = scale
weight = quant_weight_fp16(value)
fp16_result[key] = weight
else:
int2_result[key] = value.clone()
fp16_result[key] = value.clone()
output_dir = os.path.dirname(input_path)
print(f"Saving checkpoint to {output_dir}/model_state_int2.pt")
torch.save(int2_result, f"{output_dir}/model_state_int2.pt")
print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt")
torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.')
parser.add_argument('--input', type=str)
args = parser.parse_args()
convert_ts_checkpoint(
input_path=args.input,
)