update readme and setup script to support official BitNet b1.58 model (#171)

* update readme and setup file for new model.

* update model file name

---------

Co-authored-by: Yan Xia <yanxia@microsoft.com>
This commit is contained in:
Yan Xia
2025-04-15 14:53:56 +08:00
committed by GitHub
parent fa854cf8f8
commit fd3f355a0b
3 changed files with 48 additions and 9 deletions
+8 -1
View File
@@ -41,6 +41,9 @@ SUPPORTED_HF_MODELS = {
"tiiuae/Falcon3-1B-Instruct-1.58bit": {
"model_name": "Falcon3-1B-Instruct-1.58bit",
},
"microsoft/BitNet-b1.58-2B-4T": {
"model_name": "BitNet-b1.58-2B-4T",
},
}
SUPPORTED_QUANT_TYPES = {
@@ -161,6 +164,8 @@ def gen_code():
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
elif get_model_name() == "bitnet_b1_58-3B":
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
elif get_model_name() == "BitNet-b1.58-2B-4T":
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
else:
raise NotImplementedError()
else:
@@ -177,6 +182,8 @@ def gen_code():
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
elif get_model_name() == "bitnet_b1_58-3B":
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
elif get_model_name() == "BitNet-b1.58-2B-4T":
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
raise NotImplementedError()
@@ -222,4 +229,4 @@ if __name__ == "__main__":
args = parse_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
logging.basicConfig(level=logging.INFO)
main()
main()