mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-03 11:20:36 +00:00
commit paper code
This commit is contained in:
+2
-2
@@ -1,4 +1,4 @@
|
||||
[submodule "3rdparty/llama.cpp"]
|
||||
path = 3rdparty/llama.cpp
|
||||
url = https://github.com/Eddie-Wang1120/llama.cpp.git
|
||||
branch = merge-dev
|
||||
url = git@github.com:Eddie-Wang1120/llama.cpp.git
|
||||
branch = pp
|
||||
|
||||
+4
-3
@@ -14,6 +14,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
# option list
|
||||
option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF)
|
||||
option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF)
|
||||
option(BITNET_TL2_LOSS "bitnet.cpp: use tl2 on x86 platform" OFF)
|
||||
|
||||
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
@@ -24,6 +25,7 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
# override ggml options
|
||||
set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1})
|
||||
set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2})
|
||||
set(GGML_BITNET_TL2_LOSS ${BITNET_TL2_LOSS})
|
||||
|
||||
if (GGML_BITNET_ARM_TL1)
|
||||
add_compile_definitions(GGML_BITNET_ARM_TL1)
|
||||
@@ -31,9 +33,8 @@ endif()
|
||||
if (GGML_BITNET_X86_TL2)
|
||||
add_compile_definitions(GGML_BITNET_X86_TL2)
|
||||
endif()
|
||||
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
add_compile_options(-fpermissive)
|
||||
if (GGML_BITNET_TL2_LOSS)
|
||||
add_compile_definitions(GGML_BITNET_TL2_LOSS)
|
||||
endif()
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
@@ -43,8 +43,9 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
</tr>
|
||||
<tr>
|
||||
<th>I2_S</th>
|
||||
<th>TL1</th>
|
||||
<th>TL2</th>
|
||||
<th>TL1(TL1_1)</th>
|
||||
<th>TL2(TL2_1)</th>
|
||||
<th>TL2-Loss(TL2_0)</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-large">bitnet_b1_58-large</a></td>
|
||||
@@ -53,12 +54,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ARM</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"><a href="https://huggingface.co/1bitLLM/bitnet_b1_58-3B">bitnet_b1_58-3B</a></td>
|
||||
@@ -67,12 +70,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
<td>❌</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ARM</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"><a href="https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens">Llama3-8B-1.58-100B-tokens</a></td>
|
||||
@@ -81,12 +86,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ARM</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2"><a href="https://huggingface.co/collections/tiiuae/falcon3-67605ae03578be86e4e87026">Falcon3 Family</a></td>
|
||||
@@ -95,12 +102,14 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ARM</td>
|
||||
<td>✅</td>
|
||||
<td>✅</td>
|
||||
<td>❌</td>
|
||||
<td>✅</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@@ -144,11 +153,11 @@ pip install -r requirements.txt
|
||||
3. Build the project
|
||||
```bash
|
||||
# Download the model from Hugging Face, convert it to quantized gguf format, and build the project
|
||||
python setup_env.py --hf-repo tiiuae/Falcon3-7B-Instruct-1.58bit -q i2_s
|
||||
python setup_env.py --hf-repo 1bitLLM/bitnet_b1_58-large -q i2_s
|
||||
|
||||
# Or you can manually download the model and run with local path
|
||||
huggingface-cli download tiiuae/Falcon3-7B-Instruct-1.58bit --local-dir models/Falcon3-7B-Instruct-1.58bit
|
||||
python setup_env.py -md models/Falcon3-7B-Instruct-1.58bit -q i2_s
|
||||
huggingface-cli download 1bitLLM/bitnet_b1_58-large --local-dir models/bitnet_b1_58-large
|
||||
python setup_env.py -md models/bitnet_b1_58-large -q i2_s
|
||||
```
|
||||
<pre>
|
||||
usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
|
||||
|
||||
@@ -0,0 +1,627 @@
|
||||
#if defined(GGML_BITNET_ARM_TL1)
|
||||
#include "ggml-bitnet.h"
|
||||
#define GGML_BITNET_MAX_NODES 8192
|
||||
static bool initialized = false;
|
||||
static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
|
||||
static size_t bitnet_tensor_extras_index = 0;
|
||||
static void * aligned_malloc(size_t size) {{
|
||||
#if defined(_WIN32)
|
||||
return _aligned_malloc(size, 64);
|
||||
#else
|
||||
void * ptr = nullptr;
|
||||
posix_memalign(&ptr, 64, size);
|
||||
return ptr;
|
||||
#endif
|
||||
}}
|
||||
static void aligned_free(void * ptr) {{
|
||||
#if defined(_WIN32)
|
||||
_aligned_free(ptr);
|
||||
#else
|
||||
free(ptr);
|
||||
#endif
|
||||
}}
|
||||
|
||||
void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
|
||||
bitnet_float_type* b = (bitnet_float_type*)b_;
|
||||
#ifdef __ARM_NEON
|
||||
float32x4_t temp_max = vdupq_n_f32(0);
|
||||
for (int i=0; i < k / 4; i++) {{
|
||||
float32x4_t vec_bs = vld1q_f32(b + 4 * i);
|
||||
float32x4_t abssum = vabsq_f32(vec_bs);
|
||||
temp_max = vmaxq_f32(abssum, temp_max);
|
||||
}}
|
||||
float32_t scales = 127 / vmaxvq_f32(temp_max);
|
||||
*lut_scales = scales;
|
||||
#elif defined __AVX2__
|
||||
__m256 max_vec = _mm256_set1_ps(0.f);
|
||||
const __m256 vec_sign = _mm256_set1_ps(-0.0f);
|
||||
// #pragma unroll
|
||||
for (int i = 0; i < k / 8; i++) {{
|
||||
__m256 vec_b = _mm256_loadu_ps(b + i * 8);
|
||||
__m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
|
||||
max_vec = _mm256_max_ps(vec_babs, max_vec);
|
||||
}}
|
||||
__m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
|
||||
max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
|
||||
max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
|
||||
float scales = 127 / _mm_cvtss_f32(max1);
|
||||
*lut_scales = scales;
|
||||
#endif
|
||||
}}
|
||||
|
||||
void partial_max_reset(void* lut_scales_) {{
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
|
||||
*lut_scales = 0.0;
|
||||
}}
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
inline void Transpose_8_8(
|
||||
int16x8_t *v0,
|
||||
int16x8_t *v1,
|
||||
int16x8_t *v2,
|
||||
int16x8_t *v3,
|
||||
int16x8_t *v4,
|
||||
int16x8_t *v5,
|
||||
int16x8_t *v6,
|
||||
int16x8_t *v7)
|
||||
{{
|
||||
int16x8x2_t q04 = vzipq_s16(*v0, *v4);
|
||||
int16x8x2_t q15 = vzipq_s16(*v1, *v5);
|
||||
int16x8x2_t q26 = vzipq_s16(*v2, *v6);
|
||||
int16x8x2_t q37 = vzipq_s16(*v3, *v7);
|
||||
|
||||
int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
|
||||
int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
|
||||
int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
|
||||
int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
|
||||
|
||||
int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
|
||||
int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
|
||||
int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
|
||||
int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
|
||||
|
||||
*v0 = q_fin_0.val[0];
|
||||
*v1 = q_fin_0.val[1];
|
||||
*v2 = q_fin_1.val[0];
|
||||
*v3 = q_fin_1.val[1];
|
||||
*v4 = q_fin_2.val[0];
|
||||
*v5 = q_fin_2.val[1];
|
||||
*v6 = q_fin_3.val[0];
|
||||
*v7 = q_fin_3.val[1];
|
||||
}}
|
||||
#endif
|
||||
|
||||
template<int act_k>
|
||||
inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
|
||||
#ifdef __ARM_NEON
|
||||
int16x8_t vec_lut[16];
|
||||
float32_t scales = *lut_scales;
|
||||
uint8_t tbl_mask[16];
|
||||
tbl_mask[0] = 0;
|
||||
tbl_mask[1] = 2;
|
||||
tbl_mask[2] = 4;
|
||||
tbl_mask[3] = 6;
|
||||
tbl_mask[4] = 8;
|
||||
tbl_mask[5] = 10;
|
||||
tbl_mask[6] = 12;
|
||||
tbl_mask[7] = 14;
|
||||
tbl_mask[8] = 1;
|
||||
tbl_mask[9] = 3;
|
||||
tbl_mask[10] = 5;
|
||||
tbl_mask[11] = 7;
|
||||
tbl_mask[12] = 9;
|
||||
tbl_mask[13] = 11;
|
||||
tbl_mask[14] = 13;
|
||||
tbl_mask[15] = 15;
|
||||
uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < act_k / 16; ++k) {{
|
||||
float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
|
||||
float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
|
||||
float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
|
||||
float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
|
||||
float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
|
||||
float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
|
||||
int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
|
||||
int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
|
||||
int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
|
||||
int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
|
||||
int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
|
||||
int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
|
||||
int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
|
||||
int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
|
||||
int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
|
||||
int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
|
||||
vec_lut[0] = vdupq_n_s16(0);
|
||||
vec_lut[0] = vec_lut[0] - vec_bs_0;
|
||||
vec_lut[0] = vec_lut[0] - vec_bs_1;
|
||||
vec_lut[1] = vdupq_n_s16(0);
|
||||
vec_lut[1] = vec_lut[1] - vec_bs_0;
|
||||
vec_lut[2] = vdupq_n_s16(0);
|
||||
vec_lut[2] = vec_lut[2] - vec_bs_0;
|
||||
vec_lut[2] = vec_lut[2] + vec_bs_1;
|
||||
vec_lut[3] = vdupq_n_s16(0);
|
||||
vec_lut[3] = vec_lut[3] - vec_bs_1;
|
||||
vec_lut[4] = vdupq_n_s16(0);
|
||||
vec_lut[5] = vec_bs_1;
|
||||
vec_lut[6] = vec_bs_0;
|
||||
vec_lut[6] = vec_lut[6] - vec_bs_1;
|
||||
vec_lut[7] = vec_bs_0;
|
||||
vec_lut[8] = vec_bs_0;
|
||||
vec_lut[8] = vec_lut[8] + vec_bs_1;
|
||||
Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
|
||||
&(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
|
||||
Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
|
||||
&(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
|
||||
#pragma unroll
|
||||
for (int idx = 0; idx < 8; idx++) {{
|
||||
int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
|
||||
int8x8_t q0_low = vget_low_s8(q0_s);
|
||||
int8x8_t q0_high = vget_high_s8(q0_s);
|
||||
int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
|
||||
int8x8_t q1_low = vget_low_s8(q1_s);
|
||||
int8x8_t q1_high = vget_high_s8(q1_s);
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
|
||||
}}
|
||||
}}
|
||||
#endif
|
||||
}}
|
||||
|
||||
static bool is_type_supported(enum ggml_type type) {{
|
||||
if (type == GGML_TYPE_Q4_0 ||
|
||||
type == GGML_TYPE_TL1) {{
|
||||
return true;
|
||||
}} else {{
|
||||
return false;
|
||||
}}
|
||||
}}
|
||||
#include <arm_neon.h>
|
||||
|
||||
#define BM1536_4096 256
|
||||
#define BBK1536_4096 128
|
||||
inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
|
||||
#ifdef __ARM_NEON
|
||||
const int KK = BBK1536_4096 / 2;
|
||||
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
|
||||
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
|
||||
int8x16_t vec_lut[2 * KK];
|
||||
int16x8_t vec_c[4];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2 * KK; k++) {
|
||||
vec_lut[k] = vld1q_s8(lut + k * 16);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM1536_4096; i += 32) {
|
||||
#pragma unroll
|
||||
for (int i=0; i<4; i++) {
|
||||
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KK / 4; k++) {
|
||||
|
||||
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
|
||||
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
|
||||
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
|
||||
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
|
||||
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
|
||||
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
|
||||
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
|
||||
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
|
||||
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
|
||||
vec_c[0] += vec_v_left_0.val[0];
|
||||
vec_c[0] += vec_v_right_0.val[0];
|
||||
vec_c[1] += vec_v_left_0.val[1];
|
||||
vec_c[1] += vec_v_right_0.val[1];
|
||||
|
||||
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
|
||||
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
|
||||
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
|
||||
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
|
||||
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
|
||||
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
|
||||
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
|
||||
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
|
||||
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
|
||||
vec_c[0] += vec_v_left_1.val[0];
|
||||
vec_c[0] += vec_v_right_1.val[0];
|
||||
vec_c[1] += vec_v_left_1.val[1];
|
||||
vec_c[1] += vec_v_right_1.val[1];
|
||||
|
||||
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
|
||||
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
|
||||
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
|
||||
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
|
||||
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
|
||||
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
|
||||
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
|
||||
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
|
||||
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
|
||||
vec_c[2] += vec_v_left_2.val[0];
|
||||
vec_c[2] += vec_v_right_2.val[0];
|
||||
vec_c[3] += vec_v_left_2.val[1];
|
||||
vec_c[3] += vec_v_right_2.val[1];
|
||||
|
||||
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
|
||||
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
|
||||
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
|
||||
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
|
||||
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
|
||||
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
|
||||
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
|
||||
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
|
||||
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
|
||||
vec_c[2] += vec_v_left_3.val[0];
|
||||
vec_c[2] += vec_v_right_3.val[0];
|
||||
vec_c[3] += vec_v_left_3.val[1];
|
||||
vec_c[3] += vec_v_right_3.val[1];
|
||||
|
||||
}
|
||||
|
||||
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
|
||||
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
|
||||
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
|
||||
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
|
||||
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
|
||||
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
|
||||
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
|
||||
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
|
||||
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
|
||||
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
|
||||
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
|
||||
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
|
||||
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
|
||||
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
|
||||
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
|
||||
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
|
||||
alignas(32) uint32_t CBits[BM1536_4096];
|
||||
memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
|
||||
#pragma unroll
|
||||
for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
|
||||
tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM1536_4096; i++) {
|
||||
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
#include <arm_neon.h>
|
||||
|
||||
#define BM1536_1536 128
|
||||
#define BBK1536_1536 64
|
||||
inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
|
||||
#ifdef __ARM_NEON
|
||||
const int KK = BBK1536_1536 / 2;
|
||||
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
|
||||
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
|
||||
int8x16_t vec_lut[2 * KK];
|
||||
int16x8_t vec_c[8];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2 * KK; k++) {
|
||||
vec_lut[k] = vld1q_s8(lut + k * 16);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM1536_1536; i += 64) {
|
||||
#pragma unroll
|
||||
for (int i=0; i<8; i++) {
|
||||
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KK / 2; k++) {
|
||||
|
||||
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
|
||||
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
|
||||
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
|
||||
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
|
||||
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
|
||||
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
|
||||
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
|
||||
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
|
||||
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
|
||||
vec_c[0] += vec_v_left_0.val[0];
|
||||
vec_c[0] += vec_v_right_0.val[0];
|
||||
vec_c[1] += vec_v_left_0.val[1];
|
||||
vec_c[1] += vec_v_right_0.val[1];
|
||||
|
||||
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
|
||||
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
|
||||
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
|
||||
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
|
||||
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
|
||||
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
|
||||
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
|
||||
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
|
||||
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
|
||||
vec_c[2] += vec_v_left_1.val[0];
|
||||
vec_c[2] += vec_v_right_1.val[0];
|
||||
vec_c[3] += vec_v_left_1.val[1];
|
||||
vec_c[3] += vec_v_right_1.val[1];
|
||||
|
||||
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
|
||||
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
|
||||
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
|
||||
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
|
||||
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
|
||||
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
|
||||
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
|
||||
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
|
||||
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
|
||||
vec_c[4] += vec_v_left_2.val[0];
|
||||
vec_c[4] += vec_v_right_2.val[0];
|
||||
vec_c[5] += vec_v_left_2.val[1];
|
||||
vec_c[5] += vec_v_right_2.val[1];
|
||||
|
||||
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
|
||||
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
|
||||
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
|
||||
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
|
||||
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
|
||||
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
|
||||
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
|
||||
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
|
||||
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
|
||||
vec_c[6] += vec_v_left_3.val[0];
|
||||
vec_c[6] += vec_v_right_3.val[0];
|
||||
vec_c[7] += vec_v_left_3.val[1];
|
||||
vec_c[7] += vec_v_right_3.val[1];
|
||||
|
||||
}
|
||||
|
||||
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
|
||||
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
|
||||
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
|
||||
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
|
||||
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
|
||||
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
|
||||
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
|
||||
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
|
||||
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
|
||||
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
|
||||
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
|
||||
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
|
||||
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
|
||||
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
|
||||
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
|
||||
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
|
||||
int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
|
||||
int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
|
||||
vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
|
||||
vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
|
||||
int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
|
||||
int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
|
||||
vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
|
||||
vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
|
||||
int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
|
||||
int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
|
||||
vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
|
||||
vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
|
||||
int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
|
||||
int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
|
||||
vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
|
||||
vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
|
||||
alignas(32) uint32_t CBits[BM1536_1536];
|
||||
memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
|
||||
#pragma unroll
|
||||
for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
|
||||
tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM1536_1536; i++) {
|
||||
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
#include <arm_neon.h>
|
||||
|
||||
#define BM4096_1536 256
|
||||
#define BBK4096_1536 128
|
||||
inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
|
||||
#ifdef __ARM_NEON
|
||||
const int KK = BBK4096_1536 / 2;
|
||||
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
|
||||
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
|
||||
int8x16_t vec_lut[2 * KK];
|
||||
int16x8_t vec_c[4];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2 * KK; k++) {
|
||||
vec_lut[k] = vld1q_s8(lut + k * 16);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM4096_1536; i += 32) {
|
||||
#pragma unroll
|
||||
for (int i=0; i<4; i++) {
|
||||
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < KK / 4; k++) {
|
||||
|
||||
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
|
||||
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
|
||||
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
|
||||
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
|
||||
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
|
||||
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
|
||||
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
|
||||
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
|
||||
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
|
||||
vec_c[0] += vec_v_left_0.val[0];
|
||||
vec_c[0] += vec_v_right_0.val[0];
|
||||
vec_c[1] += vec_v_left_0.val[1];
|
||||
vec_c[1] += vec_v_right_0.val[1];
|
||||
|
||||
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
|
||||
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
|
||||
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
|
||||
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
|
||||
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
|
||||
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
|
||||
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
|
||||
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
|
||||
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
|
||||
vec_c[0] += vec_v_left_1.val[0];
|
||||
vec_c[0] += vec_v_right_1.val[0];
|
||||
vec_c[1] += vec_v_left_1.val[1];
|
||||
vec_c[1] += vec_v_right_1.val[1];
|
||||
|
||||
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
|
||||
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
|
||||
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
|
||||
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
|
||||
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
|
||||
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
|
||||
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
|
||||
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
|
||||
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
|
||||
vec_c[2] += vec_v_left_2.val[0];
|
||||
vec_c[2] += vec_v_right_2.val[0];
|
||||
vec_c[3] += vec_v_left_2.val[1];
|
||||
vec_c[3] += vec_v_right_2.val[1];
|
||||
|
||||
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
|
||||
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
|
||||
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
|
||||
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
|
||||
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
|
||||
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
|
||||
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
|
||||
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
|
||||
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
|
||||
vec_c[2] += vec_v_left_3.val[0];
|
||||
vec_c[2] += vec_v_right_3.val[0];
|
||||
vec_c[3] += vec_v_left_3.val[1];
|
||||
vec_c[3] += vec_v_right_3.val[1];
|
||||
|
||||
}
|
||||
|
||||
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
|
||||
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
|
||||
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
|
||||
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
|
||||
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
|
||||
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
|
||||
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
|
||||
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
|
||||
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
|
||||
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
|
||||
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
|
||||
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
|
||||
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
|
||||
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
|
||||
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
|
||||
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
|
||||
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
|
||||
alignas(32) uint32_t CBits[BM4096_1536];
|
||||
memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
|
||||
#pragma unroll
|
||||
for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
|
||||
tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < BM4096_1536; i++) {
|
||||
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
template<int K>
|
||||
void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
|
||||
partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
|
||||
per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
|
||||
|
||||
lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
|
||||
}}
|
||||
void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
|
||||
if (m == 1536 && k == 4096) {
|
||||
preprocessor_k<4096>(B, LUT_Scales, QLUT);
|
||||
}
|
||||
else if (m == 1536 && k == 1536) {
|
||||
preprocessor_k<1536>(B, LUT_Scales, QLUT);
|
||||
}
|
||||
else if (m == 4096 && k == 1536) {
|
||||
preprocessor_k<1536>(B, LUT_Scales, QLUT);
|
||||
}
|
||||
}
|
||||
void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
|
||||
if (m == 1536 && k == 4096) {
|
||||
qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
|
||||
}
|
||||
else if (m == 1536 && k == 1536) {
|
||||
qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
|
||||
}
|
||||
else if (m == 4096 && k == 1536) {
|
||||
qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
|
||||
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
int k = tensor->ne[0];
|
||||
int m = tensor->ne[1];
|
||||
const int lut_scales_size = 1;
|
||||
const int scales_size = 1;
|
||||
int bk = 0;
|
||||
int bm = 0;
|
||||
|
||||
if (m == 1536 && k == 4096) {
|
||||
bm = BM1536_4096;
|
||||
bk = BBK1536_4096;
|
||||
}
|
||||
else if (m == 1536 && k == 1536) {
|
||||
bm = BM1536_1536;
|
||||
bk = BBK1536_1536;
|
||||
}
|
||||
else if (m == 4096 && k == 1536) {
|
||||
bm = BM4096_1536;
|
||||
bk = BBK4096_1536;
|
||||
}
|
||||
|
||||
const int n_tile_num = m / bm;
|
||||
const int BK = bk;
|
||||
uint8_t * qweights;
|
||||
bitnet_float_type * scales;
|
||||
|
||||
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
|
||||
qweights = (uint8_t *) tensor->data;
|
||||
float * i2_scales = (float * )(qweights + k * m / 4);
|
||||
scales[0] = (bitnet_float_type) i2_scales[0];
|
||||
|
||||
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
|
||||
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
|
||||
/* .lut_scales_size = */ lut_scales_size,
|
||||
/* .BK = */ BK,
|
||||
/* .n_tile_num = */ n_tile_num,
|
||||
/* .qweights = */ qweights,
|
||||
/* .scales = */ scales
|
||||
};
|
||||
}
|
||||
#endif
|
||||
@@ -5,8 +5,13 @@
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#include <arm_neon.h>
|
||||
#if defined(GGML_BITNET_ARM_TL1)
|
||||
typedef float32_t bitnet_float_type;
|
||||
#else
|
||||
typedef float16_t bitnet_float_type;
|
||||
#endif
|
||||
#else
|
||||
#include <immintrin.h>
|
||||
typedef float bitnet_float_type;
|
||||
#endif
|
||||
|
||||
@@ -43,6 +48,10 @@ GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* Q
|
||||
GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
|
||||
GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
|
||||
#endif
|
||||
#if defined(GGML_BITNET_TL2_LOSS)
|
||||
GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
|
||||
GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* Three_LUT_Scales, void* Two_LUT_Scales, void* Three_QLUT, void* Two_QLUT);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
[Kernels_0]
|
||||
m = 1536
|
||||
k = 4096
|
||||
bm = 256
|
||||
bk = 128
|
||||
bmm = 32
|
||||
|
||||
[Kernels_1]
|
||||
m = 1536
|
||||
k = 1536
|
||||
bm = 128
|
||||
bk = 64
|
||||
bmm = 64
|
||||
|
||||
[Kernels_2]
|
||||
m = 4096
|
||||
k = 1536
|
||||
bm = 256
|
||||
bk = 128
|
||||
bmm = 32
|
||||
|
||||
+33
-10
@@ -44,8 +44,8 @@ SUPPORTED_HF_MODELS = {
|
||||
}
|
||||
|
||||
SUPPORTED_QUANT_TYPES = {
|
||||
"arm64": ["i2_s", "tl1"],
|
||||
"x86_64": ["i2_s", "tl2"]
|
||||
"arm64": ["i2_s", "tl1", "tl2-loss"],
|
||||
"x86_64": ["i2_s", "tl2", "tl2-loss"]
|
||||
}
|
||||
|
||||
COMPILER_EXTRA_ARGS = {
|
||||
@@ -111,8 +111,10 @@ def prepare_model():
|
||||
gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
|
||||
if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
|
||||
logging.info(f"Converting HF model to GGUF format...")
|
||||
if quant_type.startswith("tl"):
|
||||
if quant_type in ["tl1", "tl2"]:
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
|
||||
elif quant_type in ["tl2-loss"]:
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "tl2", "--quant-embd", "--loss", "--outfile", model_dir + str("/ggml-model-tl2-loss.gguf")], log_step="convert_to_tl")
|
||||
else: # i2s
|
||||
# convert to f32
|
||||
run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
|
||||
@@ -156,11 +158,20 @@ def gen_code():
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
|
||||
if get_model_name() == "bitnet_b1_58-large":
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
|
||||
elif get_model_name() in llama3_f3_models:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
|
||||
elif get_model_name() == "bitnet_b1_58-3B":
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
@@ -172,11 +183,20 @@ def gen_code():
|
||||
sys.exit(1)
|
||||
shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
|
||||
if get_model_name() == "bitnet_b1_58-large":
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
elif get_model_name() in llama3_f3_models:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
|
||||
elif get_model_name() == "bitnet_b1_58-3B":
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command([sys.executable, "utils/codegen_tl2_loss.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -192,7 +212,10 @@ def compile():
|
||||
logging.error(f"Arch {arch} is not supported yet")
|
||||
exit(0)
|
||||
logging.info("Compiling the code using CMake.")
|
||||
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
if args.quant_type == "tl2-loss":
|
||||
run_command(["cmake", "-B", "build", "-DBITNET_TL2_LOSS=ON", *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
else:
|
||||
run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
|
||||
# run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
|
||||
run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
|
||||
|
||||
|
||||
@@ -154,6 +154,80 @@ size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const stru
|
||||
return wsize;
|
||||
}
|
||||
|
||||
int ggml_bitnet_get_type_bits(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_TL2:
|
||||
return 2;
|
||||
case GGML_TYPE_Q4_0:
|
||||
return 4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(GGML_BITNET_TL2_LOSS)
|
||||
void ggml_bitnet_init(void) {
|
||||
// LOG(INFO) << "ggml_bitnet_init";
|
||||
|
||||
if (initialized) {
|
||||
return;
|
||||
}
|
||||
initialized = true;
|
||||
|
||||
// if (wrapper == nullptr) {
|
||||
// wrapper = new BITNET::BITNETGeMMWrapper<bitnet_bitnet_float_type>();
|
||||
// }
|
||||
if (bitnet_tensor_extras == nullptr) {
|
||||
bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
|
||||
}
|
||||
bitnet_tensor_extras_index = 0;
|
||||
}
|
||||
|
||||
void ggml_bitnet_free(void) {
|
||||
// LOG(INFO) << "ggml_bitnet_free";
|
||||
|
||||
if (!initialized) {
|
||||
return;
|
||||
}
|
||||
initialized = false;
|
||||
|
||||
// delete wrapper;
|
||||
// wrapper = nullptr;
|
||||
for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
|
||||
// aligned_free(bitnet_tensor_extras[i].qweights);
|
||||
// aligned_free(bitnet_tensor_extras[i].scales);
|
||||
}
|
||||
delete[] bitnet_tensor_extras;
|
||||
bitnet_tensor_extras = nullptr;
|
||||
}
|
||||
|
||||
bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
|
||||
if ((is_type_supported(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 &&
|
||||
src0->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
if (src1->ne[1] <= 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
|
||||
const size_t ne01 = src0->ne[1];
|
||||
const size_t ne10 = src1->ne[0];
|
||||
const size_t ne11 = src1->ne[1];
|
||||
|
||||
size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
|
||||
if (sizeof(bitnet_float_type) == 2) {
|
||||
// Need fp32 to fp16 conversion
|
||||
wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
|
||||
}
|
||||
wsize = ((wsize - 1) / 64 + 1) * 64;
|
||||
return wsize;
|
||||
}
|
||||
|
||||
int ggml_bitnet_get_type_bits(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_TL2:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -517,6 +517,92 @@ def preprocess_weights_tl1(
|
||||
return weight
|
||||
|
||||
|
||||
def preprocess_two_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 2, 2))
|
||||
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
|
||||
lo_weight = np.split(weight, 2, axis=1)[1]
|
||||
|
||||
weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
|
||||
weight = weight + 4
|
||||
weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
|
||||
weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1)
|
||||
weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2))
|
||||
weight_0 = weight[:, :, :, :, :, 0]
|
||||
weight_1 = weight[:, :, :, :, :, 1]
|
||||
weight_0 = weight_0 << 4
|
||||
weight_1 = weight_1
|
||||
weight = weight_0 + weight_1
|
||||
weight = weight.reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
|
||||
|
||||
for i in range(weight.shape[0]):
|
||||
final_weight.append(weight[i, :])
|
||||
|
||||
def preprocess_three_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 3, 3))
|
||||
split_weights = np.split(weight, 3, axis=1)
|
||||
first_weight = np.multiply(split_weights[0], 9)
|
||||
second_weight = np.multiply(split_weights[1], 3)
|
||||
third_weight = split_weights[2]
|
||||
|
||||
weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3)
|
||||
sign_weight = np.sign(weight)
|
||||
sign_weight = np.where(sign_weight < 1, 0, sign_weight)
|
||||
weight = np.abs(weight)
|
||||
|
||||
weight = np.reshape(weight, (M, K // 3)).astype(np.uint8)
|
||||
sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8)
|
||||
|
||||
weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
|
||||
weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3))
|
||||
|
||||
weight_list = []
|
||||
for i in range(by // 3):
|
||||
weight_list.append(weight[:, :, :, :, :, i])
|
||||
|
||||
for i in range(by // 3 // 2):
|
||||
weight_list[i] = weight_list[i] << 4
|
||||
weight_list[i + by // 3 // 2] = weight_list[i + by // 3 // 2]
|
||||
weight_list[i] = weight_list[i] + weight_list[i + by // 3 // 2]
|
||||
weight_list[i] = weight_list[i].reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
|
||||
|
||||
for i in range(weight_list[0].shape[0]):
|
||||
for j in range(by // 3 // 2):
|
||||
final_weight.append(weight_list[j][i, :])
|
||||
|
||||
sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4).astype(np.uint8)
|
||||
|
||||
combine_weight_list = []
|
||||
for i in range(by // 3 // 2):
|
||||
combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm), dtype=np.uint8)
|
||||
combine_weight_list.append(combine_weight)
|
||||
|
||||
for i in range(8):
|
||||
for j in range(by // 3 // 2):
|
||||
if bm == 16:
|
||||
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * i + j] << 7 - i)
|
||||
elif bm == 32:
|
||||
if i > 3 :
|
||||
ti = (i - 4) * 2 + 1
|
||||
else:
|
||||
ti = i * 2
|
||||
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * ti + j] << 7 - i)
|
||||
|
||||
for i in range(by // 3 // 2):
|
||||
combine_weight_list[i] = combine_weight_list[i].reshape((M * K // (by * 4)) // 16, 16)
|
||||
|
||||
for i in range(combine_weight_list[0].shape[0]):
|
||||
for j in range(by // 3 // 2):
|
||||
final_weight.append(combine_weight_list[j][i, :])
|
||||
|
||||
def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 2, 2))
|
||||
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
|
||||
@@ -603,7 +689,6 @@ def preprocess_weights_tl2(
|
||||
weight = w
|
||||
weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
|
||||
weight = np.sign(weight)
|
||||
weight_num = np.prod(weight.shape)
|
||||
|
||||
config.read('include/kernel_config.ini')
|
||||
BM = -1
|
||||
@@ -631,7 +716,8 @@ def preprocess_weights_tl2(
|
||||
|
||||
final_weight = []
|
||||
|
||||
preprocess_three_weights_tl2(three_weight.shape[0],
|
||||
if args.loss:
|
||||
preprocess_three_weights_tl2_loss(three_weight.shape[0],
|
||||
three_weight.shape[1],
|
||||
three_weight.shape[0] * three_weight.shape[1],
|
||||
BM,
|
||||
@@ -641,8 +727,29 @@ def preprocess_weights_tl2(
|
||||
three_weight,
|
||||
final_weight)
|
||||
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2( two_weight.shape[0],
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2_loss(two_weight.shape[0],
|
||||
two_weight.shape[1],
|
||||
two_weight.shape[0] * two_weight.shape[1],
|
||||
BM,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
two_weight,
|
||||
final_weight)
|
||||
else:
|
||||
preprocess_three_weights_tl2(three_weight.shape[0],
|
||||
three_weight.shape[1],
|
||||
three_weight.shape[0] * three_weight.shape[1],
|
||||
BM,
|
||||
BY,
|
||||
bm,
|
||||
by,
|
||||
three_weight,
|
||||
final_weight)
|
||||
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2(two_weight.shape[0],
|
||||
two_weight.shape[1],
|
||||
two_weight.shape[0] * two_weight.shape[1],
|
||||
BM,
|
||||
@@ -652,8 +759,10 @@ def preprocess_weights_tl2(
|
||||
two_weight,
|
||||
final_weight)
|
||||
weight = np.array(final_weight, dtype=np.uint8).reshape(-1)
|
||||
weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 -
|
||||
weight.shape[0]), mode='constant', constant_values=0)
|
||||
pad_nums = (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8
|
||||
pad_align_nums = 32 - ((K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8) % 32
|
||||
pad_nums = pad_nums + pad_align_nums
|
||||
weight = np.pad(weight, (0, pad_nums - weight.shape[0]), mode='constant', constant_values=0)
|
||||
return weight
|
||||
|
||||
def transform_to_tl1(x: np.ndarray):
|
||||
@@ -1116,6 +1225,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer")
|
||||
parser.add_argument("--loss", action="store_true", help="use loss tl2")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user