mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-06 04:40:52 +00:00
update 3rdparty & fix tl2 bug
This commit is contained in:
Vendored
+1
-1
Submodule 3rdparty/llama.cpp updated: 957b59d220...5095a95664
+19
-7
@@ -5,6 +5,7 @@ from configparser import ConfigParser
|
||||
def gen_ctor_code():
|
||||
kernel_code = "\n\
|
||||
#include \"ggml-bitnet.h\"\n\
|
||||
#include \"ggml-cpu-impl.h\"\n\
|
||||
#include <cstring>\n\
|
||||
#include <immintrin.h>\n\
|
||||
#define GGML_BITNET_MAX_NODES 8192\n\
|
||||
@@ -105,7 +106,7 @@ inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {\n\
|
||||
template<int act_k>\n\
|
||||
inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
|
||||
#if defined __AVX2__\n\
|
||||
__m256i vec_lut[16];\n\
|
||||
__m256 vec_lut[16];\n\
|
||||
const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\
|
||||
float scales = *lut_scales;\n\
|
||||
__m256i shuffle_mask = _mm256_set_epi8(\n\
|
||||
@@ -191,7 +192,7 @@ inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_t
|
||||
template<int act_k>\n\
|
||||
inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\
|
||||
#if defined __AVX2__\n\
|
||||
__m256i vec_lut[16];\n\
|
||||
__m256 vec_lut[16];\n\
|
||||
const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\
|
||||
float scales = *lut_scales;\n\
|
||||
__m256i shuffle_mask = _mm256_set_epi8(\n\
|
||||
@@ -623,7 +624,7 @@ def gen_top_api(kernel_shapes, k_list):
|
||||
kernel_code = "".join([kernel_code, "}\n"])
|
||||
return kernel_code
|
||||
|
||||
def gen_transform_code(kernel_shapes):
|
||||
def gen_transform_code(kernel_shapes, fp16):
|
||||
kernel_code = "\n\
|
||||
void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
|
||||
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
|
||||
@@ -657,10 +658,20 @@ void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
|
||||
scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
|
||||
qweights = (uint8_t *) tensor->data;\n\
|
||||
int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;\n\
|
||||
if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes;\n\
|
||||
nbytes = 32 - nbytes % 32 + nbytes;\n\
|
||||
float * i2_scales = (float * )(qweights + nbytes);\n\
|
||||
scales[0] = (bitnet_float_type) i2_scales[0];\n\
|
||||
\n\
|
||||
\n"])
|
||||
|
||||
if fp16:
|
||||
kernel_code = "".join([kernel_code, "\
|
||||
ggml_fp16_t* fp16_scale = (ggml_fp16_t *)aligned_malloc(sizeof(ggml_fp16_t));\n\
|
||||
fp16_scale[0] = GGML_FP32_TO_FP16(i2_scales[0]);\n\
|
||||
scales[0] = (bitnet_float_type) GGML_FP16_TO_FP32(fp16_scale[0]);\n"])
|
||||
else:
|
||||
kernel_code = "".join([kernel_code, "\
|
||||
scales[0] = (bitnet_float_type) i2_scales[0];\n"])
|
||||
|
||||
kernel_code = "".join([kernel_code, "\n\
|
||||
tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
|
||||
bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
|
||||
/* .lut_scales_size = */ lut_scales_size,\n\
|
||||
@@ -702,6 +713,7 @@ if __name__ == "__main__":
|
||||
help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
|
||||
parser.add_argument('--bm',default="input", type=str,
|
||||
help="using simd instructions to compute (bm, 192 / bm) in one block")
|
||||
parser.add_argument('--fp16', action="store_true", help="convert scale to fp16")
|
||||
args = parser.parse_args()
|
||||
|
||||
kernel_shapes = ModelShapeDict[args.model]
|
||||
@@ -730,7 +742,7 @@ if __name__ == "__main__":
|
||||
|
||||
ctor_code = gen_ctor_code()
|
||||
api_code = gen_top_api(kernel_shapes, k_list)
|
||||
trans_code = gen_transform_code(kernel_shapes)
|
||||
trans_code = gen_transform_code(kernel_shapes, args.fp16)
|
||||
|
||||
output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user