Files
BitNet/include/bitnet-lut-kernels.h
T
Eddie-Wang1120 4c736e3728 commit paper code
2025-02-16 15:03:25 +08:00

627 lines
28 KiB
C++

#if defined(GGML_BITNET_ARM_TL1)
#include "ggml-bitnet.h"
#define GGML_BITNET_MAX_NODES 8192
static bool initialized = false;
static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
static size_t bitnet_tensor_extras_index = 0;
static void * aligned_malloc(size_t size) {{
#if defined(_WIN32)
return _aligned_malloc(size, 64);
#else
void * ptr = nullptr;
posix_memalign(&ptr, 64, size);
return ptr;
#endif
}}
static void aligned_free(void * ptr) {{
#if defined(_WIN32)
_aligned_free(ptr);
#else
free(ptr);
#endif
}}
void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
bitnet_float_type* b = (bitnet_float_type*)b_;
#ifdef __ARM_NEON
float32x4_t temp_max = vdupq_n_f32(0);
for (int i=0; i < k / 4; i++) {{
float32x4_t vec_bs = vld1q_f32(b + 4 * i);
float32x4_t abssum = vabsq_f32(vec_bs);
temp_max = vmaxq_f32(abssum, temp_max);
}}
float32_t scales = 127 / vmaxvq_f32(temp_max);
*lut_scales = scales;
#elif defined __AVX2__
__m256 max_vec = _mm256_set1_ps(0.f);
const __m256 vec_sign = _mm256_set1_ps(-0.0f);
// #pragma unroll
for (int i = 0; i < k / 8; i++) {{
__m256 vec_b = _mm256_loadu_ps(b + i * 8);
__m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
max_vec = _mm256_max_ps(vec_babs, max_vec);
}}
__m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
float scales = 127 / _mm_cvtss_f32(max1);
*lut_scales = scales;
#endif
}}
void partial_max_reset(void* lut_scales_) {{
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
*lut_scales = 0.0;
}}
#ifdef __ARM_NEON
inline void Transpose_8_8(
int16x8_t *v0,
int16x8_t *v1,
int16x8_t *v2,
int16x8_t *v3,
int16x8_t *v4,
int16x8_t *v5,
int16x8_t *v6,
int16x8_t *v7)
{{
int16x8x2_t q04 = vzipq_s16(*v0, *v4);
int16x8x2_t q15 = vzipq_s16(*v1, *v5);
int16x8x2_t q26 = vzipq_s16(*v2, *v6);
int16x8x2_t q37 = vzipq_s16(*v3, *v7);
int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
*v0 = q_fin_0.val[0];
*v1 = q_fin_0.val[1];
*v2 = q_fin_1.val[0];
*v3 = q_fin_1.val[1];
*v4 = q_fin_2.val[0];
*v5 = q_fin_2.val[1];
*v6 = q_fin_3.val[0];
*v7 = q_fin_3.val[1];
}}
#endif
template<int act_k>
inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
#ifdef __ARM_NEON
int16x8_t vec_lut[16];
float32_t scales = *lut_scales;
uint8_t tbl_mask[16];
tbl_mask[0] = 0;
tbl_mask[1] = 2;
tbl_mask[2] = 4;
tbl_mask[3] = 6;
tbl_mask[4] = 8;
tbl_mask[5] = 10;
tbl_mask[6] = 12;
tbl_mask[7] = 14;
tbl_mask[8] = 1;
tbl_mask[9] = 3;
tbl_mask[10] = 5;
tbl_mask[11] = 7;
tbl_mask[12] = 9;
tbl_mask[13] = 11;
tbl_mask[14] = 13;
tbl_mask[15] = 15;
uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
#pragma unroll
for (int k = 0; k < act_k / 16; ++k) {{
float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
vec_lut[0] = vdupq_n_s16(0);
vec_lut[0] = vec_lut[0] - vec_bs_0;
vec_lut[0] = vec_lut[0] - vec_bs_1;
vec_lut[1] = vdupq_n_s16(0);
vec_lut[1] = vec_lut[1] - vec_bs_0;
vec_lut[2] = vdupq_n_s16(0);
vec_lut[2] = vec_lut[2] - vec_bs_0;
vec_lut[2] = vec_lut[2] + vec_bs_1;
vec_lut[3] = vdupq_n_s16(0);
vec_lut[3] = vec_lut[3] - vec_bs_1;
vec_lut[4] = vdupq_n_s16(0);
vec_lut[5] = vec_bs_1;
vec_lut[6] = vec_bs_0;
vec_lut[6] = vec_lut[6] - vec_bs_1;
vec_lut[7] = vec_bs_0;
vec_lut[8] = vec_bs_0;
vec_lut[8] = vec_lut[8] + vec_bs_1;
Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
&(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
&(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
#pragma unroll
for (int idx = 0; idx < 8; idx++) {{
int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
int8x8_t q0_low = vget_low_s8(q0_s);
int8x8_t q0_high = vget_high_s8(q0_s);
int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
int8x8_t q1_low = vget_low_s8(q1_s);
int8x8_t q1_high = vget_high_s8(q1_s);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
}}
}}
#endif
}}
static bool is_type_supported(enum ggml_type type) {{
if (type == GGML_TYPE_Q4_0 ||
type == GGML_TYPE_TL1) {{
return true;
}} else {{
return false;
}}
}}
#include <arm_neon.h>
#define BM1536_4096 256
#define BBK1536_4096 128
inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK1536_4096 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[4];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM1536_4096; i += 32) {
#pragma unroll
for (int i=0; i<4; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 4; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[0] += vec_v_left_1.val[0];
vec_c[0] += vec_v_right_1.val[0];
vec_c[1] += vec_v_left_1.val[1];
vec_c[1] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[2] += vec_v_left_2.val[0];
vec_c[2] += vec_v_right_2.val[0];
vec_c[3] += vec_v_left_2.val[1];
vec_c[3] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[2] += vec_v_left_3.val[0];
vec_c[2] += vec_v_right_3.val[0];
vec_c[3] += vec_v_left_3.val[1];
vec_c[3] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
}
#endif
}
int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM1536_4096];
memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
}
#pragma unroll
for (int i = 0; i < BM1536_4096; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
#include <arm_neon.h>
#define BM1536_1536 128
#define BBK1536_1536 64
inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK1536_1536 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[8];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM1536_1536; i += 64) {
#pragma unroll
for (int i=0; i<8; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 2; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[2] += vec_v_left_1.val[0];
vec_c[2] += vec_v_right_1.val[0];
vec_c[3] += vec_v_left_1.val[1];
vec_c[3] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[4] += vec_v_left_2.val[0];
vec_c[4] += vec_v_right_2.val[0];
vec_c[5] += vec_v_left_2.val[1];
vec_c[5] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[6] += vec_v_left_3.val[0];
vec_c[6] += vec_v_right_3.val[0];
vec_c[7] += vec_v_left_3.val[1];
vec_c[7] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
}
#endif
}
int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM1536_1536];
memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
}
#pragma unroll
for (int i = 0; i < BM1536_1536; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
#include <arm_neon.h>
#define BM4096_1536 256
#define BBK4096_1536 128
inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
#ifdef __ARM_NEON
const int KK = BBK4096_1536 / 2;
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
const int8x16_t vec_zero = vdupq_n_s16(0x0000);
int8x16_t vec_lut[2 * KK];
int16x8_t vec_c[4];
#pragma unroll
for (int k = 0; k < 2 * KK; k++) {
vec_lut[k] = vld1q_s8(lut + k * 16);
}
#pragma unroll
for (int i = 0; i < BM4096_1536; i += 32) {
#pragma unroll
for (int i=0; i<4; i++) {
vec_c[i] = vandq_s16(vec_c[i], vec_zero);
}
#pragma unroll
for (int k = 0; k < KK / 4; k++) {
uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
vec_c[0] += vec_v_left_0.val[0];
vec_c[0] += vec_v_right_0.val[0];
vec_c[1] += vec_v_left_0.val[1];
vec_c[1] += vec_v_right_0.val[1];
uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
vec_c[0] += vec_v_left_1.val[0];
vec_c[0] += vec_v_right_1.val[0];
vec_c[1] += vec_v_left_1.val[1];
vec_c[1] += vec_v_right_1.val[1];
uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
vec_c[2] += vec_v_left_2.val[0];
vec_c[2] += vec_v_right_2.val[0];
vec_c[3] += vec_v_left_2.val[1];
vec_c[3] += vec_v_right_2.val[1];
uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
vec_c[2] += vec_v_left_3.val[0];
vec_c[2] += vec_v_right_3.val[0];
vec_c[3] += vec_v_left_3.val[1];
vec_c[3] += vec_v_right_3.val[1];
}
int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
}
#endif
}
int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
alignas(32) uint32_t CBits[BM4096_1536];
memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
#pragma unroll
for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
}
#pragma unroll
for (int i = 0; i < BM4096_1536; i++) {
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
}
return 0;
};
template<int K>
void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
}}
void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
if (m == 1536 && k == 4096) {
preprocessor_k<4096>(B, LUT_Scales, QLUT);
}
else if (m == 1536 && k == 1536) {
preprocessor_k<1536>(B, LUT_Scales, QLUT);
}
else if (m == 4096 && k == 1536) {
preprocessor_k<1536>(B, LUT_Scales, QLUT);
}
}
void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
if (m == 1536 && k == 4096) {
qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
}
else if (m == 1536 && k == 1536) {
qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
}
else if (m == 4096 && k == 1536) {
qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
}
}
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
return;
}
int k = tensor->ne[0];
int m = tensor->ne[1];
const int lut_scales_size = 1;
const int scales_size = 1;
int bk = 0;
int bm = 0;
if (m == 1536 && k == 4096) {
bm = BM1536_4096;
bk = BBK1536_4096;
}
else if (m == 1536 && k == 1536) {
bm = BM1536_1536;
bk = BBK1536_1536;
}
else if (m == 4096 && k == 1536) {
bm = BM4096_1536;
bk = BBK4096_1536;
}
const int n_tile_num = m / bm;
const int BK = bk;
uint8_t * qweights;
bitnet_float_type * scales;
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
qweights = (uint8_t *) tensor->data;
float * i2_scales = (float * )(qweights + k * m / 4);
scales[0] = (bitnet_float_type) i2_scales[0];
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
/* .lut_scales_size = */ lut_scales_size,
/* .BK = */ BK,
/* .n_tile_num = */ n_tile_num,
/* .qweights = */ qweights,
/* .scales = */ scales
};
}
#endif