mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-04 03:40:50 +00:00
commit paper code
This commit is contained in:
+33
-10
@@ -44,8 +44,8 @@ SUPPORTED_HF_MODELS = {
|
||||
}
|
||||
|
||||
SUPPORTED_QUANT_TYPES = {
|
||||
"arm64": ["i2_s", "tl1"],
|
||||
"x86_64": ["i2_s", "tl2"]
|
||||
"arm64": ["i2_s", "tl1", "tl2-loss"],
|
||||
"x86_64": ["i2_s", "tl2", "tl2-loss"]
|
||||
}
|
||||
|
||||
COMPILER_EXTRA_ARGS = {
|
||||
@@ -111,8 +111,10 @@ def prepare_model():
|
||||
gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
|
||||
if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
|
||||
logging.info(f"Converting HF model to GGUF format...")
|
||||
if quant_type.startswith("tl"):
|
||||
if quant_type in ["tl1", "tl2"]:
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
|
||||
elif quant_type in ["tl2-loss"]:
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "tl2", "--quant-embd", "--loss", "--outfile", model_dir + str("/ggml-model-tl2-loss.gguf")], log_step="convert_to_tl")
|
||||
else: # i2s
|
||||
# convert to f32
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
|
||||
@@ -156,11 +158,20 @@ def gen_code():
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
|
||||
if get_model_name() == "bitnet_b1_58-large":
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
|
||||
elif get_model_name() in llama3_f3_models:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
|
||||
elif get_model_name() == "bitnet_b1_58-3B":
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
@@ -172,11 +183,20 @@ def gen_code():
|
||||
sys.exit(1)
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
|
||||
if get_model_name() == "bitnet_b1_58-large":
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
elif get_model_name() in llama3_f3_models:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
|
||||
elif get_model_name() == "bitnet_b1_58-3B":
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -192,7 +212,10 @@ def compile():
|
||||
logging.error(f"Arch {arch} is not supported yet")
|
||||
exit(0)
|
||||
logging.info("Compiling the code using CMake.")
|
||||
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command(["cmake", "-B", "build", "-DBITNET_TL2_LOSS=ON", *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
else:
|
||||
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
# run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
|
||||
run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user