mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-03 19:30:32 +00:00
initial commit
This commit is contained in:
@@ -0,0 +1,442 @@
|
||||
import argparse
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
def gen_ctor_code():
|
||||
kernel_code = "\n\
|
||||
#include \"ggml-bitnet.h\"\n\
|
||||
#define GGML_BITNET_MAX_NODES 8192\n\
|
||||
static bool initialized = false;\n\
|
||||
static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
|
||||
static size_t bitnet_tensor_extras_index = 0;\n\
|
||||
static void * aligned_malloc(size_t size) {{\n\
|
||||
#if defined(_WIN32)\n\
|
||||
return _aligned_malloc(size, 64);\n\
|
||||
#else\n\
|
||||
void * ptr = nullptr;\n\
|
||||
posix_memalign(&ptr, 64, size);\n\
|
||||
return ptr;\n\
|
||||
#endif\n\
|
||||
}}\n\
|
||||
static void aligned_free(void * ptr) {{\n\
|
||||
#if defined(_WIN32)\n\
|
||||
_aligned_free(ptr);\n\
|
||||
#else\n\
|
||||
free(ptr);\n\
|
||||
#endif\n\
|
||||
}}\n\
|
||||
\n\
|
||||
void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
|
||||
bitnet_float_type* b = (bitnet_float_type*)b_;\n\
|
||||
#ifdef __ARM_NEON\n\
|
||||
float32x4_t temp_max = vdupq_n_f32(0);\n\
|
||||
for (int i=0; i < k / 4; i++) {{\n\
|
||||
float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\
|
||||
float32x4_t abssum = vabsq_f32(vec_bs);\n\
|
||||
temp_max = vmaxq_f32(abssum, temp_max);\n\
|
||||
}}\n\
|
||||
float32_t scales = 127 / vmaxvq_f32(temp_max);\n\
|
||||
*lut_scales = scales;\n\
|
||||
#elif defined __AVX2__\n\
|
||||
__m256 max_vec = _mm256_set1_ps(0.f);\n\
|
||||
const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
|
||||
// #pragma unroll\n\
|
||||
for (int i = 0; i < k / 8; i++) {{\n\
|
||||
__m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\
|
||||
__m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\
|
||||
max_vec = _mm256_max_ps(vec_babs, max_vec);\n\
|
||||
}}\n\
|
||||
__m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\
|
||||
max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\
|
||||
max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\
|
||||
float scales = 127 / _mm_cvtss_f32(max1);\n\
|
||||
*lut_scales = scales;\n\
|
||||
#endif\n\
|
||||
}}\n\
|
||||
\n\
|
||||
void partial_max_reset(void* lut_scales_) {{\n\
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
|
||||
*lut_scales = 0.0;\n\
|
||||
}}\n\
|
||||
\n\
|
||||
#ifdef __ARM_NEON\n\
|
||||
inline void Transpose_8_8(\n\
|
||||
int16x8_t *v0,\n\
|
||||
int16x8_t *v1,\n\
|
||||
int16x8_t *v2,\n\
|
||||
int16x8_t *v3,\n\
|
||||
int16x8_t *v4,\n\
|
||||
int16x8_t *v5,\n\
|
||||
int16x8_t *v6,\n\
|
||||
int16x8_t *v7)\n\
|
||||
{{\n\
|
||||
int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\
|
||||
int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\
|
||||
int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\
|
||||
int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\
|
||||
\n\
|
||||
int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\
|
||||
int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\
|
||||
int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\
|
||||
int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\
|
||||
\n\
|
||||
int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\
|
||||
int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\
|
||||
int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\
|
||||
int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\
|
||||
\n\
|
||||
*v0 = q_fin_0.val[0];\n\
|
||||
*v1 = q_fin_0.val[1];\n\
|
||||
*v2 = q_fin_1.val[0];\n\
|
||||
*v3 = q_fin_1.val[1];\n\
|
||||
*v4 = q_fin_2.val[0];\n\
|
||||
*v5 = q_fin_2.val[1];\n\
|
||||
*v6 = q_fin_3.val[0];\n\
|
||||
*v7 = q_fin_3.val[1];\n\
|
||||
}}\n\
|
||||
#endif\n\
|
||||
\n\
|
||||
template<int act_k>\n\
|
||||
inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\
|
||||
#ifdef __ARM_NEON\n\
|
||||
int16x8_t vec_lut[16];\n\
|
||||
float32_t scales = *lut_scales;\n\
|
||||
uint8_t tbl_mask[16];\n\
|
||||
tbl_mask[0] = 0;\n\
|
||||
tbl_mask[1] = 2;\n\
|
||||
tbl_mask[2] = 4;\n\
|
||||
tbl_mask[3] = 6;\n\
|
||||
tbl_mask[4] = 8;\n\
|
||||
tbl_mask[5] = 10;\n\
|
||||
tbl_mask[6] = 12;\n\
|
||||
tbl_mask[7] = 14;\n\
|
||||
tbl_mask[8] = 1;\n\
|
||||
tbl_mask[9] = 3;\n\
|
||||
tbl_mask[10] = 5;\n\
|
||||
tbl_mask[11] = 7;\n\
|
||||
tbl_mask[12] = 9;\n\
|
||||
tbl_mask[13] = 11;\n\
|
||||
tbl_mask[14] = 13;\n\
|
||||
tbl_mask[15] = 15;\n\
|
||||
uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < act_k / 16; ++k) {{\n\
|
||||
float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\
|
||||
float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\
|
||||
float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\
|
||||
float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\
|
||||
float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\
|
||||
float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\
|
||||
int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\
|
||||
int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\
|
||||
int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\
|
||||
int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\
|
||||
int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\
|
||||
int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\
|
||||
int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\
|
||||
int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\
|
||||
int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\
|
||||
int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\
|
||||
vec_lut[0] = vdupq_n_s16(0);\n\
|
||||
vec_lut[0] = vec_lut[0] - vec_bs_0;\n\
|
||||
vec_lut[0] = vec_lut[0] - vec_bs_1;\n\
|
||||
vec_lut[1] = vdupq_n_s16(0);\n\
|
||||
vec_lut[1] = vec_lut[1] - vec_bs_0;\n\
|
||||
vec_lut[2] = vdupq_n_s16(0);\n\
|
||||
vec_lut[2] = vec_lut[2] - vec_bs_0;\n\
|
||||
vec_lut[2] = vec_lut[2] + vec_bs_1;\n\
|
||||
vec_lut[3] = vdupq_n_s16(0);\n\
|
||||
vec_lut[3] = vec_lut[3] - vec_bs_1;\n\
|
||||
vec_lut[4] = vdupq_n_s16(0);\n\
|
||||
vec_lut[5] = vec_bs_1;\n\
|
||||
vec_lut[6] = vec_bs_0;\n\
|
||||
vec_lut[6] = vec_lut[6] - vec_bs_1;\n\
|
||||
vec_lut[7] = vec_bs_0;\n\
|
||||
vec_lut[8] = vec_bs_0;\n\
|
||||
vec_lut[8] = vec_lut[8] + vec_bs_1;\n\
|
||||
Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\
|
||||
&(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\
|
||||
Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\
|
||||
&(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\
|
||||
#pragma unroll\n\
|
||||
for (int idx = 0; idx < 8; idx++) {{\n\
|
||||
int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\
|
||||
int8x8_t q0_low = vget_low_s8(q0_s);\n\
|
||||
int8x8_t q0_high = vget_high_s8(q0_s);\n\
|
||||
int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\
|
||||
int8x8_t q1_low = vget_low_s8(q1_s);\n\
|
||||
int8x8_t q1_high = vget_high_s8(q1_s);\n\
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\
|
||||
vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
#endif\n\
|
||||
}}\n\
|
||||
\n\
|
||||
static bool is_type_supported(enum ggml_type type) {{\n\
|
||||
if (type == GGML_TYPE_Q4_0 ||\n\
|
||||
type == GGML_TYPE_TL1) {{\n\
|
||||
return true;\n\
|
||||
}} else {{\n\
|
||||
return false;\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
"
|
||||
return kernel_code
|
||||
|
||||
def gen_body_core_code(bm, by):
|
||||
length = 4
|
||||
all_code = ""
|
||||
for i in range(length):
|
||||
core_code = "\n\
|
||||
uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\
|
||||
uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\
|
||||
uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\
|
||||
int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\
|
||||
int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\
|
||||
int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\
|
||||
int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\
|
||||
int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\
|
||||
int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\
|
||||
vec_c[{6}] += vec_v_left_{0}.val[0];\n\
|
||||
vec_c[{6}] += vec_v_right_{0}.val[0];\n\
|
||||
vec_c[{7}] += vec_v_left_{0}.val[1];\n\
|
||||
vec_c[{7}] += vec_v_right_{0}.val[1];\n\
|
||||
".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1)
|
||||
|
||||
all_code = "".join([all_code, core_code])
|
||||
|
||||
all_code = "".join([all_code, "\n }\n\n"])
|
||||
|
||||
for i in range(bm // 8):
|
||||
core_code = "\
|
||||
int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\
|
||||
int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\
|
||||
vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\
|
||||
vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4)
|
||||
all_code = "".join([all_code, core_code])
|
||||
|
||||
return all_code
|
||||
|
||||
def gen_tbl_impl(pre, BM, BK, bm, k):
|
||||
|
||||
kernel_code = "\
|
||||
#include <arm_neon.h>\n\
|
||||
\n\
|
||||
#define BM{0} {1}\n\
|
||||
#define BBK{0} {2}\n\
|
||||
inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
|
||||
#ifdef __ARM_NEON\n\
|
||||
const int KK = BBK{0} / 2;\n\
|
||||
const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
|
||||
const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\
|
||||
int8x16_t vec_lut[2 * KK];\n\
|
||||
".format(pre, BM, BK)
|
||||
|
||||
kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)])
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < 2 * KK; k++) {\n\
|
||||
vec_lut[k] = vld1q_s8(lut + k * 16);\n\
|
||||
}\n"])
|
||||
|
||||
pre_core_code = "\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{}; i += {}) {{\n\
|
||||
#pragma unroll\n\
|
||||
for (int i=0; i<{}; i++) {{\n\
|
||||
vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\
|
||||
}}\n".format(pre, bm, bm // 8)
|
||||
|
||||
body_core_pre_code = "\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < KK / {}; k++) {{\n\
|
||||
".format(256 // bm // 2)
|
||||
|
||||
body_core_post_code = "\n\
|
||||
}\n\
|
||||
\
|
||||
#endif\n\
|
||||
}\n"
|
||||
|
||||
kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code])
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
|
||||
alignas({1}) uint32_t CBits[BM{0}];\n\
|
||||
memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\
|
||||
#pragma unroll\n\
|
||||
for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\
|
||||
tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{0}; i++) {{\n\
|
||||
((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\
|
||||
}}\n\
|
||||
return 0;\n\
|
||||
}};\n".format(pre, min(32, BK), k)])
|
||||
|
||||
return kernel_code
|
||||
|
||||
def gen_top_api(kernel_shapes):
|
||||
|
||||
kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\
|
||||
if (m == {0} && k == {1}) {{\n\
|
||||
preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[0][0], kernel_shapes[0][1])
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
|
||||
preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
|
||||
}}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
|
||||
kernel_code = "".join([kernel_code, "}\n"])
|
||||
kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
|
||||
if (m == {0} && k == {1}) {{\n\
|
||||
qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[0][0], kernel_shapes[0][1])])
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
|
||||
qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[i][0], kernel_shapes[i][1])])
|
||||
kernel_code = "".join([kernel_code, "}\n"])
|
||||
return kernel_code
|
||||
|
||||
def gen_preprocess_code():
|
||||
kernel_code = "\n\
|
||||
template<int K>\n\
|
||||
void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\
|
||||
partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\
|
||||
per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\
|
||||
\n\
|
||||
lut_ctor<K>((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\
|
||||
}}\n"
|
||||
return kernel_code
|
||||
|
||||
def gen_transform_code(kernel_shape):
|
||||
kernel_code = "\n\
|
||||
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
|
||||
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
|
||||
return;\n\
|
||||
}\n\
|
||||
\n\
|
||||
int k = tensor->ne[0];\n\
|
||||
int m = tensor->ne[1];\n\
|
||||
const int lut_scales_size = 1;\n\
|
||||
const int scales_size = 1;\n\
|
||||
int bk = 0;\n\
|
||||
int bm = 0;\n"
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
if (m == {0} && k == {1}) {{\n\
|
||||
bm = BM{0}_{1};\n\
|
||||
bk = BBK{0}_{1};\n\
|
||||
}}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
|
||||
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
|
||||
bm = BM{0}_{1};\n\
|
||||
bk = BBK{0}_{1};\n\
|
||||
}}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
const int n_tile_num = m / bm;\n\
|
||||
const int BK = bk;\n\
|
||||
uint8_t * qweights;\n\
|
||||
bitnet_float_type * scales;\n\
|
||||
\n\
|
||||
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
|
||||
qweights = (uint8_t *) tensor->data;\n\
|
||||
float * i2_scales = (float * )(qweights + k * m / 4);\n\
|
||||
scales[0] = (bitnet_float_type) i2_scales[0];\n\
|
||||
\n\
|
||||
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
|
||||
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
|
||||
/* .lut_scales_size = */ lut_scales_size,\n\
|
||||
/* .BK = */ BK,\n\
|
||||
/* .n_tile_num = */ n_tile_num,\n\
|
||||
/* .qweights = */ qweights,\n\
|
||||
/* .scales = */ scales\n\
|
||||
};\n\
|
||||
}\n"])
|
||||
|
||||
return kernel_code
|
||||
|
||||
if __name__ == "__main__":
|
||||
ModelShapeDict = {
|
||||
"bitnet_b1_58-large" : [[1536, 4096],
|
||||
[1536, 1536],
|
||||
[4096, 1536]],
|
||||
"bitnet_b1_58-3B" : [[3200, 8640],
|
||||
[3200, 3200],
|
||||
[8640, 3200]],
|
||||
"Llama3-8B-1.58-100B-tokens" : [[14336, 4096],
|
||||
[4096, 14336],
|
||||
[1024, 4096],
|
||||
[4096, 4096]]
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='gen impl')
|
||||
parser.add_argument('--model',default="input", type=str, dest="model",
|
||||
help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
|
||||
parser.add_argument('--BM',default="input", type=str,
|
||||
help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
|
||||
parser.add_argument('--BK',default="input", type=str,
|
||||
help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
|
||||
parser.add_argument('--bm',default="input", type=str,
|
||||
help="using simd instructions to compute (bm, 256 / bm) in one block")
|
||||
args = parser.parse_args()
|
||||
|
||||
kernel_shapes = ModelShapeDict[args.model]
|
||||
|
||||
BM_list = [int(item) for item in args.BM.split(',')]
|
||||
BK_list = [int(item) for item in args.BK.split(',')]
|
||||
bm_list = [int(item) for item in args.bm.split(',')]
|
||||
|
||||
assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
|
||||
assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0"
|
||||
assert bm_list[i] in [32, 64], "choose bm from [32, 64]"
|
||||
|
||||
tbl_impl_code = []
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
tbl_impl_code.append(
|
||||
gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1])
|
||||
)
|
||||
api_code = gen_top_api(kernel_shapes)
|
||||
pre_code = gen_preprocess_code()
|
||||
ctor_code = gen_ctor_code()
|
||||
trans_code = gen_transform_code(kernel_shapes)
|
||||
|
||||
output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
|
||||
|
||||
with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
|
||||
f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)"))
|
||||
f.write(''.join(ctor_code))
|
||||
for code in tbl_impl_code:
|
||||
f.write(''.join(code))
|
||||
f.write(''.join(pre_code))
|
||||
f.write(''.join(api_code))
|
||||
f.write(''.join(trans_code))
|
||||
f.write(''.join("#endif"))
|
||||
|
||||
config = ConfigParser()
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
config.add_section('Kernels_{}'.format(i))
|
||||
config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
|
||||
config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
|
||||
config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
|
||||
config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
|
||||
config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
|
||||
|
||||
with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
|
||||
config.write(configfile)
|
||||
@@ -0,0 +1,757 @@
|
||||
import argparse
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
def gen_ctor_code():
|
||||
kernel_code = "\n\
|
||||
#include \"ggml-bitnet.h\"\n\
|
||||
#include <cstring>\n\
|
||||
#include <immintrin.h>\n\
|
||||
#define GGML_BITNET_MAX_NODES 8192\n\
|
||||
static bool initialized = false;\n\
|
||||
static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
|
||||
static size_t bitnet_tensor_extras_index = 0;\n\
|
||||
static void * aligned_malloc(size_t size) {\n\
|
||||
#if defined(_WIN32)\n\
|
||||
return _aligned_malloc(size, 64);\n\
|
||||
#else\n\
|
||||
void * ptr = nullptr;\n\
|
||||
posix_memalign(&ptr, 64, size);\n\
|
||||
return ptr;\n\
|
||||
#endif\n\
|
||||
}\n\
|
||||
\n\
|
||||
static void aligned_free(void * ptr) {\n\
|
||||
#if defined(_WIN32)\n\
|
||||
_aligned_free(ptr);\n\
|
||||
#else\n\
|
||||
free(ptr);\n\
|
||||
#endif\n\
|
||||
}\n\
|
||||
#define BK2 32\n\
|
||||
#if defined __AVX2__\n\
|
||||
inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
|
||||
{\n\
|
||||
__m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
__m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
*vl = _mm256_unpacklo_epi32(va, vb);\n\
|
||||
*vh = _mm256_unpackhi_epi32(va, vb);\n\
|
||||
}\n\
|
||||
inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
|
||||
{\n\
|
||||
__m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
__m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
*vl = _mm256_unpacklo_epi64(va, vb);\n\
|
||||
*vh = _mm256_unpackhi_epi64(va, vb);\n\
|
||||
}\n\
|
||||
inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\
|
||||
{\n\
|
||||
*vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0));\n\
|
||||
*vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1));\n\
|
||||
}\n\
|
||||
inline void Transpose_8_8(\n\
|
||||
__m256i *v0,\n\
|
||||
__m256i *v1,\n\
|
||||
__m256i *v2,\n\
|
||||
__m256i *v3,\n\
|
||||
__m256i *v4,\n\
|
||||
__m256i *v5,\n\
|
||||
__m256i *v6,\n\
|
||||
__m256i *v7)\n\
|
||||
{\n\
|
||||
__m256i w0, w1, w2, w3, w4, w5, w6, w7;\n\
|
||||
__m256i x0, x1, x2, x3, x4, x5, x6, x7;\n\
|
||||
_mm256_merge_epi32(*v0, *v1, &w0, &w1);\n\
|
||||
_mm256_merge_epi32(*v2, *v3, &w2, &w3);\n\
|
||||
_mm256_merge_epi32(*v4, *v5, &w4, &w5);\n\
|
||||
_mm256_merge_epi32(*v6, *v7, &w6, &w7);\n\
|
||||
_mm256_merge_epi64(w0, w2, &x0, &x1);\n\
|
||||
_mm256_merge_epi64(w1, w3, &x2, &x3);\n\
|
||||
_mm256_merge_epi64(w4, w6, &x4, &x5);\n\
|
||||
_mm256_merge_epi64(w5, w7, &x6, &x7);\n\
|
||||
_mm256_merge_si128(x0, x4, v0, v1);\n\
|
||||
_mm256_merge_si128(x1, x5, v2, v3);\n\
|
||||
_mm256_merge_si128(x2, x6, v4, v5);\n\
|
||||
_mm256_merge_si128(x3, x7, v6, v7);\n\
|
||||
}\n\
|
||||
#endif\n\
|
||||
inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) {\n\
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
|
||||
bitnet_float_type* b = (bitnet_float_type*)b_;\n\
|
||||
#if defined __AVX2__\n\
|
||||
__m256 max_vec = _mm256_set1_ps(0.f);\n\
|
||||
const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
|
||||
for (int i = 0; i < k / 8; i++) {\n\
|
||||
__m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\
|
||||
__m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\
|
||||
max_vec = _mm256_max_ps(vec_babs, max_vec);\n\
|
||||
}\n\
|
||||
__m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\
|
||||
max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\
|
||||
max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\
|
||||
float scales = 127 / _mm_cvtss_f32(max1);\n\
|
||||
*lut_scales = scales;\n\
|
||||
#endif\n\
|
||||
return 0;\n\
|
||||
}\n\
|
||||
inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {\n\
|
||||
bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
|
||||
#pragma unroll\n\
|
||||
for (int i=0; i< bs; i++) {\n\
|
||||
lut_scales[i] = 0.0;\n\
|
||||
}\n\
|
||||
return 0;\n\
|
||||
}\n\
|
||||
template<int act_k>\n\
|
||||
inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
|
||||
#if defined __AVX2__\n\
|
||||
__m256 vec_lut[16];\n\
|
||||
const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\
|
||||
float scales = *lut_scales;\n\
|
||||
__m256i shuffle_mask = _mm256_set_epi8(\n\
|
||||
0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\
|
||||
0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\
|
||||
0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\
|
||||
0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\
|
||||
);\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < act_k / 24; ++k) {\n\
|
||||
__m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1);\n\
|
||||
__m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1);\n\
|
||||
__m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1);\n\
|
||||
\n\
|
||||
__m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
|
||||
__m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
|
||||
__m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
|
||||
\n\
|
||||
vec_lut[15] = _mm256_setzero_si256();\n\
|
||||
vec_lut[14] = _mm256_setzero_si256();\n\
|
||||
vec_lut[13] = vec_b0i;\n\
|
||||
vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i);\n\
|
||||
vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i);\n\
|
||||
vec_lut[12] = vec_b0i;\n\
|
||||
vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i);\n\
|
||||
vec_lut[11] = vec_b0i;\n\
|
||||
vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i);\n\
|
||||
vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i);\n\
|
||||
vec_lut[10] = vec_b0i;\n\
|
||||
vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i);\n\
|
||||
vec_lut[9] = vec_b0i;\n\
|
||||
vec_lut[8] = vec_b0i;\n\
|
||||
vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i);\n\
|
||||
vec_lut[7] = vec_b0i;\n\
|
||||
vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i);\n\
|
||||
vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i);\n\
|
||||
vec_lut[6] = vec_b0i;\n\
|
||||
vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i);\n\
|
||||
vec_lut[5] = vec_b0i;\n\
|
||||
vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i);\n\
|
||||
vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i);\n\
|
||||
vec_lut[4] = vec_b1i;\n\
|
||||
vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i);\n\
|
||||
vec_lut[3] = vec_b1i;\n\
|
||||
vec_lut[2] = vec_b1i;\n\
|
||||
vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i);\n\
|
||||
vec_lut[1] = vec_b2i;\n\
|
||||
vec_lut[0] = _mm256_setzero_si256();\n\
|
||||
__m256i ix[16];\n\
|
||||
\n\
|
||||
#pragma unroll\n\
|
||||
for (int g = 0; g < 16; ++g) {\n\
|
||||
ix[g] = vec_lut[g];\n\
|
||||
}\n\
|
||||
\n\
|
||||
Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\
|
||||
Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\
|
||||
\n\
|
||||
#pragma unroll\n\
|
||||
for (int g = 0; g < 8; ++g) {\n\
|
||||
ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\
|
||||
ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\
|
||||
ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
}\n\
|
||||
int8_t* qlut_i8 = reinterpret_cast<int8_t*>(qlut);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\
|
||||
\n\
|
||||
}\n\
|
||||
\n\
|
||||
*lut_scales = scales;\n\
|
||||
#endif\n\
|
||||
return 0;\n\
|
||||
}\n\
|
||||
\n\
|
||||
template<int act_k>\n\
|
||||
inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
|
||||
#if defined __AVX2__\n\
|
||||
__m256 vec_lut[16];\n\
|
||||
const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\
|
||||
float scales = *lut_scales;\n\
|
||||
__m256i shuffle_mask = _mm256_set_epi8(\n\
|
||||
0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\
|
||||
0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\
|
||||
0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\
|
||||
0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\
|
||||
);\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < act_k / 16; ++k) {\n\
|
||||
__m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1);\n\
|
||||
__m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1);\n\
|
||||
\n\
|
||||
__m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
|
||||
__m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\
|
||||
vec_lut[15] = _mm256_setzero_si256();\n\
|
||||
vec_lut[14] = _mm256_setzero_si256();\n\
|
||||
vec_lut[13] = _mm256_setzero_si256();\n\
|
||||
vec_lut[12] = _mm256_setzero_si256();\n\
|
||||
vec_lut[11] = _mm256_setzero_si256();\n\
|
||||
vec_lut[10] = _mm256_setzero_si256();\n\
|
||||
vec_lut[9] = _mm256_setzero_si256();\n\
|
||||
vec_lut[8] = vec_b0;\n\
|
||||
vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1);\n\
|
||||
vec_lut[7] = vec_b0;\n\
|
||||
vec_lut[6] = vec_b0;\n\
|
||||
vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1);\n\
|
||||
vec_lut[5] = vec_b1;\n\
|
||||
vec_lut[4] = _mm256_setzero_si256();\n\
|
||||
vec_lut[3] = _mm256_setzero_si256();\n\
|
||||
vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1);\n\
|
||||
vec_lut[2] = _mm256_setzero_si256();\n\
|
||||
vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0);\n\
|
||||
vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1);\n\
|
||||
vec_lut[1] = _mm256_setzero_si256();\n\
|
||||
vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0);\n\
|
||||
vec_lut[0] = _mm256_setzero_si256();\n\
|
||||
vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0);\n\
|
||||
vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1);\n\
|
||||
\n\
|
||||
__m256i ix[16];\n\
|
||||
#pragma unroll\n\
|
||||
for (int g = 0; g < 16; ++g) {\n\
|
||||
ix[g] = vec_lut[g];\n\
|
||||
}\n\
|
||||
\n\
|
||||
Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\
|
||||
Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\
|
||||
\n\
|
||||
#pragma unroll\n\
|
||||
for (int g = 0; g < 8; ++g) {\n\
|
||||
ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\
|
||||
ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\
|
||||
ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\
|
||||
}\n\
|
||||
\n\
|
||||
int8_t* qlut_i8 = reinterpret_cast<int8_t*>(qlut);\n\
|
||||
\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\
|
||||
\n\
|
||||
}\n\
|
||||
*lut_scales = scales;\n\
|
||||
#endif\n\
|
||||
return 0;\n\
|
||||
}\n\
|
||||
static bool is_type_supported(enum ggml_type type) {\n\
|
||||
if (type == GGML_TYPE_Q4_0 ||\n\
|
||||
type == GGML_TYPE_TL2) {\n\
|
||||
return true;\n\
|
||||
} else {\n\
|
||||
return false;\n\
|
||||
}\n\
|
||||
}\n\
|
||||
"
|
||||
return kernel_code
|
||||
|
||||
def gen_tbl_impl(pre, BM, BK, bm, k_list):
|
||||
|
||||
kernel_code = "\
|
||||
#include <immintrin.h>\n\
|
||||
\n\
|
||||
#define BM{0} {1}\n\
|
||||
#define BBK{0} {2}\n\
|
||||
template<int batch_size, int K3>\n\
|
||||
inline void three_tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {{\n\
|
||||
".format(pre, BM, BK)
|
||||
|
||||
kernel_code = "".join([kernel_code, "\
|
||||
#ifdef __AVX2__\n\
|
||||
const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\
|
||||
const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);\n\
|
||||
const __m256i vec_zero = _mm256_set1_epi8(0x00);\n\
|
||||
const __m256i vec_one = _mm256_set1_epi8(0xff);\n\
|
||||
const int KK = BBK{0} / 3;\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{0}; i += 32) {{\n\
|
||||
__m256i vec_as[KK / 2];\n\
|
||||
__m256i vec_signs[KK / 8];\n\
|
||||
#pragma unroll\n\
|
||||
for (int ai = 0; ai < KK / 2; ai++) {{\n\
|
||||
vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int as = 0; as < KK / 8; as++) {{\n\
|
||||
vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int bs = 0; bs < batch_size; bs++) {{\n\
|
||||
__m256i vec_c0 = _mm256_setzero_si256();\n\
|
||||
__m256i vec_c1 = _mm256_setzero_si256();\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < KK / 8; k++) {{\n\
|
||||
__m256i vec_sign = vec_signs[k];\n\
|
||||
__m256i vec_a_0 = vec_as[k * 4 + 0];\n\
|
||||
__m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));\n\
|
||||
__m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);\n\
|
||||
__m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);\n\
|
||||
__m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);\n\
|
||||
__m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);\n\
|
||||
__m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);\n\
|
||||
__m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);\n\
|
||||
__m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);\n\
|
||||
__m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);\n\
|
||||
__m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);\n\
|
||||
__m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);\n\
|
||||
__m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);\n\
|
||||
__m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);\n\
|
||||
__m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);\n\
|
||||
__m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);\n\
|
||||
__m256i vec_a_1 = vec_as[k * 4 + 1];\n\
|
||||
__m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));\n\
|
||||
__m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);\n\
|
||||
__m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);\n\
|
||||
__m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);\n\
|
||||
__m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);\n\
|
||||
__m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);\n\
|
||||
__m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);\n\
|
||||
__m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);\n\
|
||||
__m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);\n\
|
||||
__m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);\n\
|
||||
__m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);\n\
|
||||
__m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);\n\
|
||||
__m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);\n\
|
||||
__m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);\n\
|
||||
__m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);\n\
|
||||
__m256i vec_a_2 = vec_as[k * 4 + 2];\n\
|
||||
__m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));\n\
|
||||
__m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);\n\
|
||||
__m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);\n\
|
||||
__m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);\n\
|
||||
__m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);\n\
|
||||
__m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);\n\
|
||||
__m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);\n\
|
||||
__m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);\n\
|
||||
__m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);\n\
|
||||
__m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);\n\
|
||||
__m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);\n\
|
||||
__m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);\n\
|
||||
__m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);\n\
|
||||
__m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);\n\
|
||||
__m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);\n\
|
||||
__m256i vec_a_3 = vec_as[k * 4 + 3];\n\
|
||||
__m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));\n\
|
||||
__m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));\n\
|
||||
__m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);\n\
|
||||
__m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);\n\
|
||||
__m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);\n\
|
||||
__m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);\n\
|
||||
__m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);\n\
|
||||
__m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);\n\
|
||||
__m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);\n\
|
||||
__m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);\n\
|
||||
__m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);\n\
|
||||
__m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);\n\
|
||||
__m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);\n\
|
||||
__m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);\n\
|
||||
__m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);\n\
|
||||
__m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);\n\
|
||||
}}\n\
|
||||
__m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\
|
||||
__m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\
|
||||
__m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\
|
||||
__m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\
|
||||
vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\
|
||||
vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\
|
||||
vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\
|
||||
vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
#endif\n\
|
||||
}}\n\
|
||||
\n\
|
||||
template<int batch_size, int K2>\n\
|
||||
inline int32_t two_tbl_impl{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
|
||||
#ifdef __AVX2__\n\
|
||||
const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\
|
||||
const int KK = BK2 / 2;\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{0}; i += 32) {{\n\
|
||||
__m256i vec_as[KK / 2];\n\
|
||||
#pragma unroll\n\
|
||||
for (int ai = 0; ai < KK / 2; ai++) {{\n\
|
||||
vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int bs = 0; bs < batch_size; bs++) {{\n\
|
||||
__m256i vec_c0 = _mm256_setzero_si256();\n\
|
||||
__m256i vec_c1 = _mm256_setzero_si256();\n\
|
||||
#pragma unroll\n\
|
||||
for (int k = 0; k < KK / 8; k++) {{\n\
|
||||
#pragma unroll\n\
|
||||
for (int j = 0; j < 4; j++) {{\n\
|
||||
__m256i vec_a = vec_as[k * 4 + j];\n\
|
||||
\n\
|
||||
__m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));\n\
|
||||
__m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));\n\
|
||||
__m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));\n\
|
||||
__m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));\n\
|
||||
\n\
|
||||
__m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);\n\
|
||||
__m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);\n\
|
||||
__m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);\n\
|
||||
\n\
|
||||
__m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);\n\
|
||||
__m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);\n\
|
||||
__m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);\n\
|
||||
\n\
|
||||
__m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);\n\
|
||||
__m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);\n\
|
||||
__m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\
|
||||
__m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);\n\
|
||||
vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);\n\
|
||||
vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); \n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
\n\
|
||||
__m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\
|
||||
__m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\
|
||||
__m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\
|
||||
__m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\
|
||||
\n\
|
||||
vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\
|
||||
vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\
|
||||
vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\
|
||||
vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\
|
||||
\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
#endif\n\
|
||||
return 0;\n\
|
||||
}}\n\
|
||||
\n\
|
||||
template<int BATCH_SIZE>\n\
|
||||
int32_t three_qgemm_lut_{0}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
|
||||
alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\
|
||||
memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\
|
||||
#pragma unroll\n\
|
||||
for (int32_t k_outer = 0; k_outer < {1} / BBK{0}; ++k_outer) {{\n\
|
||||
three_tbl_impl_{0}<BATCH_SIZE, {1}>((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 3 / 2 * BM{0})])), (&(((uint8_t*)sign)[(k_outer * BBK{0} / 3 / 8 * BM{0})])));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{0}; i++) {{\n\
|
||||
((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
return 0;\n\
|
||||
}}\n\
|
||||
\n\
|
||||
template<int BATCH_SIZE>\n\
|
||||
int32_t two_qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
|
||||
alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\
|
||||
memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\
|
||||
#pragma unroll\n\
|
||||
for (int32_t k_outer = 0; k_outer < {2} / 32; ++k_outer) {{\n\
|
||||
two_tbl_impl{0}<BATCH_SIZE, {2}>((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{0})])));\n\
|
||||
}}\n\
|
||||
#pragma unroll\n\
|
||||
for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\
|
||||
#pragma unroll\n\
|
||||
for (int i = 0; i < BM{0}; i++) {{\n\
|
||||
((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\
|
||||
((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
return 0;\n\
|
||||
}}\n\
|
||||
\n\
|
||||
".format(pre, k_list[1], k_list[0])])
|
||||
return kernel_code
|
||||
|
||||
def gen_top_api(kernel_shapes, k_list):
|
||||
|
||||
kernel_code = "void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) {{\n\
|
||||
partial_max_reset(bs, (&(((float*)LUT_Scales)[0])));\n\
|
||||
if (m == {0} && two_k == {1} && three_k == {2}) {{\n\
|
||||
for (int32_t b = 0; b < bs; b++) {{\n\
|
||||
per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\
|
||||
three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\
|
||||
two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[0][0], k_list[0][0], k_list[0][1])
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, " else if (m == {0} && two_k == {1} && three_k == {2}) {{\n\
|
||||
for (int32_t b = 0; b < bs; b++) {{\n\
|
||||
per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\
|
||||
three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\
|
||||
two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\
|
||||
}}\n\
|
||||
}}\n".format(kernel_shapes[i][0], k_list[i][0], k_list[i][1])])
|
||||
kernel_code = "".join([kernel_code, "}\n"])
|
||||
|
||||
|
||||
kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
|
||||
if (m == {0} && k == {1}) {{\n\
|
||||
if (BK == {2}) {{\n\
|
||||
if (bs == 1) {{\n\
|
||||
two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 8) {{\n\
|
||||
two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 32) {{\n\
|
||||
two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 128) {{\n\
|
||||
two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 256) {{\n\
|
||||
two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 512) {{\n\
|
||||
two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
else if (BK == {3}) {{\n\
|
||||
if (bs == 1) {{\n\
|
||||
three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 8) {{\n\
|
||||
three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 32) {{\n\
|
||||
three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 128) {{\n\
|
||||
three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 256) {{\n\
|
||||
three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 512) {{\n\
|
||||
three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[0][0], kernel_shapes[0][1], k_list[0][0], k_list[0][1], "{}_{}".format(kernel_shapes[0][0], kernel_shapes[0][1]))])
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
|
||||
if (BK == {2}) {{\n\
|
||||
if (bs == 1) {{\n\
|
||||
two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 8) {{\n\
|
||||
two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 32) {{\n\
|
||||
two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 128) {{\n\
|
||||
two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 256) {{\n\
|
||||
two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}} else if (bs == 512) {{\n\
|
||||
two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
else if (BK == {3}) {{\n\
|
||||
if (bs == 1) {{\n\
|
||||
three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 8) {{\n\
|
||||
three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 32) {{\n\
|
||||
three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 128) {{\n\
|
||||
three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 256) {{\n\
|
||||
three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}else if (bs == 512) {{\n\
|
||||
three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
}}\n\
|
||||
".format(kernel_shapes[i][0], kernel_shapes[i][1], k_list[i][0], k_list[i][1], "{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]))])
|
||||
kernel_code = "".join([kernel_code, "}\n"])
|
||||
return kernel_code
|
||||
|
||||
def gen_transform_code(kernel_shapes):
|
||||
kernel_code = "\n\
|
||||
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
|
||||
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
|
||||
return;\n\
|
||||
}\n\
|
||||
\n\
|
||||
int k = tensor->ne[0];\n\
|
||||
int m = tensor->ne[1];\n\
|
||||
const int lut_scales_size = 1;\n\
|
||||
int bk = 0;\n\
|
||||
int bm = 0;\n"
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
if (m == {0} && k == {1}) {{\n\
|
||||
bm = BM{0}_{1};\n\
|
||||
bk = BBK{0}_{1};\n\
|
||||
}}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
|
||||
|
||||
for i in range(1, len(kernel_shapes)):
|
||||
kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
|
||||
bm = BM{0}_{1};\n\
|
||||
bk = BBK{0}_{1};\n\
|
||||
}}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
const int n_tile_num = m / bm;\n\
|
||||
const int BK = bk;\n\
|
||||
uint8_t * qweights;\n\
|
||||
bitnet_float_type * scales;\n\
|
||||
\n\
|
||||
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
|
||||
qweights = (uint8_t *) tensor->data;\n\
|
||||
int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;\n\
|
||||
if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes;\n\
|
||||
float * i2_scales = (float * )(qweights + nbytes);\n\
|
||||
scales[0] = (bitnet_float_type) i2_scales[0];\n\
|
||||
\n\
|
||||
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
|
||||
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
|
||||
/* .lut_scales_size = */ lut_scales_size,\n\
|
||||
/* .BK = */ BK,\n\
|
||||
/* .n_tile_num = */ n_tile_num,\n\
|
||||
/* .qweights = */ qweights,\n\
|
||||
/* .scales = */ scales\n\
|
||||
};\n\
|
||||
}\n"])
|
||||
|
||||
return kernel_code
|
||||
|
||||
def get_three_k_two_k(K, bk):
|
||||
bk_num = K // bk
|
||||
three_k = bk_num * bk
|
||||
two_k = K - three_k
|
||||
return two_k, three_k
|
||||
|
||||
if __name__ == "__main__":
|
||||
ModelShapeDict = {
|
||||
"bitnet_b1_58-large" : [[1536, 4096],
|
||||
[1536, 1536],
|
||||
[4096, 1536]],
|
||||
"bitnet_b1_58-3B" : [[3200, 8640],
|
||||
[3200, 3200],
|
||||
[8640, 3200]],
|
||||
"Llama3-8B-1.58-100B-tokens" : [[14336, 4096],
|
||||
[4096, 14336],
|
||||
[1024, 4096],
|
||||
[4096, 4096]]
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='gen impl')
|
||||
parser.add_argument('--model',default="input", type=str, dest="model",
|
||||
help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
|
||||
parser.add_argument('--BM',default="input", type=str,
|
||||
help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
|
||||
parser.add_argument('--BK',default="input", type=str,
|
||||
help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
|
||||
parser.add_argument('--bm',default="input", type=str,
|
||||
help="using simd instructions to compute (bm, 192 / bm) in one block")
|
||||
args = parser.parse_args()
|
||||
|
||||
kernel_shapes = ModelShapeDict[args.model]
|
||||
|
||||
BM_list = [int(item) for item in args.BM.split(',')]
|
||||
BK_list = [int(item) for item in args.BK.split(',')]
|
||||
bm_list = [int(item) for item in args.bm.split(',')]
|
||||
|
||||
tbl_impl_code = []
|
||||
k_list = []
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
k_list.append(get_three_k_two_k(kernel_shapes[i][1], BK_list[i]))
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
tbl_impl_code.append(
|
||||
gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], k_list[i])
|
||||
)
|
||||
|
||||
assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
|
||||
assert (kernel_shapes[i][1] % BK_list[i]) % 32 == 0, "K %% BK %% 32 should be 0"
|
||||
assert bm_list[i] in [32], "choose bm from [32]"
|
||||
|
||||
ctor_code = gen_ctor_code()
|
||||
api_code = gen_top_api(kernel_shapes, k_list)
|
||||
trans_code = gen_transform_code(kernel_shapes)
|
||||
|
||||
output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
|
||||
|
||||
with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
|
||||
f.write(''.join("#if defined(GGML_BITNET_X86_TL2)"))
|
||||
f.write(''.join(ctor_code))
|
||||
for code in tbl_impl_code:
|
||||
f.write(''.join(code))
|
||||
f.write(''.join(api_code))
|
||||
f.write(''.join(trans_code))
|
||||
f.write(''.join("#endif"))
|
||||
|
||||
config = ConfigParser()
|
||||
|
||||
for i in range(len(kernel_shapes)):
|
||||
config.add_section('Kernels_{}'.format(i))
|
||||
config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
|
||||
config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
|
||||
config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
|
||||
config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
|
||||
config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
|
||||
|
||||
with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
|
||||
config.write(configfile)
|
||||
File diff suppressed because it is too large
Load Diff
+1711
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
def run_command(command, shell=False, log_step=None):
|
||||
"""Run a system command and ensure it succeeds."""
|
||||
if log_step:
|
||||
log_file = os.path.join(args.log_dir, log_step + ".log")
|
||||
with open(log_file, "w") as f:
|
||||
try:
|
||||
subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
subprocess.run(command, shell=shell, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Error occurred while running command: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def run_benchmark():
|
||||
bench_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build/bin/llama-bench")
|
||||
if not os.path.exists(bench_path):
|
||||
logging.error(f"Benchmark binary not found, please build first.")
|
||||
sys.exit(1)
|
||||
command = [
|
||||
f'{bench_path}',
|
||||
'-m', args.model,
|
||||
'-n', str(args.n_token),
|
||||
'-ngl', '0',
|
||||
'-b', '1',
|
||||
'-t', str(args.threads),
|
||||
'-p', str(args.n_prompt),
|
||||
'-r', '5'
|
||||
]
|
||||
run_command(command)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
|
||||
parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True)
|
||||
parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128)
|
||||
parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512)
|
||||
parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
run_benchmark()
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user