diff --git a/.gitmodules b/.gitmodules index 2b36e49..60c975a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "3rdparty/llama.cpp"] path = 3rdparty/llama.cpp - url = https://github.com/Eddie-Wang1120/llama.cpp.git - branch = merge-dev + url = git@github.com:Eddie-Wang1120/llama.cpp.git + branch = pp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ddaa51..bd1143b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) # option list option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF) option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF) +option(BITNET_TL2_LOSS "bitnet.cpp: use tl2 on x86 platform" OFF) set(CMAKE_CXX_STANDARD_REQUIRED true) @@ -24,6 +25,7 @@ set(THREADS_PREFER_PTHREAD_FLAG ON) # override ggml options set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1}) set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2}) +set(GGML_BITNET_TL2_LOSS ${BITNET_TL2_LOSS}) if (GGML_BITNET_ARM_TL1) add_compile_definitions(GGML_BITNET_ARM_TL1) @@ -31,9 +33,8 @@ endif() if (GGML_BITNET_X86_TL2) add_compile_definitions(GGML_BITNET_X86_TL2) endif() - -if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - add_compile_options(-fpermissive) +if (GGML_BITNET_TL2_LOSS) + add_compile_definitions(GGML_BITNET_TL2_LOSS) endif() find_package(Threads REQUIRED) diff --git a/README.md b/README.md index a439f0a..0b6bac5 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,9 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) I2_S - TL1 - TL2 + TL1(TL1_1) + TL2(TL2_1) + TL2-Loss(TL2_0) bitnet_b1_58-large @@ -53,12 +54,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) ✅ ❌ ✅ + ✅ ARM ✅ ✅ ❌ + ✅ bitnet_b1_58-3B @@ -67,12 +70,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) ❌ ❌ ✅ + ✅ ARM ❌ ✅ ❌ + ✅ Llama3-8B-1.58-100B-tokens @@ -81,12 +86,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) ✅ ❌ ✅ + ✅ ARM ✅ ✅ ❌ + ✅ Falcon3 Family @@ -95,12 +102,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) ✅ ❌ ✅ + ✅ ARM ✅ ✅ ❌ + ✅ @@ -144,11 +153,11 @@ pip install -r requirements.txt 3. Build the project ```bash # Download the model from Hugging Face, convert it to quantized gguf format, and build the project -python setup_env.py --hf-repo tiiuae/Falcon3-7B-Instruct-1.58bit -q i2_s +python setup_env.py --hf-repo 1bitLLM/bitnet_b1_58-large -q i2_s # Or you can manually download the model and run with local path -huggingface-cli download tiiuae/Falcon3-7B-Instruct-1.58bit --local-dir models/Falcon3-7B-Instruct-1.58bit -python setup_env.py -md models/Falcon3-7B-Instruct-1.58bit -q i2_s +huggingface-cli download 1bitLLM/bitnet_b1_58-large --local-dir models/bitnet_b1_58-large +python setup_env.py -md models/bitnet_b1_58-large -q i2_s ```
 usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
diff --git a/include/bitnet-lut-kernels.h b/include/bitnet-lut-kernels.h
new file mode 100644
index 0000000..daf470e
--- /dev/null
+++ b/include/bitnet-lut-kernels.h
@@ -0,0 +1,627 @@
+#if defined(GGML_BITNET_ARM_TL1)
+#include "ggml-bitnet.h"
+#define GGML_BITNET_MAX_NODES 8192
+static bool initialized = false;
+static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
+static size_t bitnet_tensor_extras_index = 0;
+static void * aligned_malloc(size_t size) {{
+#if defined(_WIN32)
+    return _aligned_malloc(size, 64);
+#else
+    void * ptr = nullptr;
+    posix_memalign(&ptr, 64, size);
+    return ptr;
+#endif
+}}
+static void aligned_free(void * ptr) {{
+#if defined(_WIN32)
+    _aligned_free(ptr);
+#else
+    free(ptr);
+#endif
+}}
+
+void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
+    bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
+    bitnet_float_type* b = (bitnet_float_type*)b_;
+#ifdef __ARM_NEON
+    float32x4_t temp_max = vdupq_n_f32(0);
+    for (int i=0; i < k / 4; i++) {{
+      float32x4_t vec_bs = vld1q_f32(b + 4 * i);
+      float32x4_t abssum = vabsq_f32(vec_bs);
+      temp_max = vmaxq_f32(abssum, temp_max);
+    }}
+    float32_t scales = 127 / vmaxvq_f32(temp_max);
+    *lut_scales = scales;
+#elif defined __AVX2__
+    __m256 max_vec = _mm256_set1_ps(0.f);
+    const __m256 vec_sign = _mm256_set1_ps(-0.0f);
+    // #pragma unroll
+    for (int i = 0; i < k / 8; i++) {{
+        __m256 vec_b = _mm256_loadu_ps(b + i * 8);
+        __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
+        max_vec = _mm256_max_ps(vec_babs, max_vec);
+    }}
+    __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
+    max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
+    max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
+    float scales = 127 / _mm_cvtss_f32(max1);
+    *lut_scales = scales;
+#endif
+}}
+
+void partial_max_reset(void* lut_scales_) {{
+    bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
+    *lut_scales = 0.0;
+}}
+
+#ifdef __ARM_NEON
+inline void Transpose_8_8(
+    int16x8_t *v0,
+    int16x8_t *v1,
+    int16x8_t *v2,
+    int16x8_t *v3,
+    int16x8_t *v4,
+    int16x8_t *v5,
+    int16x8_t *v6,
+    int16x8_t *v7)
+{{
+    int16x8x2_t q04 = vzipq_s16(*v0, *v4);
+    int16x8x2_t q15 = vzipq_s16(*v1, *v5);
+    int16x8x2_t q26 = vzipq_s16(*v2, *v6);
+    int16x8x2_t q37 = vzipq_s16(*v3, *v7);
+
+    int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
+    int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
+    int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
+    int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
+
+    int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
+    int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
+    int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
+    int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
+
+    *v0 = q_fin_0.val[0];
+    *v1 = q_fin_0.val[1];
+    *v2 = q_fin_1.val[0];
+    *v3 = q_fin_1.val[1];
+    *v4 = q_fin_2.val[0];
+    *v5 = q_fin_2.val[1];
+    *v6 = q_fin_3.val[0];
+    *v7 = q_fin_3.val[1];
+}}
+#endif
+
+template
+inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
+#ifdef __ARM_NEON
+    int16x8_t vec_lut[16];
+    float32_t scales = *lut_scales;
+        uint8_t tbl_mask[16];
+        tbl_mask[0] = 0;
+        tbl_mask[1] = 2;
+        tbl_mask[2] = 4;
+        tbl_mask[3] = 6;
+        tbl_mask[4] = 8;
+        tbl_mask[5] = 10;
+        tbl_mask[6] = 12;
+        tbl_mask[7] = 14;
+        tbl_mask[8] = 1;
+        tbl_mask[9] = 3;
+        tbl_mask[10] = 5;
+        tbl_mask[11] = 7;
+        tbl_mask[12] = 9;
+        tbl_mask[13] = 11;
+        tbl_mask[14] = 13;
+        tbl_mask[15] = 15;
+        uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
+#pragma unroll
+    for (int k = 0; k < act_k / 16; ++k) {{
+        float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
+        float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
+        float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
+        float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
+        float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
+        float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
+        int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
+        int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
+        int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
+        int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
+        int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
+        int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
+        int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
+        int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
+        int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
+        int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
+        vec_lut[0] = vdupq_n_s16(0);
+        vec_lut[0] = vec_lut[0] - vec_bs_0;
+        vec_lut[0] = vec_lut[0] - vec_bs_1;
+        vec_lut[1] = vdupq_n_s16(0);
+        vec_lut[1] = vec_lut[1] - vec_bs_0;
+        vec_lut[2] = vdupq_n_s16(0);
+        vec_lut[2] = vec_lut[2] - vec_bs_0;
+        vec_lut[2] = vec_lut[2] + vec_bs_1;
+        vec_lut[3] = vdupq_n_s16(0);
+        vec_lut[3] = vec_lut[3] - vec_bs_1;
+        vec_lut[4] = vdupq_n_s16(0);
+        vec_lut[5] = vec_bs_1;
+        vec_lut[6] = vec_bs_0;
+        vec_lut[6] = vec_lut[6] - vec_bs_1;
+        vec_lut[7] = vec_bs_0;
+        vec_lut[8] = vec_bs_0;
+        vec_lut[8] = vec_lut[8] + vec_bs_1;
+        Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
+                      &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
+        Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
+                      &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
+#pragma unroll
+        for (int idx = 0; idx < 8; idx++) {{
+            int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
+            int8x8_t q0_low = vget_low_s8(q0_s);
+            int8x8_t q0_high = vget_high_s8(q0_s);
+            int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
+            int8x8_t q1_low = vget_low_s8(q1_s);
+            int8x8_t q1_high = vget_high_s8(q1_s);
+            vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
+            vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
+            vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
+            vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
+        }}
+    }}
+#endif
+}}
+
+static bool is_type_supported(enum ggml_type type) {{
+    if (type == GGML_TYPE_Q4_0 ||
+        type == GGML_TYPE_TL1) {{
+        return true;
+    }} else {{
+        return false;
+    }}
+}}
+#include 
+
+#define BM1536_4096 256
+#define BBK1536_4096 128
+inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __ARM_NEON
+    const int KK = BBK1536_4096 / 2;
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
+    const int8x16_t vec_zero = vdupq_n_s16(0x0000);
+    int8x16_t vec_lut[2 * KK];
+    int16x8_t vec_c[4];
+#pragma unroll
+    for (int k = 0; k < 2 * KK; k++) {
+        vec_lut[k] = vld1q_s8(lut + k * 16);
+    }
+
+#pragma unroll
+    for (int i = 0; i < BM1536_4096; i += 32) {
+        #pragma unroll
+        for (int i=0; i<4; i++) {
+            vec_c[i] = vandq_s16(vec_c[i], vec_zero);
+        }
+
+#pragma unroll
+        for (int k = 0; k < KK / 4; k++) {
+            
+            uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
+            uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
+            uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
+            int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
+            int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
+            int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
+            int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
+            int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
+            int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
+            vec_c[0] += vec_v_left_0.val[0];
+            vec_c[0] += vec_v_right_0.val[0];
+            vec_c[1] += vec_v_left_0.val[1];
+            vec_c[1] += vec_v_right_0.val[1];
+        
+            uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
+            uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
+            uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
+            int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
+            int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
+            int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
+            int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
+            int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
+            int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
+            vec_c[0] += vec_v_left_1.val[0];
+            vec_c[0] += vec_v_right_1.val[0];
+            vec_c[1] += vec_v_left_1.val[1];
+            vec_c[1] += vec_v_right_1.val[1];
+        
+            uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
+            uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
+            uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
+            int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
+            int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
+            int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
+            int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
+            int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
+            int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
+            vec_c[2] += vec_v_left_2.val[0];
+            vec_c[2] += vec_v_right_2.val[0];
+            vec_c[3] += vec_v_left_2.val[1];
+            vec_c[3] += vec_v_right_2.val[1];
+        
+            uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
+            uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
+            uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
+            int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
+            int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
+            int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
+            int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
+            int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
+            int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
+            vec_c[2] += vec_v_left_3.val[0];
+            vec_c[2] += vec_v_right_3.val[0];
+            vec_c[3] += vec_v_left_3.val[1];
+            vec_c[3] += vec_v_right_3.val[1];
+        
+       }
+
+        int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
+        int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
+        vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
+        vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
+        int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
+        int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
+        vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
+        vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
+        int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
+        int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
+        vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
+        vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
+        int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
+        int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
+        vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
+        vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
+
+    }
+#endif
+}
+
+int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+    alignas(32) uint32_t CBits[BM1536_4096];
+    memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
+#pragma unroll
+    for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
+        tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
+    }
+#pragma unroll
+    for (int i = 0; i < BM1536_4096; i++) {
+        ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
+    }
+  return 0;
+};
+#include 
+
+#define BM1536_1536 128
+#define BBK1536_1536 64
+inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __ARM_NEON
+    const int KK = BBK1536_1536 / 2;
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
+    const int8x16_t vec_zero = vdupq_n_s16(0x0000);
+    int8x16_t vec_lut[2 * KK];
+    int16x8_t vec_c[8];
+#pragma unroll
+    for (int k = 0; k < 2 * KK; k++) {
+        vec_lut[k] = vld1q_s8(lut + k * 16);
+    }
+
+#pragma unroll
+    for (int i = 0; i < BM1536_1536; i += 64) {
+        #pragma unroll
+        for (int i=0; i<8; i++) {
+            vec_c[i] = vandq_s16(vec_c[i], vec_zero);
+        }
+
+#pragma unroll
+        for (int k = 0; k < KK / 2; k++) {
+            
+            uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
+            uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
+            uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
+            int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
+            int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
+            int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
+            int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
+            int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
+            int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
+            vec_c[0] += vec_v_left_0.val[0];
+            vec_c[0] += vec_v_right_0.val[0];
+            vec_c[1] += vec_v_left_0.val[1];
+            vec_c[1] += vec_v_right_0.val[1];
+        
+            uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
+            uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
+            uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
+            int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
+            int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
+            int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
+            int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
+            int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
+            int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
+            vec_c[2] += vec_v_left_1.val[0];
+            vec_c[2] += vec_v_right_1.val[0];
+            vec_c[3] += vec_v_left_1.val[1];
+            vec_c[3] += vec_v_right_1.val[1];
+        
+            uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
+            uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
+            uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
+            int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
+            int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
+            int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
+            int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
+            int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
+            int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
+            vec_c[4] += vec_v_left_2.val[0];
+            vec_c[4] += vec_v_right_2.val[0];
+            vec_c[5] += vec_v_left_2.val[1];
+            vec_c[5] += vec_v_right_2.val[1];
+        
+            uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
+            uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
+            uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
+            int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
+            int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
+            int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
+            int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
+            int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
+            int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
+            vec_c[6] += vec_v_left_3.val[0];
+            vec_c[6] += vec_v_right_3.val[0];
+            vec_c[7] += vec_v_left_3.val[1];
+            vec_c[7] += vec_v_right_3.val[1];
+        
+       }
+
+        int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
+        int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
+        vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
+        vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
+        int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
+        int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
+        vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
+        vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
+        int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
+        int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
+        vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
+        vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
+        int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
+        int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
+        vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
+        vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
+        int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
+        int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
+        vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
+        vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
+        int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
+        int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
+        vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
+        vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
+        int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
+        int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
+        vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
+        vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
+        int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
+        int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
+        vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
+        vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
+
+    }
+#endif
+}
+
+int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+    alignas(32) uint32_t CBits[BM1536_1536];
+    memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
+#pragma unroll
+    for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
+        tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
+    }
+#pragma unroll
+    for (int i = 0; i < BM1536_1536; i++) {
+        ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
+    }
+  return 0;
+};
+#include 
+
+#define BM4096_1536 256
+#define BBK4096_1536 128
+inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __ARM_NEON
+    const int KK = BBK4096_1536 / 2;
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
+    const int8x16_t vec_zero = vdupq_n_s16(0x0000);
+    int8x16_t vec_lut[2 * KK];
+    int16x8_t vec_c[4];
+#pragma unroll
+    for (int k = 0; k < 2 * KK; k++) {
+        vec_lut[k] = vld1q_s8(lut + k * 16);
+    }
+
+#pragma unroll
+    for (int i = 0; i < BM4096_1536; i += 32) {
+        #pragma unroll
+        for (int i=0; i<4; i++) {
+            vec_c[i] = vandq_s16(vec_c[i], vec_zero);
+        }
+
+#pragma unroll
+        for (int k = 0; k < KK / 4; k++) {
+            
+            uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
+            uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
+            uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
+            int8x16_t  vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
+            int8x16_t  vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
+            int8x16_t  vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
+            int8x16_t  vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
+            int8x16x2_t  vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
+            int8x16x2_t  vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
+            vec_c[0] += vec_v_left_0.val[0];
+            vec_c[0] += vec_v_right_0.val[0];
+            vec_c[1] += vec_v_left_0.val[1];
+            vec_c[1] += vec_v_right_0.val[1];
+        
+            uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
+            uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
+            uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
+            int8x16_t  vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
+            int8x16_t  vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
+            int8x16_t  vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
+            int8x16_t  vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
+            int8x16x2_t  vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
+            int8x16x2_t  vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
+            vec_c[0] += vec_v_left_1.val[0];
+            vec_c[0] += vec_v_right_1.val[0];
+            vec_c[1] += vec_v_left_1.val[1];
+            vec_c[1] += vec_v_right_1.val[1];
+        
+            uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
+            uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
+            uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
+            int8x16_t  vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
+            int8x16_t  vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
+            int8x16_t  vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
+            int8x16_t  vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
+            int8x16x2_t  vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
+            int8x16x2_t  vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
+            vec_c[2] += vec_v_left_2.val[0];
+            vec_c[2] += vec_v_right_2.val[0];
+            vec_c[3] += vec_v_left_2.val[1];
+            vec_c[3] += vec_v_right_2.val[1];
+        
+            uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
+            uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
+            uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
+            int8x16_t  vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
+            int8x16_t  vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
+            int8x16_t  vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
+            int8x16_t  vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
+            int8x16x2_t  vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
+            int8x16x2_t  vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
+            vec_c[2] += vec_v_left_3.val[0];
+            vec_c[2] += vec_v_right_3.val[0];
+            vec_c[3] += vec_v_left_3.val[1];
+            vec_c[3] += vec_v_right_3.val[1];
+        
+       }
+
+        int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
+        int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
+        vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
+        vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
+        int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
+        int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
+        vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
+        vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
+        int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
+        int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
+        vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
+        vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
+        int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
+        int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
+        vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
+        vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
+
+    }
+#endif
+}
+
+int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+    alignas(32) uint32_t CBits[BM4096_1536];
+    memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
+#pragma unroll
+    for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
+        tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
+    }
+#pragma unroll
+    for (int i = 0; i < BM4096_1536; i++) {
+        ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
+    }
+  return 0;
+};
+
+template
+void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
+  partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
+  per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
+  
+  lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
+}}
+void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
+    if (m == 1536 && k == 4096) {
+        preprocessor_k<4096>(B, LUT_Scales, QLUT);
+    }
+    else if (m == 1536 && k == 1536) {
+        preprocessor_k<1536>(B, LUT_Scales, QLUT);
+    }
+    else if (m == 4096 && k == 1536) {
+        preprocessor_k<1536>(B, LUT_Scales, QLUT);
+    }
+}
+void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+    if (m == 1536 && k == 4096) {
+        qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
+    }
+    else if (m == 1536 && k == 1536) {
+        qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
+    }
+    else if (m == 4096 && k == 1536) {
+        qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
+    }
+}
+
+void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
+    if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
+        return;
+    }
+
+    int k = tensor->ne[0];
+    int m = tensor->ne[1];
+    const int lut_scales_size = 1;
+    const int scales_size = 1;
+    int bk = 0;
+    int bm = 0;
+
+    if (m == 1536 && k == 4096) {
+        bm = BM1536_4096;
+        bk = BBK1536_4096;
+    }
+else if (m == 1536 && k == 1536) {
+        bm = BM1536_1536;
+        bk = BBK1536_1536;
+    }
+else if (m == 4096 && k == 1536) {
+        bm = BM4096_1536;
+        bk = BBK4096_1536;
+    }
+
+    const int n_tile_num = m / bm;
+    const int BK = bk;
+    uint8_t * qweights;
+    bitnet_float_type * scales;
+
+    scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
+    qweights = (uint8_t *) tensor->data;
+    float * i2_scales = (float * )(qweights + k * m / 4);
+    scales[0] = (bitnet_float_type) i2_scales[0];
+
+    tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
+    bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
+        /* .lut_scales_size = */ lut_scales_size,
+        /* .BK              = */ BK,
+        /* .n_tile_num      = */ n_tile_num,
+        /* .qweights        = */ qweights,
+        /* .scales          = */ scales
+    };
+}
+#endif
\ No newline at end of file
diff --git a/include/ggml-bitnet.h b/include/ggml-bitnet.h
index 3f8571c..bf373df 100644
--- a/include/ggml-bitnet.h
+++ b/include/ggml-bitnet.h
@@ -5,8 +5,13 @@
 
 #ifdef __ARM_NEON
 #include 
+#if defined(GGML_BITNET_ARM_TL1)
 typedef float32_t bitnet_float_type;
 #else
+typedef float16_t bitnet_float_type;
+#endif
+#else
+#include 
 typedef float bitnet_float_type;
 #endif
 
@@ -43,6 +48,10 @@ GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* Q
 GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
 GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
 #endif
+#if defined(GGML_BITNET_TL2_LOSS)
+GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
+GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* Three_LUT_Scales, void* Two_LUT_Scales, void* Three_QLUT, void* Two_QLUT);
+#endif
 
 #ifdef  __cplusplus
 }
diff --git a/include/kernel_config.ini b/include/kernel_config.ini
new file mode 100644
index 0000000..5d94318
--- /dev/null
+++ b/include/kernel_config.ini
@@ -0,0 +1,21 @@
+[Kernels_0]
+m = 1536
+k = 4096
+bm = 256
+bk = 128
+bmm = 32
+
+[Kernels_1]
+m = 1536
+k = 1536
+bm = 128
+bk = 64
+bmm = 64
+
+[Kernels_2]
+m = 4096
+k = 1536
+bm = 256
+bk = 128
+bmm = 32
+
diff --git a/setup_env.py b/setup_env.py
index 9256324..12d5100 100644
--- a/setup_env.py
+++ b/setup_env.py
@@ -44,8 +44,8 @@ SUPPORTED_HF_MODELS = {
 }
 
 SUPPORTED_QUANT_TYPES = {
-    "arm64": ["i2_s", "tl1"],
-    "x86_64": ["i2_s", "tl2"]
+    "arm64": ["i2_s", "tl1", "tl2-loss"],
+    "x86_64": ["i2_s", "tl2", "tl2-loss"]
 }
 
 COMPILER_EXTRA_ARGS = {
@@ -111,8 +111,10 @@ def prepare_model():
     gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
     if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
         logging.info(f"Converting HF model to GGUF format...")
-        if quant_type.startswith("tl"):
+        if quant_type in ["tl1", "tl2"]:
             run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
+        elif quant_type in ["tl2-loss"]:
+            run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "tl2", "--quant-embd", "--loss", "--outfile", model_dir + str("/ggml-model-tl2-loss.gguf")], log_step="convert_to_tl")
         else: # i2s
             # convert to f32
             run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
@@ -156,11 +158,20 @@ def gen_code():
                 shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
                 shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
         if get_model_name() == "bitnet_b1_58-large":
-            run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
         elif get_model_name() in llama3_f3_models:
-            run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
         elif get_model_name() == "bitnet_b1_58-3B":
-            run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
         else:
             raise NotImplementedError()
     else:
@@ -172,11 +183,20 @@ def gen_code():
                 sys.exit(1)
             shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
         if get_model_name() == "bitnet_b1_58-large":
-            run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
         elif get_model_name() in llama3_f3_models:
-            run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
         elif get_model_name() == "bitnet_b1_58-3B":
-            run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            if args.quant_type == "tl2-loss":
+                run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
+            else:
+                run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
         else:
             raise NotImplementedError()
 
@@ -192,7 +212,10 @@ def compile():
         logging.error(f"Arch {arch} is not supported yet")
         exit(0)
     logging.info("Compiling the code using CMake.")
-    run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
+    if args.quant_type == "tl2-loss":
+        run_command(["cmake", "-B", "build", "-DBITNET_TL2_LOSS=ON", *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
+    else:
+        run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
     # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
     run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
 
diff --git a/src/ggml-bitnet-lut.cpp b/src/ggml-bitnet-lut.cpp
index 59422d5..680cdab 100644
--- a/src/ggml-bitnet-lut.cpp
+++ b/src/ggml-bitnet-lut.cpp
@@ -154,6 +154,80 @@ size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const stru
     return wsize;
 }
 
+int ggml_bitnet_get_type_bits(enum ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_TL2:
+            return 2;
+        case GGML_TYPE_Q4_0:
+            return 4;
+        default:
+            return 0;
+    }
+}
+#endif
+
+#if defined(GGML_BITNET_TL2_LOSS)
+void ggml_bitnet_init(void) {
+    // LOG(INFO) << "ggml_bitnet_init";
+
+    if (initialized) {
+        return;
+    }
+    initialized = true;
+
+    // if (wrapper == nullptr) {
+    //     wrapper = new BITNET::BITNETGeMMWrapper();
+    // }
+    if (bitnet_tensor_extras == nullptr) {
+        bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
+    }
+    bitnet_tensor_extras_index = 0;
+}
+
+void ggml_bitnet_free(void) {
+    // LOG(INFO) << "ggml_bitnet_free";
+
+    if (!initialized) {
+        return;
+    }
+    initialized = false;
+
+    // delete wrapper;
+    // wrapper = nullptr;
+    for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
+        // aligned_free(bitnet_tensor_extras[i].qweights);
+        // aligned_free(bitnet_tensor_extras[i].scales);
+    }
+    delete[] bitnet_tensor_extras;
+    bitnet_tensor_extras = nullptr;
+}
+
+bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
+    if ((is_type_supported(src0->type)) &&
+        src1->type == GGML_TYPE_F32 &&
+        dst->type == GGML_TYPE_F32 &&
+        src0->backend == GGML_BACKEND_TYPE_CPU) {
+        if (src1->ne[1] <= 1) {
+            return true;
+        }
+    }
+    return false;
+}
+
+size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
+    const size_t ne01 = src0->ne[1];
+    const size_t ne10 = src1->ne[0];
+    const size_t ne11 = src1->ne[1];
+    
+    size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
+    if (sizeof(bitnet_float_type) == 2) {
+        // Need fp32 to fp16 conversion
+        wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
+    }
+    wsize = ((wsize - 1) / 64 + 1) * 64;
+    return wsize;
+}
+
 int ggml_bitnet_get_type_bits(enum ggml_type type) {
     switch (type) {
         case GGML_TYPE_TL2:
diff --git a/utils/codegen_tl2_loss.py b/utils/codegen_tl2_loss.py
new file mode 100644
index 0000000..ced9903
--- /dev/null
+++ b/utils/codegen_tl2_loss.py
@@ -0,0 +1,1056 @@
+import argparse
+import os
+from configparser import ConfigParser
+
+def gen_ctor_code():
+    kernel_code = "\n\
+#include \"ggml-bitnet.h\"\n\
+#include \n\
+#define GGML_BITNET_MAX_NODES 8192\n\
+static bool initialized = false;\n\
+static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
+static size_t bitnet_tensor_extras_index = 0;\n\
+static void * aligned_malloc(size_t size) {\n\
+#if defined(_WIN32)\n\
+    return _aligned_malloc(size, 64);\n\
+#else\n\
+    void * ptr = nullptr;\n\
+    posix_memalign(&ptr, 64, size);\n\
+    return ptr;\n\
+#endif\n\
+}\n\
+\n\
+static void aligned_free(void * ptr) {\n\
+#if defined(_WIN32)\n\
+    _aligned_free(ptr);\n\
+#else\n\
+    free(ptr);\n\
+#endif\n\
+}\n\
+#define BK2 32\n\
+#if defined __AVX2__\n\
+inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
+{\n\
+    __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\
+    __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\
+    *vl = _mm256_unpacklo_epi32(va, vb);\n\
+    *vh = _mm256_unpackhi_epi32(va, vb);\n\
+}\n\
+inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
+{\n\
+    __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\
+    __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\
+    *vl = _mm256_unpacklo_epi64(va, vb);\n\
+    *vh = _mm256_unpackhi_epi64(va, vb);\n\
+}\n\
+inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
+{\n\
+    *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0));\n\
+    *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1));\n\
+}\n\
+inline void Transpose_8_8(\n\
+    __m256i *v0,\n\
+    __m256i *v1,\n\
+    __m256i *v2,\n\
+    __m256i *v3,\n\
+    __m256i *v4,\n\
+    __m256i *v5,\n\
+    __m256i *v6,\n\
+    __m256i *v7)\n\
+{\n\
+    __m256i w0, w1, w2, w3, w4, w5, w6, w7;\n\
+    __m256i x0, x1, x2, x3, x4, x5, x6, x7;\n\
+    _mm256_merge_epi32(*v0, *v1, &w0, &w1);\n\
+    _mm256_merge_epi32(*v2, *v3, &w2, &w3);\n\
+    _mm256_merge_epi32(*v4, *v5, &w4, &w5);\n\
+    _mm256_merge_epi32(*v6, *v7, &w6, &w7);\n\
+    _mm256_merge_epi64(w0, w2, &x0, &x1);\n\
+    _mm256_merge_epi64(w1, w3, &x2, &x3);\n\
+    _mm256_merge_epi64(w4, w6, &x4, &x5);\n\
+    _mm256_merge_epi64(w5, w7, &x6, &x7);\n\
+    _mm256_merge_si128(x0, x4, v0, v1);\n\
+    _mm256_merge_si128(x1, x5, v2, v3);\n\
+    _mm256_merge_si128(x2, x6, v4, v5);\n\
+    _mm256_merge_si128(x3, x7, v6, v7);\n\
+}\n\
+#elif defined __ARM_NEON\n\
+inline void Transpose_8_8(\n\
+    int8x8_t *v0,\n\
+    int8x8_t *v1,\n\
+    int8x8_t *v2,\n\
+    int8x8_t *v3,\n\
+    int8x8_t *v4,\n\
+    int8x8_t *v5,\n\
+    int8x8_t *v6,\n\
+    int8x8_t *v7)\n\
+{\n\
+    int8x8x2_t q04 = vzip_s8(*v0, *v4);\n\
+    int8x8x2_t q15 = vzip_s8(*v1, *v5);\n\
+    int8x8x2_t q26 = vzip_s8(*v2, *v6);\n\
+    int8x8x2_t q37 = vzip_s8(*v3, *v7);\n\
+    int8x8x2_t q0246_0 = vzip_s8(q04.val[0], q26.val[0]);\n\
+    int8x8x2_t q0246_1 = vzip_s8(q04.val[1], q26.val[1]);\n\
+    int8x8x2_t q1357_0 = vzip_s8(q15.val[0], q37.val[0]);\n\
+    int8x8x2_t q1357_1 = vzip_s8(q15.val[1], q37.val[1]);\n\
+    int8x8x2_t q_fin_0 = vzip_s8(q0246_0.val[0], q1357_0.val[0]);\n\
+    int8x8x2_t q_fin_1 = vzip_s8(q0246_0.val[1], q1357_0.val[1]);\n\
+    int8x8x2_t q_fin_2 = vzip_s8(q0246_1.val[0], q1357_1.val[0]);\n\
+    int8x8x2_t q_fin_3 = vzip_s8(q0246_1.val[1], q1357_1.val[1]);\n\
+    *v0 = q_fin_0.val[0];\n\
+    *v1 = q_fin_0.val[1];\n\
+    *v2 = q_fin_1.val[0];\n\
+    *v3 = q_fin_1.val[1];\n\
+    *v4 = q_fin_2.val[0];\n\
+    *v5 = q_fin_2.val[1];\n\
+    *v6 = q_fin_3.val[0];\n\
+    *v7 = q_fin_3.val[1];\n\
+}\n\
+#endif\n\
+inline int32_t two_partial_max(void* lut_scales_, void* b_) {\n\
+    bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
+    bitnet_float_type* b = (bitnet_float_type*)b_;\n\
+#if defined __AVX2__\n\
+    const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\
+    __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1);\n\
+    __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1);\n\
+    const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
+    __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0);\n\
+    __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1);\n\
+    __m256 abssum = _mm256_add_ps(vec_babs0, vec_babs1);\n\
+    __m128 max2 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum));\n\
+    max2 = _mm_max_ps(max2, _mm_movehl_ps(max2, max2));\n\
+    max2 = _mm_max_ss(max2, _mm_movehdup_ps(max2));\n\
+    bitnet_float_type scales = _mm_cvtss_f32(max2) / 127;\n\
+    *lut_scales = std::max(*lut_scales, scales);\n\
+#elif defined __ARM_NEON\n\
+    float16x8x2_t vec_bs = vld2q_f16(b);\n\
+    float16x8_t abssum = vabsq_f16(vec_bs.val[0]) + vabsq_f16(vec_bs.val[1]);\n\
+    float16_t scales = vmaxvq_f16(abssum) / 127;\n\
+    *lut_scales = std::max(*lut_scales, scales);\n\
+#endif\n\
+    return 0;\n\
+}\n\
+inline int32_t three_partial_max(void* lut_scales_, void* b_) {\n\
+    bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
+    bitnet_float_type* b = (bitnet_float_type*)b_;\n\
+#if defined __AVX2__\n\
+    const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\
+    __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1);\n\
+    __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1);\n\
+    __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1);\n\
+    const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
+    __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0);\n\
+    __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1);\n\
+    __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2);\n\
+    __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), vec_babs2);\n\
+    __m128 max3 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum));\n\
+    max3 = _mm_max_ps(max3, _mm_movehl_ps(max3, max3));\n\
+    max3 = _mm_max_ss(max3, _mm_movehdup_ps(max3));\n\
+    bitnet_float_type scales = _mm_cvtss_f32(max3) / 127;\n\
+    *lut_scales = std::max(*lut_scales, scales);\n\
+#elif defined __ARM_NEON\n\
+    float16x8x3_t vec_bs = vld3q_f16(b);\n\
+    float16x8_t abssum = vabsq_f16(vec_bs.val[0]) + vabsq_f16(vec_bs.val[1]) + vabsq_f16(vec_bs.val[2]);\n\
+    float16_t scales = vmaxvq_f16(abssum) / 127;\n\
+    *lut_scales = std::max(*lut_scales, scales);\n\
+#endif\n\
+    return 0;\n\
+}\n\
+inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {\n\
+    bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
+    #pragma unroll\n\
+    for (int i=0; i< bs; i++) {\n\
+        lut_scales[i] = 0.0;\n\
+    }\n\
+    return 0;\n\
+}\n\
+template\n\
+inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
+#if defined __AVX2__\n\
+    __m256 vec_lut[16];\n\
+    const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\
+    bitnet_float_type scales = *lut_scales;\n\
+    bitnet_float_type t_scales = scales ? 1.0f / scales : 0.0f;\n\
+#pragma unroll\n\
+    for (int k = 0; k < act_k / 24; ++k) {\n\
+        __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1);\n\
+        __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1);\n\
+        __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1);\n\
+\n\
+        vec_lut[15] = _mm256_setzero_ps();\n\
+        vec_lut[14] = _mm256_setzero_ps();\n\
+        vec_lut[13] = vec_b0;\n\
+        vec_lut[13] = _mm256_add_ps(vec_lut[13], vec_b1);\n\
+        vec_lut[13] = _mm256_add_ps(vec_lut[13], vec_b2);\n\
+        vec_lut[12] = vec_b0;\n\
+        vec_lut[12] = _mm256_add_ps(vec_lut[12], vec_b1);\n\
+        vec_lut[11] = vec_b0;\n\
+        vec_lut[11] = _mm256_add_ps(vec_lut[11], vec_b1);\n\
+        vec_lut[11] = _mm256_sub_ps(vec_lut[11], vec_b2);\n\
+        vec_lut[10] = vec_b0;\n\
+        vec_lut[10] = _mm256_add_ps(vec_lut[10], vec_b2);\n\
+        vec_lut[9] = vec_b0;\n\
+        vec_lut[8] = vec_b0;\n\
+        vec_lut[8] = _mm256_sub_ps(vec_lut[8], vec_b2);\n\
+        vec_lut[7] = vec_b0;\n\
+        vec_lut[7] = _mm256_sub_ps(vec_lut[7], vec_b1);\n\
+        vec_lut[7] = _mm256_add_ps(vec_lut[7], vec_b2);\n\
+        vec_lut[6] = vec_b0;\n\
+        vec_lut[6] = _mm256_sub_ps(vec_lut[6], vec_b1);\n\
+        vec_lut[5] = vec_b0;\n\
+        vec_lut[5] = _mm256_sub_ps(vec_lut[5], vec_b1);\n\
+        vec_lut[5] = _mm256_sub_ps(vec_lut[5], vec_b2);\n\
+        vec_lut[4] = vec_b1;\n\
+        vec_lut[4] = _mm256_add_ps(vec_lut[4], vec_b2);\n\
+        vec_lut[3] = vec_b1;\n\
+        vec_lut[2] = vec_b1;\n\
+        vec_lut[2] = _mm256_sub_ps(vec_lut[2], vec_b2);\n\
+        vec_lut[1] = vec_b2;\n\
+        vec_lut[0] = _mm256_setzero_ps();\n\
+\n\
+#pragma unroll\n\
+        for (int g = 0; g < 14; ++g) {\n\
+            vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales));\n\
+        }\n\
+        __m256i ix[16];\n\
+        for (int g = 0; g < 14; ++g) {\n\
+            ix[g] = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
+        }\n\
+        __m256i shuffle_mask = _mm256_set_epi8(\n\
+                                               0x0f, 0x0e, 0x0d, 0x0c, 0x07, 0x06, 0x05, 0x04,\n\
+                                               0x0b, 0x0a, 0x09, 0x08, 0x03, 0x02, 0x01, 0x00,\n\
+                                               0x0f, 0x0e, 0x0d, 0x0c, 0x07, 0x06, 0x05, 0x04,\n\
+                                               0x0b, 0x0a, 0x09, 0x08, 0x03, 0x02, 0x01, 0x00\n\
+                                               );\n\
+        Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\
+        Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\
+        int8_t* qlut_i8 = reinterpret_cast(qlut);\n\
+#pragma unroll\n\
+        for (int g = 0; g < 8; ++g) {\n\
+            ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\
+            ix[g] = _mm256_packs_epi16(ix[g], ix[g]);\n\
+            ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
+            ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\
+            _mm_storeu_si128(reinterpret_cast<__m128i*>(qlut_i8 + k * 128 + g * 16 + 0), _mm256_castsi256_si128(ix[g]));\n\
+        }\n\
+    }\n\
+\n\
+    *lut_scales = scales;\n\
+#elif defined __ARM_NEON\n\
+    float16x8_t vec_lut[16];\n\
+    float16_t scales = *lut_scales;\n\
+    float16_t t_scales = scales ? 1.0 / scales : 0.0;\n\
+#pragma unroll\n\
+    for (int k = 0; k < act_k / 24; ++k) {\n\
+        float16x8x3_t vec_bs = vld3q_f16(b + k * 24);\n\
+        vec_lut[15] = vdupq_n_f16(0);\n\
+        vec_lut[14] = vdupq_n_f16(0);\n\
+        vec_lut[13] = vec_bs.val[0] + vec_bs.val[1] + vec_bs.val[2];\n\
+        vec_lut[12] = vec_bs.val[0] + vec_bs.val[1];\n\
+        vec_lut[11] = vec_bs.val[0] + vec_bs.val[1] - vec_bs.val[2];\n\
+        vec_lut[10] = vec_bs.val[0] + vec_bs.val[2];\n\
+        vec_lut[9] = vec_bs.val[0];\n\
+        vec_lut[8] = vec_bs.val[0] - vec_bs.val[2];\n\
+        vec_lut[7] = vec_bs.val[0] - vec_bs.val[1] + vec_bs.val[2];\n\
+        vec_lut[6] = vec_bs.val[0] - vec_bs.val[1];\n\
+        vec_lut[5] = vec_bs.val[0] - vec_bs.val[1] - vec_bs.val[2];\n\
+        vec_lut[4] = vec_bs.val[1] + vec_bs.val[2];\n\
+        vec_lut[3] = vec_bs.val[1];\n\
+        vec_lut[2] = vec_bs.val[1] - vec_bs.val[2];\n\
+        vec_lut[1] = vec_bs.val[2];\n\
+        vec_lut[0] = vdupq_n_f16(0);\n\
+\n\
+#pragma unroll\n\
+        for (int g = 0; g < 14; ++g) {\n\
+            vec_lut[g] = vmulq_n_f16(vec_lut[g], t_scales);\n\
+        }\n\
+\n\
+        int8x8_t vec_qlut[16];\n\
+#pragma unroll\n\
+        for (int g = 0; g < 14; ++g) {\n\
+            vec_qlut[g] = vqmovn_s16(vcvtnq_s16_f16(vec_lut[g]));\n\
+        }\n\
+        Transpose_8_8(&(vec_qlut[0]), &(vec_qlut[1]), &(vec_qlut[2]), &(vec_qlut[3]),\n\
+                      &(vec_qlut[4]), &(vec_qlut[5]), &(vec_qlut[6]), &(vec_qlut[7]));\n\
+        Transpose_8_8(&(vec_qlut[8]), &(vec_qlut[9]), &(vec_qlut[10]), &(vec_qlut[11]),\n\
+                      &(vec_qlut[12]), &(vec_qlut[13]), &(vec_qlut[14]), &(vec_qlut[15]));\n\
+\n\
+#pragma unroll\n\
+        for (int idx = 0; idx < 8; idx++) {\n\
+            vst1_s8(qlut + k * 16 * 8 + idx * 16 + 0 * 8, vec_qlut[idx]);\n\
+            vst1_s8(qlut + k * 16 * 8 + idx * 16 + 1 * 8, vec_qlut[idx + 8]);\n\
+        }\n\
+    }\n\
+#endif\n\
+    return 0;\n\
+}\n\
+\n\
+template\n\
+inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
+#if defined __AVX2__\n\
+    __m256 vec_lut[16];\n\
+    const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\
+    bitnet_float_type scales = *lut_scales;\n\
+    bitnet_float_type t_scales = scales ? 1.0f / scales : 0.0f;\n\
+#pragma unroll\n\
+    for (int k = 0; k < act_k / 16; ++k) {\n\
+        __m256 vec_b0 = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1);\n\
+        __m256 vec_b1 = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1);\n\
+        vec_lut[0] = _mm256_setzero_ps();\n\
+        vec_lut[0] = _mm256_sub_ps(vec_lut[0], vec_b0);\n\
+        vec_lut[0] = _mm256_sub_ps(vec_lut[0], vec_b1);\n\
+        vec_lut[1] = _mm256_setzero_ps();\n\
+        vec_lut[1] = _mm256_sub_ps(vec_lut[1], vec_b0);\n\
+        vec_lut[2] = _mm256_setzero_ps();\n\
+        vec_lut[2] = _mm256_sub_ps(vec_lut[2], vec_b0);\n\
+        vec_lut[2] = _mm256_add_ps(vec_lut[2], vec_b1);\n\
+        vec_lut[3] = _mm256_setzero_ps();\n\
+        vec_lut[3] = _mm256_sub_ps(vec_lut[3], vec_b1);\n\
+        vec_lut[4] = _mm256_setzero_ps();\n\
+        vec_lut[5] = vec_b1;\n\
+        vec_lut[6] = vec_b0;\n\
+        vec_lut[6] = _mm256_sub_ps(vec_lut[6], vec_b1);\n\
+        vec_lut[7] = vec_b0;\n\
+        vec_lut[8] = vec_b0;\n\
+        vec_lut[8] = _mm256_add_ps(vec_lut[8], vec_b1);\n\
+        vec_lut[9] = _mm256_setzero_ps();\n\
+        vec_lut[10] = _mm256_setzero_ps();\n\
+        vec_lut[11] = _mm256_setzero_ps();\n\
+        vec_lut[12] = _mm256_setzero_ps();\n\
+        vec_lut[13] = _mm256_setzero_ps();\n\
+        vec_lut[14] = _mm256_setzero_ps();\n\
+        vec_lut[15] = _mm256_setzero_ps();\n\
+\n\
+#pragma unroll\n\
+        for (int g = 0; g < 9; ++g) {\n\
+            vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales));\n\
+        }\n\
+        __m256i ix[16];\n\
+#pragma unroll\n\
+        for (int g = 0; g < 9; ++g) {\n\
+            ix[g] = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
+        }\n\
+\n\
+        __m256i shuffle_mask = _mm256_set_epi8(\n\
+                                               0x0f, 0x0e, 0x0d, 0x0c, 0x07, 0x06, 0x05, 0x04,\n\
+                                               0x0b, 0x0a, 0x09, 0x08, 0x03, 0x02, 0x01, 0x00,\n\
+                                               0x0f, 0x0e, 0x0d, 0x0c, 0x07, 0x06, 0x05, 0x04,\n\
+                                               0x0b, 0x0a, 0x09, 0x08, 0x03, 0x02, 0x01, 0x00\n\
+                                               );\n\
+\n\
+        Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\
+        Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\
+\n\
+        int8_t* qlut_i8 = reinterpret_cast(qlut);\n\
+#pragma unroll\n\
+        for (int g = 0; g < 8; ++g) {\n\
+            ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\
+            ix[g] = _mm256_packs_epi16(ix[g], ix[g]);\n\
+            ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
+            ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\
+            _mm_storeu_si128(reinterpret_cast<__m128i*>(qlut_i8 + k * 128 + g * 16 + 0), _mm256_castsi256_si128(ix[g]));\n\
+        }\n\
+    }\n\
+    *lut_scales = scales;\n\
+#elif defined __ARM_NEON\n\
+    float16x8_t vec_lut[16];\n\
+    float16_t scales = *lut_scales;\n\
+    float16_t t_scales = scales ? 1.0 / scales : 0.0;\n\
+\n\
+#pragma unroll\n\
+    for (int k = 0; k < act_k / 16; ++k) {\n\
+        float16x8x2_t vec_bs = vld2q_f16(b + k * 16);\n\
+        vec_lut[15] = vdupq_n_f16(0);\n\
+        vec_lut[14] = vdupq_n_f16(0);\n\
+        vec_lut[13] = vdupq_n_f16(0);\n\
+        vec_lut[12] = vdupq_n_f16(0);\n\
+        vec_lut[11] = vdupq_n_f16(0);\n\
+        vec_lut[10] = vdupq_n_f16(0);\n\
+        vec_lut[9] = vdupq_n_f16(0);\n\
+        vec_lut[8] = vec_bs.val[0] + vec_bs.val[1];\n\
+        vec_lut[7] = vec_bs.val[0];\n\
+        vec_lut[6] = vec_bs.val[0] - vec_bs.val[1];\n\
+        vec_lut[5] = vec_bs.val[1];\n\
+        vec_lut[4] = vdupq_n_f16(0);\n\
+        vec_lut[3] = -vec_bs.val[1];\n\
+        vec_lut[2] = -vec_bs.val[0] + vec_bs.val[1];\n\
+        vec_lut[1] = -vec_bs.val[0];\n\
+        vec_lut[0] = -vec_bs.val[0] - vec_bs.val[1];\n\
+\n\
+#pragma unroll\n\
+        for (int g = 0; g < 16; ++g) {\n\
+            vec_lut[g] = vmulq_n_f16(vec_lut[g], t_scales);\n\
+        }\n\
+\n\
+        int8x8_t vec_qlut[16];\n\
+#pragma unroll\n\
+        for (int g = 0; g < 16; ++g) {\n\
+            vec_qlut[g] = vqmovn_s16(vcvtnq_s16_f16(vec_lut[g]));\n\
+        }\n\
+        Transpose_8_8(&(vec_qlut[0]), &(vec_qlut[1]), &(vec_qlut[2]), &(vec_qlut[3]),\n\
+                      &(vec_qlut[4]), &(vec_qlut[5]), &(vec_qlut[6]), &(vec_qlut[7]));\n\
+        Transpose_8_8(&(vec_qlut[8]), &(vec_qlut[9]), &(vec_qlut[10]), &(vec_qlut[11]),\n\
+                      &(vec_qlut[12]), &(vec_qlut[13]), &(vec_qlut[14]), &(vec_qlut[15]));\n\
+\n\
+#pragma unroll\n\
+        for (int idx = 0; idx < 8; idx++) {\n\
+            vst1_s8(qlut + k * 16 * 8 + idx * 16 + 0 * 8, vec_qlut[idx]);\n\
+            vst1_s8(qlut + k * 16 * 8 + idx * 16 + 1 * 8, vec_qlut[idx + 8]);\n\
+        }\n\
+    }\n\
+#endif\n\
+    return 0;\n\
+}\n\
+static bool is_type_supported(enum ggml_type type) {\n\
+    if (type == GGML_TYPE_Q4_0 ||\n\
+        type == GGML_TYPE_TL2) {\n\
+        return true;\n\
+    } else {\n\
+        return false;\n\
+    }\n\
+}\n\
+"
+    return kernel_code
+
+def gen_tbl_impl(pre, BM, BK, bm, k_list):
+
+    kernel_code = "\
+\n\
+#define BM{0} {1}\n\
+#define BBK{0} {2}\n\
+template\n\
+inline void three_tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {{\n\
+".format(pre, BM, BK)
+
+    if bm == 16:
+        kernel_code = "".join([kernel_code, "\
+#ifdef __AVX2__\n\
+    const int KK = BBK{0}/ 3;\n\
+    for (int i = 0; i < BM{0}; i += 16) {{\n\
+        __m256i vec_c0 = _mm256_setzero_si256();\n\
+#pragma unroll\n\
+        for (int k = 0; k < KK / 16; k++) {{\n\
+            __m256i vec_sign = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + k * 32));\n\
+            __m256i vec_k_top_256_0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 0 * 64));\n\
+            __m256i vec_k_bot_256_0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 0 * 64 + 32));\n\
+            __m256i vec_a_0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 128 + 0 * 32));\n\
+            __m256i vec_a_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_0 = _mm256_and_si256(vec_a_0, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_0 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 2 * 0), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_0 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_0, vec_a_top_0), vec_sign_top_0), vec_sign_top_0);\n\
+            __m256i vec_sign_bot_0 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 6 - 2 * 0), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_0 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_0, vec_a_bot_0), vec_sign_bot_0), vec_sign_bot_0);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_0)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_0)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_0, 1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_0, 1)));\n\
+            __m256i vec_k_top_256_1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 1 * 64));\n\
+            __m256i vec_k_bot_256_1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 1 * 64 + 32));\n\
+            __m256i vec_a_1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 128 + 1 * 32));\n\
+            __m256i vec_a_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_1 = _mm256_and_si256(vec_a_1, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_1 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 2 * 1), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_1 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_1, vec_a_top_1), vec_sign_top_1), vec_sign_top_1);\n\
+            __m256i vec_sign_bot_1 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 6 - 2 * 1), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_1 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_1, vec_a_bot_1), vec_sign_bot_1), vec_sign_bot_1);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_1, 1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_1, 1)));\n\
+            __m256i vec_k_top_256_2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 2 * 64));\n\
+            __m256i vec_k_bot_256_2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 2 * 64 + 32));\n\
+            __m256i vec_a_2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 128 + 2 * 32));\n\
+            __m256i vec_a_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_2 = _mm256_and_si256(vec_a_2, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 2 * 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_2 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_2, vec_a_top_2), vec_sign_top_2), vec_sign_top_2);\n\
+            __m256i vec_sign_bot_2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 6 - 2 * 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_2 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_2, vec_a_bot_2), vec_sign_bot_2), vec_sign_bot_2);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_2)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_2)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_2, 1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_2, 1)));\n\
+            __m256i vec_k_top_256_3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 3 * 64));\n\
+            __m256i vec_k_bot_256_3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(lut + 256 * k + 3 * 64 + 32));\n\
+            __m256i vec_a_3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 128 + 3 * 32));\n\
+            __m256i vec_a_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_3 = _mm256_and_si256(vec_a_3, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_3 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 3 * 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_3 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_3, vec_a_top_3), vec_sign_top_3), vec_sign_top_3);\n\
+            __m256i vec_sign_bot_3 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 6 - 3 * 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_3 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_3, vec_a_bot_3), vec_sign_bot_3), vec_sign_bot_3);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_3)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_3)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_3, 1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_3, 1)));\n\
+        }}\n\
+        __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i));\n\
+        __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8));\n\
+        vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\
+        vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i), vec_gc0);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8), vec_gc1);\n\
+    }}\n".format(pre)])
+
+        kernel_code = "".join([kernel_code, "\
+#elif defined __ARM_NEON\n\
+    const int KK = BBK{0} / 3;\n\
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
+#pragma unroll\n\
+    for (int i = 0; i < BM{0}; i += 16) {{\n\
+        int16x8_t vec_c0 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c1 = vdupq_n_s16(0);\n\
+#pragma unroll \n\
+        for (int k = 0; k < KK / 16; k++) {{\n\
+            uint8x16_t vec_sign_left = vmvnq_s8(vld1q_u8(sign + i * KK / 8 + k * 32));\n\
+            uint8x16_t vec_sign_right = vmvnq_u8(vld1q_u8(sign + i * KK / 8 + k * 32 + 16));\n".format(pre)])
+
+        for i in range(4):
+            kernel_code = "".join([kernel_code, "\
+            int8x16_t vec_k_left_left_{0} = vld1q_s8(lut + 256 * k + {0} * 64);\n\
+            int8x16_t vec_k_left_right_{0} = vld1q_s8(lut + 256 * k + {0} * 64 + 16);\n\
+            int8x16_t vec_k_right_left_{0} = vld1q_s8(lut + 256 * k + {0} * 64 + 32);\n\
+            int8x16_t vec_k_right_right_{0} = vld1q_s8(lut + 256 * k + {0} * 64 + 48);\n\
+            uint8x16_t vec_sign_left_left_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_left, 2 * {0}));\n\
+            uint8x16_t vec_sign_left_right_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_left, 2 * {0} + 1));\n\
+            uint8x16_t vec_sign_right_left_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_right, 2 * {0}));\n\
+            uint8x16_t vec_sign_right_right_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_right, 2 * {0} + 1));\n\
+            uint8x16_t vec_a_left_{0} = vld1q_u8(a + i * KK / 2 + k * 128 + {0} * 32);\n\
+            uint8x16_t vec_a_right_{0} = vld1q_u8(a + i * KK / 2 + k * 128 + {0} * 32 + 16);\n\
+            uint8x16_t vec_a_left_left_{0} = vshrq_n_u8(vec_a_left_{0}, 4);\n\
+            uint8x16_t vec_a_left_right_{0} = vandq_u8(vec_a_left_{0}, vec_mask);\n\
+            uint8x16_t vec_a_right_left_{0} = vshrq_n_u8(vec_a_right_{0}, 4);\n\
+            uint8x16_t vec_a_right_right_{0} = vandq_u8(vec_a_right_{0}, vec_mask);\n\
+            int8x16_t vec_v_top_left_tmp_{0} = vqtbl1q_s8(vec_k_left_left_{0}, vec_a_left_left_{0});\n\
+            int8x16_t vec_v_bot_left_tmp_{0} = vqtbl1q_s8(vec_k_left_right_{0}, vec_a_right_left_{0});\n\
+            int8x16_t vec_v_top_right_tmp_{0} = vqtbl1q_s8(vec_k_right_left_{0}, vec_a_left_right_{0});\n\
+            int8x16_t vec_v_bot_right_tmp_{0} = vqtbl1q_s8(vec_k_right_right_{0}, vec_a_right_right_{0});\n\
+            vec_v_top_left_tmp_{0} = vbslq_s8(vec_sign_left_left_{0}, vnegq_s8(vec_v_top_left_tmp_{0}), vec_v_top_left_tmp_{0});\n\
+            vec_v_bot_left_tmp_{0} = vbslq_s8(vec_sign_right_left_{0}, vnegq_s8(vec_v_bot_left_tmp_{0}), vec_v_bot_left_tmp_{0});\n\
+            vec_v_top_right_tmp_{0} = vbslq_s8(vec_sign_left_right_{0}, vnegq_s8(vec_v_top_right_tmp_{0}), vec_v_top_right_tmp_{0});\n\
+            vec_v_bot_right_tmp_{0} = vbslq_s8(vec_sign_right_right_{0}, vnegq_s8(vec_v_bot_right_tmp_{0}), vec_v_bot_right_tmp_{0});\n\
+            int16x8_t vec_v_top_left_high_{0} = vmovl_high_s8(vec_v_top_left_tmp_{0});\n\
+            int16x8_t vec_v_top_left_bot_{0} = vmovl_s8(vget_low_s8(vec_v_top_left_tmp_{0}));\n\
+            int16x8_t vec_v_top_right_high_{0} = vmovl_high_s8(vec_v_top_right_tmp_{0});\n\
+            int16x8_t vec_v_top_right_bot_{0} = vmovl_s8(vget_low_s8(vec_v_top_right_tmp_{0}));\n\
+            int16x8_t vec_v_bot_left_high_{0} = vmovl_high_s8(vec_v_bot_left_tmp_{0});\n\
+            int16x8_t vec_v_bot_left_bot_{0} = vmovl_s8(vget_low_s8(vec_v_bot_left_tmp_{0}));\n\
+            int16x8_t vec_v_bot_right_high_{0} = vmovl_high_s8(vec_v_bot_right_tmp_{0});\n\
+            int16x8_t vec_v_bot_right_bot_{0} = vmovl_s8(vget_low_s8(vec_v_bot_right_tmp_{0}));\n\
+            vec_c0 += vec_v_top_left_bot_{0};\n\
+            vec_c0 += vec_v_top_right_bot_{0};\n\
+            vec_c0 += vec_v_bot_left_bot_{0};\n\
+            vec_c0 += vec_v_bot_right_bot_{0};\n\
+            vec_c1 += vec_v_top_left_high_{0};\n\
+            vec_c1 += vec_v_top_right_high_{0};\n\
+            vec_c1 += vec_v_bot_left_high_{0};\n\
+            vec_c1 += vec_v_bot_right_high_{0};\n".format(i)])
+
+        kernel_code = "".join([kernel_code, "\
+        }\n\
+        int32x4_t vec_v_1 = vmovl_high_s16(vec_c0);\n\
+        int32x4_t vec_v_0 = vmovl_s16(vget_low_s16(vec_c0));\n\
+        int32x4_t vec_v_3 = vmovl_high_s16(vec_c1);\n\
+        int32x4_t vec_v_2 = vmovl_s16(vget_low_s16(vec_c1));\n\
+        vst1q_s32(c + i,      vld1q_s32(c + i     ) + vec_v_0);\n\
+        vst1q_s32(c + i + 4,  vld1q_s32(c + i + 4 ) + vec_v_1);\n\
+        vst1q_s32(c + i + 8,  vld1q_s32(c + i + 8 ) + vec_v_2);\n\
+        vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_3);\n\
+    }\n\
+#endif\n\
+}\n"])
+    elif bm == 32:
+        kernel_code = "".join([kernel_code, "\
+#ifdef __AVX2__\n\
+    const int KK = BBK{0} / 3;\n\
+    for (int i = 0; i < BM{0}; i += 32) {{\n\
+        __m256i vec_c0 = _mm256_set1_epi16(0);\n\
+        __m256i vec_c1 = _mm256_set1_epi16(0);\n\
+#pragma unroll\n\
+        for (int k = 0; k < KK / 8; k++) {{\n\
+            __m256i vec_sign = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + k * 32));\n\
+            __m128i vec_k_top_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 0 * 32 + 0));\n\
+            __m256i vec_k_top_256_0 = _mm256_set_m128i(vec_k_top_0, vec_k_top_0);\n\
+            __m128i vec_k_bot_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 0 * 32 + 16));\n\
+            __m256i vec_k_bot_256_0 = _mm256_set_m128i(vec_k_bot_0, vec_k_bot_0);\n\
+            __m256i vec_a_0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 32 * 4 + 0 * 32));\n\
+            __m256i vec_a_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_0 = _mm256_and_si256(vec_a_0, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_0 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 0), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_0 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_0, vec_a_top_0), vec_sign_top_0), vec_sign_top_0);\n\
+            __m256i vec_sign_bot_0 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 3 - 0), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_0 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_0, vec_a_bot_0), vec_sign_bot_0), vec_sign_bot_0);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_0)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_0)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_0, 1)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_0, 1)));\n\
+            __m128i vec_k_top_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 1 * 32 + 0));\n\
+            __m256i vec_k_top_256_1 = _mm256_set_m128i(vec_k_top_1, vec_k_top_1);\n\
+            __m128i vec_k_bot_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 1 * 32 + 16));\n\
+            __m256i vec_k_bot_256_1 = _mm256_set_m128i(vec_k_bot_1, vec_k_bot_1);\n\
+            __m256i vec_a_1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 32 * 4 + 1 * 32));\n\
+            __m256i vec_a_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_1 = _mm256_and_si256(vec_a_1, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_1 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 1), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_1 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_1, vec_a_top_1), vec_sign_top_1), vec_sign_top_1);\n\
+            __m256i vec_sign_bot_1 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 3 - 1), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_1 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_1, vec_a_bot_1), vec_sign_bot_1), vec_sign_bot_1);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_1)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_1, 1)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_1, 1)));\n\
+            __m128i vec_k_top_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 2 * 32 + 0));\n\
+            __m256i vec_k_top_256_2 = _mm256_set_m128i(vec_k_top_2, vec_k_top_2);\n\
+            __m128i vec_k_bot_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 2 * 32 + 16));\n\
+            __m256i vec_k_bot_256_2 = _mm256_set_m128i(vec_k_bot_2, vec_k_bot_2);\n\
+            __m256i vec_a_2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 32 * 4 + 2 * 32));\n\
+            __m256i vec_a_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_2 = _mm256_and_si256(vec_a_2, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_2 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_2, vec_a_top_2), vec_sign_top_2), vec_sign_top_2);\n\
+            __m256i vec_sign_bot_2 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 3 - 2), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_2 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_2, vec_a_bot_2), vec_sign_bot_2), vec_sign_bot_2);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_2)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_2)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_2, 1)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_2, 1)));\n\
+            __m128i vec_k_top_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 3 * 32 + 0));\n\
+            __m256i vec_k_top_256_3 = _mm256_set_m128i(vec_k_top_3, vec_k_top_3);\n\
+            __m128i vec_k_bot_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 128 * k + 3 * 32 + 16));\n\
+            __m256i vec_k_bot_256_3 = _mm256_set_m128i(vec_k_bot_3, vec_k_bot_3);\n\
+            __m256i vec_a_3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + k * 32 * 4 + 3 * 32));\n\
+            __m256i vec_a_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_a_bot_3 = _mm256_and_si256(vec_a_3, _mm256_set1_epi8(0x0f));\n\
+            __m256i vec_sign_top_3 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 7 - 3), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_top_3 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_top_256_3, vec_a_top_3), vec_sign_top_3), vec_sign_top_3);\n\
+            __m256i vec_sign_bot_3 = _mm256_sub_epi8(_mm256_and_si256(_mm256_srli_epi16(vec_sign, 3 - 3), _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));\n\
+            __m256i vec_v_bot_3 = _mm256_xor_si256(_mm256_add_epi8(_mm256_shuffle_epi8(vec_k_bot_256_3, vec_a_bot_3), vec_sign_bot_3), vec_sign_bot_3);\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot_3)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top_3)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot_3, 1)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top_3, 1)));\n\
+        }}\n\
+        __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i));\n\
+        __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8));\n\
+        __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16));\n\
+        __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24));\n\
+        vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\
+        vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\
+        vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\
+        vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i), vec_gc0);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8), vec_gc1);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16), vec_gc2);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24), vec_gc3);\n\
+    }}\n".format(pre)])
+
+        kernel_code = "".join([kernel_code, "\
+#elif defined __ARM_NEON\n\
+    const int KK = BBK{0} / 3;\n\
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
+#pragma unroll\n\
+    for (int i = 0; i < BM{0}; i += 32) {{\n\
+        int16x8_t vec_c0 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c1 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c2 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c3 = vdupq_n_s16(0);\n\
+#pragma unroll \n\
+        for (int k = 0; k < KK / 8; k++) {{\n\
+            uint8x16_t vec_sign_left = vmvnq_s8(vld1q_u8(sign + i * KK / 8 + k * 32));\n\
+            uint8x16_t vec_sign_right = vmvnq_u8(vld1q_u8(sign + i * KK / 8 + k * 32 + 16));\n".format(pre)])
+
+        for i in range(4):
+            kernel_code = "".join([kernel_code, "\
+            int8x16_t vec_k_left_{0} = vld1q_s8(lut + 128 * k + {0} * 32);\n\
+            int8x16_t vec_k_right_{0} = vld1q_s8(lut + 128 * k + {0} * 32 + 16);\n\
+            uint8x16_t vec_sign_left_left_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_left, {0}));\n\
+            uint8x16_t vec_sign_left_right_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_left, {0} + 4));\n\
+            uint8x16_t vec_sign_right_left_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_right, {0}));\n\
+            uint8x16_t vec_sign_right_right_{0} = vcltzq_s8(vshlq_n_u8(vec_sign_right, {0} + 4));\n\
+            uint8x16_t vec_a_left_{0} = vld1q_u8(a + i * KK / 2 + k * 128 + {0} * 32);\n\
+            uint8x16_t vec_a_right_{0} = vld1q_u8(a + i * KK / 2 + k * 128 + {0} * 32 + 16);\n\
+            uint8x16_t vec_a_left_left_{0} = vshrq_n_u8(vec_a_left_{0}, 4);\n\
+            uint8x16_t vec_a_left_right_{0} = vandq_u8(vec_a_left_{0}, vec_mask);\n\
+            uint8x16_t vec_a_right_left_{0} = vshrq_n_u8(vec_a_right_{0}, 4);\n\
+            uint8x16_t vec_a_right_right_{0} = vandq_u8(vec_a_right_{0}, vec_mask);\n\
+            int8x16_t vec_v_top_left_tmp_{0} = vqtbl1q_s8(vec_k_left_{0}, vec_a_left_left_{0});\n\
+            int8x16_t vec_v_bot_left_tmp_{0} = vqtbl1q_s8(vec_k_left_{0}, vec_a_right_left_{0});\n\
+            int8x16_t vec_v_top_right_tmp_{0} = vqtbl1q_s8(vec_k_right_{0}, vec_a_left_right_{0});\n\
+            int8x16_t vec_v_bot_right_tmp_{0} = vqtbl1q_s8(vec_k_right_{0}, vec_a_right_right_{0});\n\
+            vec_v_top_left_tmp_{0} = vbslq_s8(vec_sign_left_left_{0}, vnegq_s8(vec_v_top_left_tmp_{0}), vec_v_top_left_tmp_{0});\n\
+            vec_v_bot_left_tmp_{0} = vbslq_s8(vec_sign_right_left_{0}, vnegq_s8(vec_v_bot_left_tmp_{0}), vec_v_bot_left_tmp_{0});\n\
+            vec_v_top_right_tmp_{0} = vbslq_s8(vec_sign_left_right_{0}, vnegq_s8(vec_v_top_right_tmp_{0}), vec_v_top_right_tmp_{0});\n\
+            vec_v_bot_right_tmp_{0} = vbslq_s8(vec_sign_right_right_{0}, vnegq_s8(vec_v_bot_right_tmp_{0}), vec_v_bot_right_tmp_{0});\n\
+            int16x8_t vec_v_top_left_high_{0} = vmovl_high_s8(vec_v_top_left_tmp_{0});\n\
+            int16x8_t vec_v_top_left_bot_{0} = vmovl_s8(vget_low_s8(vec_v_top_left_tmp_{0}));\n\
+            int16x8_t vec_v_top_right_high_{0} = vmovl_high_s8(vec_v_top_right_tmp_{0});\n\
+            int16x8_t vec_v_top_right_bot_{0} = vmovl_s8(vget_low_s8(vec_v_top_right_tmp_{0}));\n\
+            int16x8_t vec_v_bot_left_high_{0} = vmovl_high_s8(vec_v_bot_left_tmp_{0});\n\
+            int16x8_t vec_v_bot_left_bot_{0} = vmovl_s8(vget_low_s8(vec_v_bot_left_tmp_{0}));\n\
+            int16x8_t vec_v_bot_right_high_{0} = vmovl_high_s8(vec_v_bot_right_tmp_{0});\n\
+            int16x8_t vec_v_bot_right_bot_{0} = vmovl_s8(vget_low_s8(vec_v_bot_right_tmp_{0}));\n\
+            vec_c0 += vec_v_top_left_bot_{0};\n\
+            vec_c0 += vec_v_top_right_bot_{0};\n\
+            vec_c1 += vec_v_bot_left_bot_{0};\n\
+            vec_c1 += vec_v_bot_right_bot_{0};\n\
+            vec_c2 += vec_v_top_left_high_{0};\n\
+            vec_c2 += vec_v_top_right_high_{0};\n\
+            vec_c3 += vec_v_bot_left_high_{0};\n\
+            vec_c3 += vec_v_bot_right_high_{0};\n".format(i)])
+
+        kernel_code = "".join([kernel_code, "\
+        }\n\
+        int32x4_t vec_v_1 = vmovl_high_s16(vec_c0);\n\
+        int32x4_t vec_v_0 = vmovl_s16(vget_low_s16(vec_c0));\n\
+        int32x4_t vec_v_3 = vmovl_high_s16(vec_c1);\n\
+        int32x4_t vec_v_2 = vmovl_s16(vget_low_s16(vec_c1));\n\
+        int32x4_t vec_v_5 = vmovl_high_s16(vec_c2);\n\
+        int32x4_t vec_v_4 = vmovl_s16(vget_low_s16(vec_c2));\n\
+        int32x4_t vec_v_7 = vmovl_high_s16(vec_c3);\n\
+        int32x4_t vec_v_6 = vmovl_s16(vget_low_s16(vec_c3));\n\
+\n\
+        vst1q_s32(c + i,      vld1q_s32(c + i     ) + vec_v_0);\n\
+        vst1q_s32(c + i + 4,  vld1q_s32(c + i + 4 ) + vec_v_1);\n\
+        vst1q_s32(c + i + 8,  vld1q_s32(c + i + 8 ) + vec_v_4);\n\
+        vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_5);\n\
+        vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_2);\n\
+        vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_3);\n\
+        vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_6);\n\
+        vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_7);\n\
+    }\n\
+#endif\n\
+}\n"])
+
+    kernel_code = "".join([kernel_code, "\
+\n\
+template\n\
+inline int32_t two_tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
+#ifdef __AVX2__\n\
+    const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\
+    const __m256i vec_sub  = _mm256_set1_epi8(0x01);\n\
+    const int KK = 16;\n\
+    __m256i vec_lut[KK];\n\
+    for (int k = 0; k < KK; k++) {{\n\
+        __m128i vec_k = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + 8 * k));\n\
+        vec_lut[k] = _mm256_set_m128i(vec_k, vec_k);\n\
+    }}\n\
+#pragma unroll\n\
+    for (int i = 0; i < BM{0} / 2; i += 16) {{\n\
+        __m256i vec_c0 = _mm256_set1_epi16(0);\n\
+        __m256i vec_c1 = _mm256_set1_epi16(0);\n\
+#pragma unroll\n\
+        for (int k = 0; k < KK / 2; k++) {{\n\
+            __m256i vec_as = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK + k * 32));\n\
+            __m256i vec_v_bot = _mm256_shuffle_epi8(vec_lut[2 * k + 1], _mm256_and_si256(vec_as, vec_mask));\n\
+            __m256i vec_v_top = _mm256_shuffle_epi8(vec_lut[2 * k], _mm256_and_si256(_mm256_srli_epi16(vec_as, 4), vec_mask));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_bot)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_bot, 1)));\n\
+            vec_c0 = _mm256_add_epi16(vec_c0, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vec_v_top)));\n\
+            vec_c1 = _mm256_add_epi16(vec_c1, _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vec_v_top, 1)));\n\
+        }}\n\
+        __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2));\n\
+        __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 8));\n\
+        __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 16));\n\
+        __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 24));\n\
+        vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\
+        vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\
+        vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\
+        vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2), vec_gc0);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 8), vec_gc1);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 16), vec_gc2);\n\
+        _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i * 2 + 24), vec_gc3);\n\
+    }}\n\
+#elif defined __ARM_NEON\n\
+    const int KK = 16;\n\
+    const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
+    const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\
+    int8x16_t vec_lut[KK];\n\
+#pragma unroll\n\
+    for (int k = 0; k < KK; k++) {{\n\
+        vec_lut[k] = vld1q_s8(lut + k * 16);\n\
+    }}\n\
+    for (int i = 0; i < BM{0} / 2; i += 16) {{\n\
+        int16x8_t vec_c0 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c1 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c2 = vdupq_n_s16(0);\n\
+        int16x8_t vec_c3 = vdupq_n_s16(0);\n\
+        for (int k = 0; k < KK / 2; k++) {{\n\
+            uint8x16_t vec_a_top = vld1q_u8(a + i * KK + k * 32);\n\
+            uint8x16_t vec_a_bot = vld1q_u8(a + i * KK + k * 32 + 16);\n\
+            uint8x16_t vec_a_top_left = vshrq_n_u8(vec_a_top, 4);\n\
+            uint8x16_t vec_a_top_right = vandq_u8(vec_a_top, vec_mask);\n\
+            uint8x16_t vec_a_bot_left = vshrq_n_u8(vec_a_bot, 4);\n\
+            uint8x16_t vec_a_bot_right = vandq_u8(vec_a_bot, vec_mask);\n\
+            int8x16_t vec_v_top_left_tmp = vqtbl1q_s8(vec_lut[2 * k], vec_a_top_left);\n\
+            int8x16_t vec_v_top_right_tmp = vqtbl1q_s8(vec_lut[2 * k + 1], vec_a_top_right);\n\
+            int8x16_t vec_v_bot_left_tmp = vqtbl1q_s8(vec_lut[2 * k], vec_a_bot_left);\n\
+            int8x16_t vec_v_bot_right_tmp = vqtbl1q_s8(vec_lut[2 * k + 1], vec_a_bot_right);\n\
+            int16x8_t vec_v_top_left_high = vmovl_high_s8(vec_v_top_left_tmp);\n\
+            int16x8_t vec_v_top_left_bot = vmovl_s8(vget_low_s8(vec_v_top_left_tmp));\n\
+            int16x8_t vec_v_top_right_high = vmovl_high_s8(vec_v_top_right_tmp);\n\
+            int16x8_t vec_v_top_right_bot = vmovl_s8(vget_low_s8(vec_v_top_right_tmp));\n\
+            int16x8_t vec_v_bot_left_high = vmovl_high_s8(vec_v_bot_left_tmp);\n\
+            int16x8_t vec_v_bot_left_bot = vmovl_s8(vget_low_s8(vec_v_bot_left_tmp));\n\
+            int16x8_t vec_v_bot_right_high = vmovl_high_s8(vec_v_bot_right_tmp);\n\
+            int16x8_t vec_v_bot_right_bot = vmovl_s8(vget_low_s8(vec_v_bot_right_tmp));\n\
+            vec_c0 += vec_v_top_left_bot;\n\
+            vec_c0 += vec_v_top_right_bot;\n\
+            vec_c1 += vec_v_top_left_high;\n\
+            vec_c1 += vec_v_top_right_high;\n\
+            vec_c2 += vec_v_bot_left_bot;\n\
+            vec_c2 += vec_v_bot_right_bot;\n\
+            vec_c3 += vec_v_bot_left_high;\n\
+            vec_c3 += vec_v_bot_right_high;\n\
+        }}\n\
+        int32x4_t vec_v_1 = vmovl_high_s16(vec_c0);\n\
+        int32x4_t vec_v_0 = vmovl_s16(vget_low_s16(vec_c0));\n\
+        int32x4_t vec_v_3 = vmovl_high_s16(vec_c1);\n\
+        int32x4_t vec_v_2 = vmovl_s16(vget_low_s16(vec_c1));\n\
+        int32x4_t vec_v_5 = vmovl_high_s16(vec_c2);\n\
+        int32x4_t vec_v_4 = vmovl_s16(vget_low_s16(vec_c2));\n\
+        int32x4_t vec_v_7 = vmovl_high_s16(vec_c3);\n\
+        int32x4_t vec_v_6 = vmovl_s16(vget_low_s16(vec_c3));\n\
+        vst1q_s32(c + i * 2,      vld1q_s32(c + i * 2     ) + vec_v_0);\n\
+        vst1q_s32(c + i * 2 + 4,  vld1q_s32(c + i * 2 + 4 ) + vec_v_1);\n\
+        vst1q_s32(c + i * 2 + 8,  vld1q_s32(c + i * 2 + 8 ) + vec_v_2);\n\
+        vst1q_s32(c + i * 2 + 12, vld1q_s32(c + i * 2 + 12) + vec_v_3);\n\
+        vst1q_s32(c + i * 2 + 16, vld1q_s32(c + i * 2 + 16) + vec_v_4);\n\
+        vst1q_s32(c + i * 2 + 20, vld1q_s32(c + i * 2 + 20) + vec_v_5);\n\
+        vst1q_s32(c + i * 2 + 24, vld1q_s32(c + i * 2 + 24) + vec_v_6);\n\
+        vst1q_s32(c + i * 2 + 28, vld1q_s32(c + i * 2 + 28) + vec_v_7);\n\
+    }}\n\
+#endif\n\
+    return 0;\n\
+}};\n\
+\n\
+template\n\
+int32_t three_qgemm_lut_{0}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
+    alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\
+    memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\
+#pragma unroll\n\
+    for (int32_t k_outer = 0; k_outer < {1} / BBK{0}; ++k_outer) {{\n\
+        three_tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 3 * 16)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 3 / 2 * BM{0})])), (&(((uint8_t*)sign)[(k_outer * BBK{0} / 3 / 8 * BM{0})])));\n\
+    }}\n\
+#pragma unroll\n\
+    for (int i = 0; i < BM{0}; i++) {{\n\
+        ((bitnet_float_type*)C)[i] = (bitnet_float_type)((float)(((int32_t*)CBits)[i]) * ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]);\n\
+    }}\n\
+  return 0;\n\
+}}\n\
+\n\
+template\n\
+int32_t two_qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
+    alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\
+    memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\
+#pragma unroll\n\
+    for (int32_t k_outer = 0; k_outer < {2} / 32; ++k_outer) {{\n\
+        two_tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 16)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{0})])));\n\
+    }}\n\
+#pragma unroll\n\
+    for (int i = 0; i < BM{0}; i++) {{\n\
+        ((bitnet_float_type*)C)[i] += (bitnet_float_type)((float)(((int32_t*)CBits)[i]) * ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]);\n\
+    }}\n\
+  return 0;\n\
+}}\n\
+\n\
+".format(pre, k_list[1], k_list[0])])
+    return kernel_code
+
+def gen_top_api(kernel_shapes, k_list):
+
+    kernel_code = "void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* Three_LUT_Scales, void* Two_LUT_Scales, void* Three_QLUT, void* Two_QLUT) {{\n\
+    partial_max_reset(bs, (&(((bitnet_float_type*)Three_LUT_Scales)[0])));\n\
+    partial_max_reset(bs, (&(((bitnet_float_type*)Two_LUT_Scales)[0])));\n\
+    for (int32_t b = 0; b < bs; b++) {{\n\
+        for (int32_t k_outer = 0; k_outer < (three_k + two_k) / 24; ++k_outer) {{\n\
+            three_partial_max((&(((bitnet_float_type*)Three_LUT_Scales)[b])), (&(((bitnet_float_type*)B)[(k_outer * 24)])));\n\
+        }}\n\
+        for (int32_t k_outer = 0; k_outer < (three_k + two_k) / 16; ++k_outer) {{\n\
+            two_partial_max((&(((bitnet_float_type*)Two_LUT_Scales)[b])), (&(((bitnet_float_type*)B)[(k_outer * 16)])));\n\
+        }}\n\
+    }}\n\
+    if (m == {0} && two_k == {1} && three_k == {2}) {{\n\
+        for (int32_t b = 0; b < bs; b++) {{\n\
+            three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 16])), (&(((bitnet_float_type*)B)[b * (three_k + two_k)])), (&(((bitnet_float_type*)Three_LUT_Scales)[b])));\n\
+            two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 16])), (&(((bitnet_float_type*)B)[b * (three_k + two_k) + {2}])), (&(((bitnet_float_type*)Two_LUT_Scales)[b])));\n\
+        }}\n\
+    }}\n\
+".format(kernel_shapes[0][0], k_list[0][0], k_list[0][1])
+    for i in range(1, len(kernel_shapes)):
+        kernel_code = "".join([kernel_code, "    else if (m == {0} && two_k == {1} && three_k == {2}) {{\n\
+        for (int32_t b = 0; b < bs; b++) {{\n\
+            three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 16])), (&(((bitnet_float_type*)B)[b * (three_k + two_k)])), (&(((bitnet_float_type*)Three_LUT_Scales)[b])));\n\
+            two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 16])), (&(((bitnet_float_type*)B)[b * (three_k + two_k) + {2}])), (&(((bitnet_float_type*)Two_LUT_Scales)[b])));\n\
+        }}\n\
+    }}\n".format(kernel_shapes[i][0], k_list[i][0], k_list[i][1])])
+    kernel_code = "".join([kernel_code, "}\n"])
+
+
+    kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
+    if (m == {0} && k == {1}) {{\n\
+        if (BK == {2}) {{\n\
+            if (bs == 1) {{\n\
+                two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\
+            }}\n\
+        }}\n\
+        else if (BK == {3}) {{\n\
+            if (bs == 1) {{\n\
+                three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\
+            }}\n\
+        }}\n\
+    }}\n\
+".format(kernel_shapes[0][0], kernel_shapes[0][1], k_list[0][0], k_list[0][1], "{}_{}".format(kernel_shapes[0][0], kernel_shapes[0][1]))])
+    for i in range(1, len(kernel_shapes)):
+        kernel_code = "".join([kernel_code, "    else if (m == {0} && k == {1}) {{\n\
+        if (BK == {2}) {{\n\
+            if (bs == 1) {{\n\
+                two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\
+            }}\n\
+        }}\n\
+        else if (BK == {3}) {{\n\
+            if (bs == 1) {{\n\
+                three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\
+            }}\n\
+        }}\n\
+    }}\n\
+".format(kernel_shapes[i][0], kernel_shapes[i][1], k_list[i][0], k_list[i][1], "{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]))])
+    kernel_code = "".join([kernel_code, "}\n"])
+    return kernel_code
+
+def gen_transform_code(kernel_shapes):
+    kernel_code = "\n\
+void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
+    if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
+        return;\n\
+    }\n\
+\n\
+    int k = tensor->ne[0];\n\
+    int m = tensor->ne[1];\n\
+    const int lut_scales_size = 1;\n\
+    int bk = 0;\n\
+    int bm = 0;\n"
+
+    kernel_code = "".join([kernel_code, "\n\
+    if (m == {0} && k == {1}) {{\n\
+        bm = BM{0}_{1};\n\
+        bk = BBK{0}_{1};\n\
+    }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
+
+    for i in range(1, len(kernel_shapes)):
+        kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
+        bm = BM{0}_{1};\n\
+        bk = BBK{0}_{1};\n\
+    }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
+
+    kernel_code = "".join([kernel_code, "\n\
+    const int n_tile_num = m / bm;\n\
+    const int BK = bk;\n\
+    uint8_t * qweights;\n\
+    bitnet_float_type * scales;\n\
+\n\
+    scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
+    qweights = (uint8_t *) tensor->data;\n\
+    int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;\n\
+    nbytes = 32 - nbytes % 32 + nbytes;\n\
+    float * i2_scales = (float * )(qweights + nbytes);\n\
+\n"])
+
+    kernel_code = "".join([kernel_code, "\
+    scales[0] = (bitnet_float_type) i2_scales[0];\n"])
+
+    kernel_code = "".join([kernel_code, "\n\
+    tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
+    bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
+        /* .lut_scales_size = */ lut_scales_size,\n\
+        /* .BK              = */ BK,\n\
+        /* .n_tile_num      = */ n_tile_num,\n\
+        /* .qweights        = */ qweights,\n\
+        /* .scales          = */ scales\n\
+    };\n\
+}\n"])
+
+    return kernel_code
+
+def get_three_k_two_k(K, bk):
+    bk_num = K // bk
+    three_k = bk_num * bk
+    two_k = K - three_k
+    return two_k, three_k
+
+if __name__ == "__main__":
+    ModelShapeDict = {
+        "bitnet_b1_58-large"                : [[1536, 4096],
+                                               [1536, 1536],
+                                               [4096, 1536]],
+        "bitnet_b1_58-3B"                   : [[3200, 8640],
+                                               [3200, 3200],
+                                               [8640, 3200]],
+        "Llama3-8B-1.58-100B-tokens"        : [[14336, 4096],
+                                               [4096, 14336],
+                                               [1024, 4096],
+                                               [4096, 4096]] 
+    }
+
+    parser = argparse.ArgumentParser(description='gen impl')
+    parser.add_argument('--model',default="input", type=str, dest="model", 
+                        help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
+    parser.add_argument('--BM',default="input", type=str,
+                        help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
+    parser.add_argument('--BK',default="input", type=str,
+                        help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
+    parser.add_argument('--bm',default="input", type=str,
+                        help="using simd instructions to compute (bm, 192 / bm) in one block")
+    args = parser.parse_args()
+
+    kernel_shapes = ModelShapeDict[args.model]
+
+    BM_list = [int(item) for item in args.BM.split(',')]
+    BK_list = [int(item) for item in args.BK.split(',')]
+    bm_list = [int(item) for item in args.bm.split(',')]
+
+    tbl_impl_code = []
+    k_list = []
+
+    for i in range(len(kernel_shapes)):
+        k_list.append(get_three_k_two_k(kernel_shapes[i][1], BK_list[i]))
+
+    for i in range(len(kernel_shapes)):
+        tbl_impl_code.append(
+            gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], k_list[i])
+        )
+
+    assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
+    
+    for i in range(len(kernel_shapes)):
+        assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
+        assert (kernel_shapes[i][1] % BK_list[i]) % 32 == 0, "K %% BK %% 32 should be 0"
+        assert bm_list[i] in [16, 32], "choose bm from [16, 32]"
+
+    ctor_code = gen_ctor_code()
+    api_code = gen_top_api(kernel_shapes, k_list)
+    trans_code = gen_transform_code(kernel_shapes)
+
+    output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
+
+    with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
+        f.write(''.join("#if defined(GGML_BITNET_TL2_LOSS)"))
+        f.write(''.join(ctor_code))
+        for code in tbl_impl_code:
+            f.write(''.join(code))
+        f.write(''.join(api_code))
+        f.write(''.join(trans_code))
+        f.write(''.join("#endif"))
+
+    config = ConfigParser()
+
+    for i in range(len(kernel_shapes)):
+        config.add_section('Kernels_{}'.format(i))
+        config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
+        config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
+        config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
+        config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
+        config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
+
+    with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
+        config.write(configfile)
\ No newline at end of file
diff --git a/utils/convert-hf-to-gguf-bitnet.py b/utils/convert-hf-to-gguf-bitnet.py
index f525f58..6d0e46c 100644
--- a/utils/convert-hf-to-gguf-bitnet.py
+++ b/utils/convert-hf-to-gguf-bitnet.py
@@ -517,6 +517,92 @@ def preprocess_weights_tl1(
     return weight
 
 
+def preprocess_two_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
+    weight = np.reshape(weight, (weight_num // 2, 2))
+    hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
+    lo_weight = np.split(weight, 2, axis=1)[1]
+
+    weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
+    weight = weight + 4
+    weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
+    weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1)
+    weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2))
+    weight_0 = weight[:, :, :, :, :, 0]
+    weight_1 = weight[:, :, :, :, :, 1]
+    weight_0 = weight_0 << 4
+    weight_1 = weight_1
+    weight = weight_0 + weight_1
+    weight = weight.reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
+
+    for i in range(weight.shape[0]):
+        final_weight.append(weight[i, :])
+
+def preprocess_three_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
+    weight = np.reshape(weight, (weight_num // 3, 3))
+    split_weights = np.split(weight, 3, axis=1)
+    first_weight = np.multiply(split_weights[0], 9)
+    second_weight = np.multiply(split_weights[1], 3)
+    third_weight = split_weights[2]
+
+    weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3)
+    sign_weight = np.sign(weight)
+    sign_weight = np.where(sign_weight < 1, 0, sign_weight)
+    weight = np.abs(weight)
+
+    weight = np.reshape(weight, (M, K // 3)).astype(np.uint8)
+    sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8)
+
+    weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
+    weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4)
+    weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3))
+    
+    weight_list = []
+    for i in range(by // 3):
+        weight_list.append(weight[:, :, :, :, :, i])
+    
+    for i in range(by // 3 // 2):
+        weight_list[i] = weight_list[i] << 4
+        weight_list[i + by // 3 // 2] = weight_list[i + by // 3 // 2]
+        weight_list[i] = weight_list[i] + weight_list[i + by // 3 // 2]
+        weight_list[i] = weight_list[i].reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
+
+    for i in range(weight_list[0].shape[0]):
+        for j in range(by // 3 // 2):
+            final_weight.append(weight_list[j][i, :])
+
+    sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
+    sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
+    sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
+    sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4).astype(np.uint8)
+
+    combine_weight_list = []
+    for i in range(by // 3 // 2):
+        combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm), dtype=np.uint8)
+        combine_weight_list.append(combine_weight)
+
+    for i in range(8):
+        for j in range(by // 3 // 2):
+            if bm == 16:
+                combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * i + j] << 7 - i)
+            elif bm == 32:
+                if i > 3 :
+                    ti = (i - 4) * 2 + 1
+                else:
+                    ti = i * 2
+                combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * ti + j] << 7 - i)
+
+    for i in range(by // 3 // 2):
+        combine_weight_list[i] = combine_weight_list[i].reshape((M * K // (by * 4)) // 16, 16)
+
+    for i in range(combine_weight_list[0].shape[0]):
+        for j in range(by // 3 // 2):
+            final_weight.append(combine_weight_list[j][i, :])
+
 def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
     weight = np.reshape(weight, (weight_num // 2, 2))
     hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
@@ -603,7 +689,6 @@ def preprocess_weights_tl2(
     weight = w
     weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
     weight = np.sign(weight)
-    weight_num = np.prod(weight.shape)
 
     config.read('include/kernel_config.ini')
     BM = -1
@@ -631,7 +716,8 @@ def preprocess_weights_tl2(
 
     final_weight = []
 
-    preprocess_three_weights_tl2(three_weight.shape[0],
+    if args.loss:
+        preprocess_three_weights_tl2_loss(three_weight.shape[0],
                          three_weight.shape[1],
                          three_weight.shape[0] * three_weight.shape[1],
                          BM,
@@ -641,8 +727,29 @@ def preprocess_weights_tl2(
                          three_weight,
                          final_weight)
 
-    if (weight.shape[1] % BY != 0):
-        preprocess_two_weights_tl2(  two_weight.shape[0],
+        if (weight.shape[1] % BY != 0):
+            preprocess_two_weights_tl2_loss(two_weight.shape[0],
+                         two_weight.shape[1],
+                         two_weight.shape[0] * two_weight.shape[1],
+                         BM,
+                         32,
+                         32,
+                         4,
+                         two_weight,
+                         final_weight)
+    else:
+        preprocess_three_weights_tl2(three_weight.shape[0],
+                         three_weight.shape[1],
+                         three_weight.shape[0] * three_weight.shape[1],
+                         BM,
+                         BY,
+                         bm,
+                         by,
+                         three_weight,
+                         final_weight)
+
+        if (weight.shape[1] % BY != 0):
+            preprocess_two_weights_tl2(two_weight.shape[0],
                          two_weight.shape[1],
                          two_weight.shape[0] * two_weight.shape[1],
                          BM,
@@ -652,8 +759,10 @@ def preprocess_weights_tl2(
                          two_weight,
                          final_weight)
     weight = np.array(final_weight, dtype=np.uint8).reshape(-1)
-    weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 -
-                             weight.shape[0]), mode='constant', constant_values=0)
+    pad_nums = (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8
+    pad_align_nums = 32 - ((K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8) % 32
+    pad_nums = pad_nums + pad_align_nums
+    weight = np.pad(weight, (0, pad_nums - weight.shape[0]), mode='constant', constant_values=0)
     return weight
 
 def transform_to_tl1(x: np.ndarray):
@@ -1116,6 +1225,7 @@ def parse_args() -> argparse.Namespace:
     parser.add_argument("--model-name", type=str, default=None, help="name of the model")
     parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
     parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer")
+    parser.add_argument("--loss", action="store_true", help="use loss tl2")
 
     return parser.parse_args()