Files
BitNet/gpu/convert_safetensors.py
T
2025-05-19 04:34:00 +00:00

116 lines
4.1 KiB
Python

import re
import torch
from pathlib import Path
from safetensors.torch import load_file
from einops import rearrange
from dataclasses import dataclass
from typing import Optional
transformer_configs = {
"2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912),
}
@dataclass
class ModelArgs:
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
config = [k for k in transformer_configs if k in name.upper() or k in name]
assert len(config) == 1, f"Unknown model name: {name}"
return cls(**transformer_configs[config[0]])
def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2)
def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2)
def convert_back(
safetensors_path: str,
output_file: str,
model_name: Optional[str] = None,
):
st_dict = load_file(safetensors_path)
cfg = ModelArgs.from_name(model_name)
print(f"Using model configurations: {cfg}")
recovered: dict = {}
for layer in range(cfg.n_layer):
base = f"model.layers.{layer}."
wq = st_dict[f"{base}self_attn.q_proj.weight"]
wk = st_dict[f"{base}self_attn.k_proj.weight"]
wv = st_dict[f"{base}self_attn.v_proj.weight"]
wq = invert_convert_q(wq, cfg)
wk = invert_convert_k(wk, cfg)
wqkv = torch.cat([wq, wk, wv], dim=0)
recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv
recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"]
recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"]
recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"]
recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"]
recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"]
gate = st_dict[f"{base}mlp.gate_proj.weight"]
up = st_dict[f"{base}mlp.up_proj.weight"]
w13 = torch.cat([gate, up], dim=0)
recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13
recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"]
recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"]
recovered["output.weight"] = st_dict["model.embed_tokens.weight"]
recovered["norm.weight"] = st_dict["model.norm.weight"]
print(f"Saving to {output_file}")
torch.save(recovered, output_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
parser.add_argument(
"--safetensors_file", type=str, required=True,
help="Path to input .safetensors file"
)
parser.add_argument(
"--output", type=str, default="./checkpoints/model_state.pt",
help="Path to output .pt file"
)
parser.add_argument(
"--model_name", type=str, default="2B",
help="Model configuration name to use (e.g. 2B)"
)
args = parser.parse_args()
convert_back(
safetensors_path=args.safetensors_file,
output_file=args.output,
model_name=args.model_name,
)