diff --git a/setup_env.py b/setup_env.py index 1ea456b..8440929 100644 --- a/setup_env.py +++ b/setup_env.py @@ -20,12 +20,27 @@ SUPPORTED_HF_MODELS = { "HF1BitLLM/Llama3-8B-1.58-100B-tokens": { "model_name": "Llama3-8B-1.58-100B-tokens", }, - "tiiuae/falcon3-7b-instruct-1.58bit": { - "model_name": "falcon3-7b-1.58bit", + "tiiuae/Falcon3-7B-Instruct-1.58bit": { + "model_name": "Falcon3-7B-1.58bit", + }, + "tiiuae/Falcon3-7B-1.58bit": { + "model_name": "Falcon3-7B-1.58bit", + }, + "tiiuae/Falcon3-10B-Instruct-1.58bit": { + "model_name": "Falcon3-10B-1.58bit", + }, + "tiiuae/Falcon3-10B-1.58bit": { + "model_name": "Falcon3-10B-1.58bit", + }, + "tiiuae/Falcon3-3B-Instruct-1.58bit": { + "model_name": "Falcon3-3B-1.58bit", + }, + "tiiuae/Falcon3-3B-1.58bit": { + "model_name": "Falcon3-3B-1.58bit", + }, + "tiiuae/Falcon3-1B-Instruct-1.58bit": { + "model_name": "Falcon3-1B-1.58bit", }, - "tiiuae/falcon3-7b-1.58bit": { - "model_name": "falcon3-7b-1.58bit", - } } SUPPORTED_QUANT_TYPES = { @@ -139,7 +154,7 @@ def gen_code(): shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini") if get_model_name() == "bitnet_b1_58-large": run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen") - elif get_model_name() == "Llama3-8B-1.58-100B-tokens": + elif get_model_name() in ["Llama3-8B-1.58-100B-tokens", "Falcon3-7B-1.58bit", "Falcon3-10B-1.58bit", "Falcon3-3B-1.58bit", "Falcon3-1B-1.58bit"]: run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen") elif get_model_name() == "bitnet_b1_58-3B": run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") @@ -155,7 +170,7 @@ def gen_code(): shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") if get_model_name() == "bitnet_b1_58-large": run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen") - elif get_model_name() in ["Llama3-8B-1.58-100B-tokens", "falcon3-7b-1.58bit"]: + elif get_model_name() in ["Llama3-8B-1.58-100B-tokens", "Falcon3-7B-1.58bit", "Falcon3-10B-1.58bit", "Falcon3-3B-1.58bit", "Falcon3-1B-1.58bit"]: run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") elif get_model_name() == "bitnet_b1_58-3B": run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")