commit paper code

This commit is contained in:
Eddie-Wang1120
2025-02-16 15:03:25 +08:00
parent 437b321dcf
commit 4c736e3728
10 changed files with 1956 additions and 26 deletions
+2 -2
View File
@@ -1,4 +1,4 @@
[submodule "3rdparty/llama.cpp"]
path = 3rdparty/llama.cpp
url = https://github.com/Eddie-Wang1120/llama.cpp.git
branch = merge-dev
url = git@github.com:Eddie-Wang1120/llama.cpp.git
branch = pp
+4 -3
View File
@@ -14,6 +14,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
# option list
option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF)
option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF)
option(BITNET_TL2_LOSS "bitnet.cpp: use tl2 on x86 platform" OFF)
set(CMAKE_CXX_STANDARD_REQUIRED true)
@@ -24,6 +25,7 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
# override ggml options
set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1})
set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2})
set(GGML_BITNET_TL2_LOSS ${BITNET_TL2_LOSS})
if (GGML_BITNET_ARM_TL1)
add_compile_definitions(GGML_BITNET_ARM_TL1)
@@ -31,9 +33,8 @@ endif()
if (GGML_BITNET_X86_TL2)
add_compile_definitions(GGML_BITNET_X86_TL2)
endif()
if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
add_compile_options(-fpermissive)
if (GGML_BITNET_TL2_LOSS)
add_compile_definitions(GGML_BITNET_TL2_LOSS)
endif()
find_package(Threads REQUIRED)
+14 -5
View File
@@ -43,8 +43,9 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
</tr>
<tr>
<th>I2_S</th>
<th>TL1</th>
<th>TL2</th>
<th>TL1(TL1_1)</th>
<th>TL2(TL2_1)</th>
<th>TL2-Loss(TL2_0)</th>
</tr>
<tr>
<td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-large">bitnet_b1_58-large</a></td>
@@ -53,12 +54,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
<td>&#9989;</td>
</tr>
<tr>
<td>ARM</td>
<td>&#9989;</td>
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
</tr>
<tr>
<td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">bitnet_b1_58-3B</a></td>
@@ -67,12 +70,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
<td>&#10060;</td>
<td>&#10060;</td>
<td>&#9989;</td>
<td>&#9989;</td>
</tr>
<tr>
<td>ARM</td>
<td>&#10060;</td>
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
</tr>
<tr>
<td rowspan="2"><a href="https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens">Llama3-8B-1.58-100B-tokens</a></td>
@@ -81,12 +86,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
<td>&#9989;</td>
</tr>
<tr>
<td>ARM</td>
<td>&#9989;</td>
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
</tr>
<tr>
<td rowspan="2"><a href="https://huggingface.co/collections/tiiuae/falcon3-67605ae03578be86e4e87026">Falcon3 Family</a></td>
@@ -95,12 +102,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
<td>&#9989;</td>
</tr>
<tr>
<td>ARM</td>
<td>&#9989;</td>
<td>&#9989;</td>
<td>&#10060;</td>
<td>&#9989;</td>
</tr>
</table>
@@ -144,11 +153,11 @@ pip install -r requirements.txt
3. Build the project
```bash
# Download the model from Hugging Face, convert it to quantized gguf format, and build the project
python setup_env.py --hf-repo tiiuae/Falcon3-7B-Instruct-1.58bit -q i2_s
python setup_env.py --hf-repo 1bitLLM/bitnet_b1_58-large -q i2_s
# Or you can manually download the model and run with local path
huggingface-cli download tiiuae/Falcon3-7B-Instruct-1.58bit --local-dir models/Falcon3-7B-Instruct-1.58bit
python setup_env.py -md models/Falcon3-7B-Instruct-1.58bit -q i2_s
huggingface-cli download 1bitLLM/bitnet_b1_58-large --local-dir models/bitnet_b1_58-large
python setup_env.py -md models/bitnet_b1_58-large -q i2_s
```
<pre>
usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
+627
View File
@@ -0,0 +1,627 @@
#if defined(GGML_BITNET_ARM_TL1)
#include "ggml-bitnet.h"
#define GGML_BITNET_MAX_NODES 8192
static bool initialized = false;
static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
static size_t bitnet_tensor_extras_index = 0;
static void * aligned_malloc(size_t size) {{
#if defined(_WIN32)
return _aligned_malloc(size, 64);
#else
void * ptr = nullptr;
posix_memalign(&ptr, 64, size);
return ptr;
#endif
}}
static void aligned_free(void * ptr) {{
#if defined(_WIN32)
_aligned_free(ptr);
#else
free(ptr);
#endif
}}
void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
bitnet_float_type* b = (bitnet_float_type*)b_;
#ifdef __ARM_NEON
float32x4_t temp_max = vdupq_n_f32(0);
for (int i=0; i < k / 4; i++) {{
float32x4_t vec_bs = vld1q_f32(b + 4 * i);
float32x4_t abssum = vabsq_f32(vec_bs);
temp_max = vmaxq_f32(abssum, temp_max);
}}
float32_t scales = 127 / vmaxvq_f32(temp_max);
*lut_scales = scales;
#elif defined __AVX2__
__m256 max_vec = _mm256_set1_ps(0.f);
const __m256 vec_sign = _mm256_set1_ps(-0.0f);
// #pragma unroll
for (int i = 0; i < k / 8; i++) {{
__m256 vec_b = _mm256_loadu_ps(b + i * 8);
__m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
max_vec = _mm256_max_ps(vec_babs, max_vec);
}}
__m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
float scales = 127 / _mm_cvtss_f32(max1);
*lut_scales = scales;
#endif
}}
void partial_max_reset(void* lut_scales_) {{
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
*lut_scales = 0.0;
}}
#ifdef __ARM_NEON
inline void Transpose_8_8(
int16x8_t *v0,
int16x8_t *v1,
int16x8_t *v2,
int16x8_t *v3,
int16x8_t *v4,
int16x8_t *v5,
int16x8_t *v6,
int16x8_t *v7)
{{
int16x8x2_t q04 = vzipq_s16(*v0, *v4);
int16x8x2_t q15 = vzipq_s16(*v1, *v5);
int16x8x2_t q26 = vzipq_s16(*v2, *v6);
int16x8x2_t q37 = vzipq_s16(*v3, *v7);
int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
*v0 = q_fin_0.val[0];
*v1 = q_fin_0.val[1];
*v2 = q_fin_1.val[0];
*v3 = q_fin_1.val[1];
*v4 = q_fin_2.val[0];
*v5 = q_fin_2.val[1];
*v6 = q_fin_3.val[0];
*v7 = q_fin_3.val[1];
}}
#endif
template<int act_k>
inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
#ifdef __ARM_NEON
int16x8_t vec_lut[16];
float32_t scales = *lut_scales;
uint8_t tbl_mask[16];
tbl_mask[0] = 0;
tbl_mask[1] = 2;
tbl_mask[2] = 4;
tbl_mask[3] = 6;
tbl_mask[4] = 8;
tbl_mask[5] = 10;
tbl_mask[6] = 12;
tbl_mask[7] = 14;
tbl_mask[8] = 1;
tbl_mask[9] = 3;
tbl_mask[10] = 5;
tbl_mask[11] = 7;
tbl_mask[12] = 9;
tbl_mask[13] = 11;
tbl_mask[14] = 13;
tbl_mask[15] = 15;
uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
#pragma unroll
for (int k = 0; k < act_k / 16; ++k) {{
float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
vec_lut[0] = vdupq_n_s16(0);
vec_lut[0] = vec_lut[0] - vec_bs_0;
vec_lut[0] = vec_lut[0] - vec_bs_1;
vec_lut[1] = vdupq_n_s16(0);
vec_lut[1] = vec_lut[1] - vec_bs_0;
vec_lut[2] = vdupq_n_s16(0);
vec_lut[2] = vec_lut[2] - vec_bs_0;
vec_lut[2] = vec_lut[2] + vec_bs_1;
vec_lut[3] = vdupq_n_s16(0);
vec_lut[3] = vec_lut[3] - vec_bs_1;
vec_lut[4] = vdupq_n_s16(0);
vec_lut[5] = vec_bs_1;
vec_lut[6] = vec_bs_0;
vec_lut[6] = vec_lut[6] - vec_bs_1;
vec_lut[7] = vec_bs_0;
vec_lut[8] = vec_bs_0;
vec_lut[8] = vec_lut[8] + vec_bs_1;
Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
&(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
&(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
#pragma unroll
for (int idx = 0; idx < 8; idx++) {{
int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
int8x8_t q0_low = vget_low_s8(q0_s);
int8x8_t q0_high = vget_high_s8(q0_s);
int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
int8x8_t q1_low = vget_low_s8(q1_s);
int8x8_t q1_high = vget_high_s8(q1_s);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
}}
}}
#endif
}}
static bool is_type_supported(enum ggml_type type) {{
if (type == GGML_TYPE_Q4_0 ||
type == GGML_TYPE_TL1) {{
return true;
}} else {{
return false;
}}
}}
#include <arm_neon.h>
#define BM1536_4096 256
#define BBK1536_4096 128
inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK1536_4096 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[4];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM1536_4096; i += 32) {
#pragma unroll
for (int i=0; i<4; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 4; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[0] += vec_v_left_1.val[0];
vec_c[0] += vec_v_right_1.val[0];
vec_c[1] += vec_v_left_1.val[1];
vec_c[1] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[2] += vec_v_left_2.val[0];
vec_c[2] += vec_v_right_2.val[0];
vec_c[3] += vec_v_left_2.val[1];
vec_c[3] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[2] += vec_v_left_3.val[0];
vec_c[2] += vec_v_right_3.val[0];
vec_c[3] += vec_v_left_3.val[1];
vec_c[3] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
}
#endif
}
int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM1536_4096];
memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
}
#pragma unroll
for (int i = 0; i < BM1536_4096; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
#include <arm_neon.h>
#define BM1536_1536 128
#define BBK1536_1536 64
inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK1536_1536 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[8];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM1536_1536; i += 64) {
#pragma unroll
for (int i=0; i<8; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 2; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[2] += vec_v_left_1.val[0];
vec_c[2] += vec_v_right_1.val[0];
vec_c[3] += vec_v_left_1.val[1];
vec_c[3] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[4] += vec_v_left_2.val[0];
vec_c[4] += vec_v_right_2.val[0];
vec_c[5] += vec_v_left_2.val[1];
vec_c[5] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[6] += vec_v_left_3.val[0];
vec_c[6] += vec_v_right_3.val[0];
vec_c[7] += vec_v_left_3.val[1];
vec_c[7] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
}
#endif
}
int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM1536_1536];
memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
}
#pragma unroll
for (int i = 0; i < BM1536_1536; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
#include <arm_neon.h>
#define BM4096_1536 256
#define BBK4096_1536 128
inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK4096_1536 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[4];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM4096_1536; i += 32) {
#pragma unroll
for (int i=0; i<4; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 4; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[0] += vec_v_left_1.val[0];
vec_c[0] += vec_v_right_1.val[0];
vec_c[1] += vec_v_left_1.val[1];
vec_c[1] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[2] += vec_v_left_2.val[0];
vec_c[2] += vec_v_right_2.val[0];
vec_c[3] += vec_v_left_2.val[1];
vec_c[3] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[2] += vec_v_left_3.val[0];
vec_c[2] += vec_v_right_3.val[0];
vec_c[3] += vec_v_left_3.val[1];
vec_c[3] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
}
#endif
}
int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM4096_1536];
memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
}
#pragma unroll
for (int i = 0; i < BM4096_1536; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
template<int K>
void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
}}
void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
if (m == 1536 && k == 4096) {
preprocessor_k<4096>(B, LUT_Scales, QLUT);
}
else if (m == 1536 && k == 1536) {
preprocessor_k<1536>(B, LUT_Scales, QLUT);
}
else if (m == 4096 && k == 1536) {
preprocessor_k<1536>(B, LUT_Scales, QLUT);
}
}
void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
if (m == 1536 && k == 4096) {
qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
}
else if (m == 1536 && k == 1536) {
qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
}
else if (m == 4096 && k == 1536) {
qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
}
}
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
return;
}
int k = tensor->ne[0];
int m = tensor->ne[1];
const int lut_scales_size = 1;
const int scales_size = 1;
int bk = 0;
int bm = 0;
if (m == 1536 && k == 4096) {
bm = BM1536_4096;
bk = BBK1536_4096;
}
else if (m == 1536 && k == 1536) {
bm = BM1536_1536;
bk = BBK1536_1536;
}
else if (m == 4096 && k == 1536) {
bm = BM4096_1536;
bk = BBK4096_1536;
}
const int n_tile_num = m / bm;
const int BK = bk;
uint8_t * qweights;
bitnet_float_type * scales;
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
qweights = (uint8_t *) tensor->data;
float * i2_scales = (float * )(qweights + k * m / 4);
scales[0] = (bitnet_float_type) i2_scales[0];
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
/* .lut_scales_size = */ lut_scales_size,
/* .BK = */ BK,
/* .n_tile_num = */ n_tile_num,
/* .qweights = */ qweights,
/* .scales = */ scales
};
}
#endif
+9
View File
@@ -5,8 +5,13 @@
#ifdef __ARM_NEON
#include <arm_neon.h>
#if defined(GGML_BITNET_ARM_TL1)
typedef float32_t bitnet_float_type;
#else
typedef float16_t bitnet_float_type;
#endif
#else
#include <immintrin.h>
typedef float bitnet_float_type;
#endif
@@ -43,6 +48,10 @@ GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* Q
GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
#endif
#if defined(GGML_BITNET_TL2_LOSS)
GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* Three_LUT_Scales, void* Two_LUT_Scales, void* Three_QLUT, void* Two_QLUT);
#endif
#ifdef __cplusplus
}
+21
View File
@@ -0,0 +1,21 @@
[Kernels_0]
m = 1536
k = 4096
bm = 256
bk = 128
bmm = 32
[Kernels_1]
m = 1536
k = 1536
bm = 128
bk = 64
bmm = 64
[Kernels_2]
m = 4096
k = 1536
bm = 256
bk = 128
bmm = 32
+33 -10
View File
@@ -44,8 +44,8 @@ SUPPORTED_HF_MODELS = {
}
SUPPORTED_QUANT_TYPES = {
"arm64": ["i2_s", "tl1"],
"x86_64": ["i2_s", "tl2"]
"arm64": ["i2_s", "tl1", "tl2-loss"],
"x86_64": ["i2_s", "tl2", "tl2-loss"]
}
COMPILER_EXTRA_ARGS = {
@@ -111,8 +111,10 @@ def prepare_model():
gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
logging.info(f"Converting HF model to GGUF format...")
if quant_type.startswith("tl"):
if quant_type in ["tl1", "tl2"]:
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
elif quant_type in ["tl2-loss"]:
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "tl2", "--quant-embd", "--loss", "--outfile", model_dir + str("/ggml-model-tl2-loss.gguf")], log_step="convert_to_tl")
else: # i2s
# convert to f32
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
@@ -156,11 +158,20 @@ def gen_code():
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
if get_model_name() == "bitnet_b1_58-large":
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
elif get_model_name() in llama3_f3_models:
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
elif get_model_name() == "bitnet_b1_58-3B":
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
else:
raise NotImplementedError()
else:
@@ -172,11 +183,20 @@ def gen_code():
sys.exit(1)
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
if get_model_name() == "bitnet_b1_58-large":
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
elif get_model_name() in llama3_f3_models:
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
elif get_model_name() == "bitnet_b1_58-3B":
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
if args.quant_type == "tl2-loss":
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
else:
raise NotImplementedError()
@@ -192,7 +212,10 @@ def compile():
logging.error(f"Arch {arch} is not supported yet")
exit(0)
logging.info("Compiling the code using CMake.")
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
if args.quant_type == "tl2-loss":
run_command(["cmake", "-B", "build", "-DBITNET_TL2_LOSS=ON", *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
else:
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
# run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
+74
View File
@@ -154,6 +154,80 @@ size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const stru
return wsize;
}
int ggml_bitnet_get_type_bits(enum ggml_type type) {
switch (type) {
case GGML_TYPE_TL2:
return 2;
case GGML_TYPE_Q4_0:
return 4;
default:
return 0;
}
}
#endif
#if defined(GGML_BITNET_TL2_LOSS)
void ggml_bitnet_init(void) {
// LOG(INFO) << "ggml_bitnet_init";
if (initialized) {
return;
}
initialized = true;
// if (wrapper == nullptr) {
// wrapper = new BITNET::BITNETGeMMWrapper<bitnet_bitnet_float_type>();
// }
if (bitnet_tensor_extras == nullptr) {
bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
}
bitnet_tensor_extras_index = 0;
}
void ggml_bitnet_free(void) {
// LOG(INFO) << "ggml_bitnet_free";
if (!initialized) {
return;
}
initialized = false;
// delete wrapper;
// wrapper = nullptr;
for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
// aligned_free(bitnet_tensor_extras[i].qweights);
// aligned_free(bitnet_tensor_extras[i].scales);
}
delete[] bitnet_tensor_extras;
bitnet_tensor_extras = nullptr;
}
bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
if ((is_type_supported(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
src0->backend == GGML_BACKEND_TYPE_CPU) {
if (src1->ne[1] <= 1) {
return true;
}
}
return false;
}
size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
const size_t ne01 = src0->ne[1];
const size_t ne10 = src1->ne[0];
const size_t ne11 = src1->ne[1];
size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
if (sizeof(bitnet_float_type) == 2) {
// Need fp32 to fp16 conversion
wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
}
wsize = ((wsize - 1) / 64 + 1) * 64;
return wsize;
}
int ggml_bitnet_get_type_bits(enum ggml_type type) {
switch (type) {
case GGML_TYPE_TL2:
File diff suppressed because it is too large Load Diff
+116 -6
View File
@@ -517,6 +517,92 @@ def preprocess_weights_tl1(
return weight
def preprocess_two_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 2, 2))
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
lo_weight = np.split(weight, 2, axis=1)[1]
weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
weight = weight + 4
weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1)
weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2)
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2))
weight_0 = weight[:, :, :, :, :, 0]
weight_1 = weight[:, :, :, :, :, 1]
weight_0 = weight_0 << 4
weight_1 = weight_1
weight = weight_0 + weight_1
weight = weight.reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
for i in range(weight.shape[0]):
final_weight.append(weight[i, :])
def preprocess_three_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 3, 3))
split_weights = np.split(weight, 3, axis=1)
first_weight = np.multiply(split_weights[0], 9)
second_weight = np.multiply(split_weights[1], 3)
third_weight = split_weights[2]
weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3)
sign_weight = np.sign(weight)
sign_weight = np.where(sign_weight < 1, 0, sign_weight)
weight = np.abs(weight)
weight = np.reshape(weight, (M, K // 3)).astype(np.uint8)
sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8)
weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3))
weight_list = []
for i in range(by // 3):
weight_list.append(weight[:, :, :, :, :, i])
for i in range(by // 3 // 2):
weight_list[i] = weight_list[i] << 4
weight_list[i + by // 3 // 2] = weight_list[i + by // 3 // 2]
weight_list[i] = weight_list[i] + weight_list[i + by // 3 // 2]
weight_list[i] = weight_list[i].reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
for i in range(weight_list[0].shape[0]):
for j in range(by // 3 // 2):
final_weight.append(weight_list[j][i, :])
sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4).astype(np.uint8)
combine_weight_list = []
for i in range(by // 3 // 2):
combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm), dtype=np.uint8)
combine_weight_list.append(combine_weight)
for i in range(8):
for j in range(by // 3 // 2):
if bm == 16:
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * i + j] << 7 - i)
elif bm == 32:
if i > 3 :
ti = (i - 4) * 2 + 1
else:
ti = i * 2
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * ti + j] << 7 - i)
for i in range(by // 3 // 2):
combine_weight_list[i] = combine_weight_list[i].reshape((M * K // (by * 4)) // 16, 16)
for i in range(combine_weight_list[0].shape[0]):
for j in range(by // 3 // 2):
final_weight.append(combine_weight_list[j][i, :])
def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 2, 2))
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
@@ -603,7 +689,6 @@ def preprocess_weights_tl2(
weight = w
weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
weight = np.sign(weight)
weight_num = np.prod(weight.shape)
config.read('include/kernel_config.ini')
BM = -1
@@ -631,7 +716,8 @@ def preprocess_weights_tl2(
final_weight = []
preprocess_three_weights_tl2(three_weight.shape[0],
if args.loss:
preprocess_three_weights_tl2_loss(three_weight.shape[0],
three_weight.shape[1],
three_weight.shape[0] * three_weight.shape[1],
BM,
@@ -641,8 +727,29 @@ def preprocess_weights_tl2(
three_weight,
final_weight)
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2( two_weight.shape[0],
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2_loss(two_weight.shape[0],
two_weight.shape[1],
two_weight.shape[0] * two_weight.shape[1],
BM,
32,
32,
4,
two_weight,
final_weight)
else:
preprocess_three_weights_tl2(three_weight.shape[0],
three_weight.shape[1],
three_weight.shape[0] * three_weight.shape[1],
BM,
BY,
bm,
by,
three_weight,
final_weight)
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2(two_weight.shape[0],
two_weight.shape[1],
two_weight.shape[0] * two_weight.shape[1],
BM,
@@ -652,8 +759,10 @@ def preprocess_weights_tl2(
two_weight,
final_weight)
weight = np.array(final_weight, dtype=np.uint8).reshape(-1)
weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 -
weight.shape[0]), mode='constant', constant_values=0)
pad_nums = (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8
pad_align_nums = 32 - ((K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8) % 32
pad_nums = pad_nums + pad_align_nums
weight = np.pad(weight, (0, pad_nums - weight.shape[0]), mode='constant', constant_values=0)
return weight
def transform_to_tl1(x: np.ndarray):
@@ -1116,6 +1225,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer")
parser.add_argument("--loss", action="store_true", help="use loss tl2")
return parser.parse_args()