commit 6cfd8831fd748cfeb91ba65df0ff542678f77ff0 Author: potassiummmm Date: Thu Oct 17 21:21:10 2024 +0800 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a8a2b3f --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Extensions + +*.a +*.bat +*.bin +*.dll +*.dot +*.etag +*.exe +*.gcda +*.gcno +*.gcov +*.gguf +*.gguf.json +*.lastModified +*.log +*.metallib +*.o +*.so +*.tmp + +# IDE / OS + +.cache/ +.ccls-cache/ +.direnv/ +.DS_Store +.envrc +.idea/ +.swiftpm +.vs/ +.vscode/ +nppBackup + +# Models +models/* + +# Python + +/.venv +__pycache__/ +*/poetry.lock +poetry.toml + +build/ +logs/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..2b36e49 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "3rdparty/llama.cpp"] + path = 3rdparty/llama.cpp + url = https://github.com/Eddie-Wang1120/llama.cpp.git + branch = merge-dev diff --git a/3rdparty/llama.cpp b/3rdparty/llama.cpp new file mode 160000 index 0000000..5371710 --- /dev/null +++ b/3rdparty/llama.cpp @@ -0,0 +1 @@ +Subproject commit 5371710215b86ca760bdc51c298ea1f4be0449a6 diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..b7a0c99 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,73 @@ +cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +project("bitnet.cpp" C CXX) +include(CheckIncludeFileCXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +# option list +option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF) +option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF) + + +set(CMAKE_CXX_STANDARD_REQUIRED true) +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) + +# override ggml options +set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1}) +set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2}) + +if (GGML_BITNET_ARM_TL1) + add_compile_definitions(GGML_BITNET_ARM_TL1) +endif() +if (GGML_BITNET_X86_TL2) + add_compile_definitions(GGML_BITNET_X86_TL2) +endif() + +find_package(Threads REQUIRED) + +add_subdirectory(src) +add_subdirectory(3rdparty/llama.cpp) + +# install + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} + CACHE PATH "Location of header files") +set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} + CACHE PATH "Location of library files") +set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} + CACHE PATH "Location of binary files") +set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) +set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) +set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) + +get_target_property(GGML_DIRECTORY ggml SOURCE_DIR) +get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS) +get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS) +set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES}) +get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES) + +get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake + VERSION ${LLAMA_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) + +set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h) +install(TARGETS llama LIBRARY PUBLIC_HEADER) \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..f9ba8cf --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9e841e7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/README.md b/README.md new file mode 100644 index 0000000..88c5d09 --- /dev/null +++ b/README.md @@ -0,0 +1,228 @@ +# bitnet.cpp +[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) +![version](https://img.shields.io/badge/version-1.0-blue) + +bitnet.cpp is the official inference framework for BitNet models (e.g., BitNet b1.58), optimized for CPU devices. It offers a suite of optimized kernels, that support lossless inference of 1.58-bit models on both x86 and ARM architectures. + +## Demo + +A demo of bitnet.cpp runing a BitNet b1.58 3B model on Apple M2: + +https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1 + +## Timeline + +- 10/17/2024 bitnet.cpp 1.0 released. +- 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) +- 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453) + +## Supported Models + +bitnet.cpp supports a list of 1-bit models available on [Hugging Face](https://huggingface.co/) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelParametersCPUKernel
I2_STL1TL2
bitnet_b1_58-large0.7Bx86
ARM
bitnet_b1_58-3B3.3Bx86
ARM
Llama3-8B-1.58-100B-tokens8.0Bx86
ARM
+ + + +## Installation + +### Requirements +- python>=3.9 +- cmake>=3.22 +- clang>=18 + - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake): + - Desktop-development with C++ + - C++-CMake Tools for Windows + - Git for Windows + - C++-Clang Compiler for Windows + - MS-Build Support for LLVM-Toolset (clang) + - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/) + + ` bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"` +- conda (highly recommend) + +### Build from source + +> [!IMPORTANT] +> If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands + +1. Clone the repo +```bash +git clone --recursive https://github.com/microsoft/BitNet.git +cd BitNet +``` +2. Install the dependencies +```bash +# (Recommended) Create a new conda environment +conda create -n bitnet-cpp python=3.9 +conda activate bitnet-cpp + +pip install -r requirements.txt +``` +3. Build the project +```bash +# Download the model from Hugging Face, convert it to quantized gguf format, and build the project +python setup_env.py --hf-repo HF1BitLLM/Llama3-8B-1.58-100B-tokens -q i2_s + +# Or you can manually download the model and run with local path +huggingface-cli download HF1BitLLM/Llama3-8B-1.58-100B-tokens --local-dir models/Llama3-8B-1.58-100B-tokens +python setup_env.py -md models/Llama3-8B-1.58-100B-tokens -q i2_s +``` +
+usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
+                    [--use-pretuned]
+
+Setup the environment for running inference
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens}
+                        Model used for inference
+  --model-dir MODEL_DIR, -md MODEL_DIR
+                        Directory to save/load the model
+  --log-dir LOG_DIR, -ld LOG_DIR
+                        Directory to save the logging info
+  --quant-type {i2_s,tl1}, -q {i2_s,tl1}
+                        Quantization type
+  --quant-embd          Quantize the embeddings to f16
+  --use-pretuned, -p    Use the pretuned kernel parameters
+
+## Usage +### Basic usage +```bash +# Run inference with the quantized model +python run_inference.py -m models/Llama3-8B-1.58-100B-tokens/ggml-model-i2_s.gguf -p "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:" -n 6 -temp 0 + +# Output: +# Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary? +# Answer: Mary is in the garden. + +``` +
+usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE]
+
+Run inference
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -m MODEL, --model MODEL
+                        Path to model file
+  -n N_PREDICT, --n-predict N_PREDICT
+                        Number of tokens to predict when generating text
+  -p PROMPT, --prompt PROMPT
+                        Prompt to generate text from
+  -t THREADS, --threads THREADS
+                        Number of threads to use
+  -c CTX_SIZE, --ctx-size CTX_SIZE
+                        Size of the prompt context
+  -temp TEMPERATURE, --temperature TEMPERATURE
+                        Temperature, a hyperparameter that controls the randomness of the generated text
+
+ +### Benchmark +We provide scripts to run the inference benchmark providing a model. + +``` +usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS] + +Setup the environment for running the inference + +required arguments: + -m MODEL, --model MODEL + Path to the model file. + +optional arguments: + -h, --help + Show this help message and exit. + -n N_TOKEN, --n-token N_TOKEN + Number of generated tokens. + -p N_PROMPT, --n-prompt N_PROMPT + Prompt to generate text from. + -t THREADS, --threads THREADS + Number of threads to use. +``` + +Here's a brief explanation of each argument: + +- `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script. +- `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128. +- `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512. +- `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2. +- `-h`, `--help`: Show the help message and exit. Use this argument to display usage information. + +For example: + +```sh +python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4 +``` + +This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads. + +For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine: + +```bash +python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M + +# Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate +python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128 +``` + +## Acknowledgements + +This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. We also thank [T-MAC](https://github.com/microsoft/T-MAC/) team for the helpful discussion on the LUT method for low-bit LLM inference. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..b3c89ef --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). + + diff --git a/assets/tl1.png b/assets/tl1.png new file mode 100644 index 0000000..f0b3e1e Binary files /dev/null and b/assets/tl1.png differ diff --git a/assets/tl2.png b/assets/tl2.png new file mode 100644 index 0000000..9587a1f Binary files /dev/null and b/assets/tl2.png differ diff --git a/docs/codegen.md b/docs/codegen.md new file mode 100644 index 0000000..f085309 --- /dev/null +++ b/docs/codegen.md @@ -0,0 +1,49 @@ +Codegen for TL1 and TL2 +------------------------ + +codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2. + +We cutting weight into multiple compute blocks to best utilize hardware capabilities. + +### Example +bitnet_b1_58-large: + +- Make sure Mamtul kernels shapes \ +For example, bitnet_b1_58-large Matmul kernel shapes are:\ +[1536, 4096]\ +[1536, 1536]\ +[4096, 1536] + +- Make sure each BM, BK, bm for each kernel to meet the requirements below +- Generate codes\ +For example, for bitnet_b1_58-large, we can gencode like: + +```bash +# For TL1 +python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32 + +# For TL2 +python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32 +``` + +### TL1: +![TL1](../assets/tl1.png) + +For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it. + +Thus, we need to make sure +- M % BM == 0 +- K % BK == 0 +- BM % bm == 0 +- bm choose in [32, 64] + +### TL2: +![TL2](../assets/tl2.png) + +For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K). + +Thus, we needs to make sure +- M % BM == 0 +- K % BK % 32 == 0 +- BM % bm == 0 +- bm choose in \[32\] \ No newline at end of file diff --git a/include/ggml-bitnet.h b/include/ggml-bitnet.h new file mode 100644 index 0000000..3f8571c --- /dev/null +++ b/include/ggml-bitnet.h @@ -0,0 +1,49 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __ARM_NEON +#include +typedef float32_t bitnet_float_type; +#else +typedef float bitnet_float_type; +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct bitnet_tensor_extra { + int lut_scales_size; + int BK; + int n_tile_num; + uint8_t * qweights; + bitnet_float_type * scales; +}; + +GGML_API void ggml_bitnet_init(void); +GGML_API void ggml_bitnet_free(void); +// src0->type == Q4_0/IQ2_XXS/IQ3_XXS +// bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros) +// If use i-quantization gguf models, the results will be wrong +// TODO: add customized block types Q2_0/Q3_0 +GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); +GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); +GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits); +GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits); +GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor); +GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type); +GGML_API void ggml_bitnet_set_n_threads(int n_threads); +#if defined(GGML_BITNET_ARM_TL1) +GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C); +GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT); +#endif +#if defined(GGML_BITNET_X86_TL2) +GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C); +GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT); +#endif + +#ifdef __cplusplus +} +#endif diff --git a/media/benchmark.png b/media/benchmark.png new file mode 100644 index 0000000..615ac24 Binary files /dev/null and b/media/benchmark.png differ diff --git a/media/demo.mp4 b/media/demo.mp4 new file mode 100644 index 0000000..419e1b0 Binary files /dev/null and b/media/demo.mp4 differ diff --git a/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl1.h b/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl1.h new file mode 100644 index 0000000..024fb78 --- /dev/null +++ b/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl1.h @@ -0,0 +1,771 @@ +#if defined(GGML_BITNET_ARM_TL1) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) {{ +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +}} +static void aligned_free(void * ptr) {{ +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +}} + +void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#ifdef __ARM_NEON + float32x4_t temp_max = vdupq_n_f32(0); + for (int i=0; i < k / 4; i++) {{ + float32x4_t vec_bs = vld1q_f32(b + 4 * i); + float32x4_t abssum = vabsq_f32(vec_bs); + temp_max = vmaxq_f32(abssum, temp_max); + }} + float32_t scales = 127 / vmaxvq_f32(temp_max); + *lut_scales = scales; +#elif defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + // #pragma unroll + for (int i = 0; i < k / 8; i++) {{ + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + }} + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif +}} + +void partial_max_reset(void* lut_scales_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + *lut_scales = 0.0; +}} + +#ifdef __ARM_NEON +inline void Transpose_8_8( + int16x8_t *v0, + int16x8_t *v1, + int16x8_t *v2, + int16x8_t *v3, + int16x8_t *v4, + int16x8_t *v5, + int16x8_t *v6, + int16x8_t *v7) +{{ + int16x8x2_t q04 = vzipq_s16(*v0, *v4); + int16x8x2_t q15 = vzipq_s16(*v1, *v5); + int16x8x2_t q26 = vzipq_s16(*v2, *v6); + int16x8x2_t q37 = vzipq_s16(*v3, *v7); + + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); + + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); + + *v0 = q_fin_0.val[0]; + *v1 = q_fin_0.val[1]; + *v2 = q_fin_1.val[0]; + *v3 = q_fin_1.val[1]; + *v4 = q_fin_2.val[0]; + *v5 = q_fin_2.val[1]; + *v6 = q_fin_3.val[0]; + *v7 = q_fin_3.val[1]; +}} +#endif + +template +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ +#ifdef __ARM_NEON + int16x8_t vec_lut[16]; + float32_t scales = *lut_scales; + uint8_t tbl_mask[16]; + tbl_mask[0] = 0; + tbl_mask[1] = 2; + tbl_mask[2] = 4; + tbl_mask[3] = 6; + tbl_mask[4] = 8; + tbl_mask[5] = 10; + tbl_mask[6] = 12; + tbl_mask[7] = 14; + tbl_mask[8] = 1; + tbl_mask[9] = 3; + tbl_mask[10] = 5; + tbl_mask[11] = 7; + tbl_mask[12] = 9; + tbl_mask[13] = 11; + tbl_mask[14] = 13; + tbl_mask[15] = 15; + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) {{ + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); + vec_lut[0] = vdupq_n_s16(0); + vec_lut[0] = vec_lut[0] - vec_bs_0; + vec_lut[0] = vec_lut[0] - vec_bs_1; + vec_lut[1] = vdupq_n_s16(0); + vec_lut[1] = vec_lut[1] - vec_bs_0; + vec_lut[2] = vdupq_n_s16(0); + vec_lut[2] = vec_lut[2] - vec_bs_0; + vec_lut[2] = vec_lut[2] + vec_bs_1; + vec_lut[3] = vdupq_n_s16(0); + vec_lut[3] = vec_lut[3] - vec_bs_1; + vec_lut[4] = vdupq_n_s16(0); + vec_lut[5] = vec_bs_1; + vec_lut[6] = vec_bs_0; + vec_lut[6] = vec_lut[6] - vec_bs_1; + vec_lut[7] = vec_bs_0; + vec_lut[8] = vec_bs_0; + vec_lut[8] = vec_lut[8] + vec_bs_1; + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); +#pragma unroll + for (int idx = 0; idx < 8; idx++) {{ + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); + int8x8_t q0_low = vget_low_s8(q0_s); + int8x8_t q0_high = vget_high_s8(q0_s); + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); + int8x8_t q1_low = vget_low_s8(q1_s); + int8x8_t q1_high = vget_high_s8(q1_s); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); + }} + }} +#endif +}} + +static bool is_type_supported(enum ggml_type type) {{ + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL1) {{ + return true; + }} else {{ + return false; + }} +}} +#include + +#define BM14336_4096 256 +#define BBK14336_4096 128 +inline void tbl_impl_14336_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK14336_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM14336_4096; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_14336_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM14336_4096]; + memset(&(CBits[0]), 0, BM14336_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK14336_4096; ++k_outer) { + tbl_impl_14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK14336_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK14336_4096 / 2 / 2 * BM14336_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM14336_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM4096_14336 256 +#define BBK4096_14336 128 +inline void tbl_impl_4096_14336(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK4096_14336 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM4096_14336; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_4096_14336(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM4096_14336]; + memset(&(CBits[0]), 0, BM4096_14336 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 14336 / BBK4096_14336; ++k_outer) { + tbl_impl_4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_14336 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_14336 / 2 / 2 * BM4096_14336)]))); + } +#pragma unroll + for (int i = 0; i < BM4096_14336; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM1024_4096 128 +#define BBK1024_4096 64 +inline void tbl_impl_1024_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK1024_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM1024_4096; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_1024_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM1024_4096]; + memset(&(CBits[0]), 0, BM1024_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK1024_4096; ++k_outer) { + tbl_impl_1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1024_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1024_4096 / 2 / 2 * BM1024_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM1024_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM4096_4096 128 +#define BBK4096_4096 64 +inline void tbl_impl_4096_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK4096_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM4096_4096; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_4096_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM4096_4096]; + memset(&(CBits[0]), 0, BM4096_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK4096_4096; ++k_outer) { + tbl_impl_4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_4096 / 2 / 2 * BM4096_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM4096_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; + +template +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); + + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); +}} +void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { + if (m == 14336 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } + else if (m == 4096 && k == 14336) { + preprocessor_k<14336>(B, LUT_Scales, QLUT); + } + else if (m == 1024 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } + else if (m == 4096 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } +} +void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 14336 && k == 4096) { + qgemm_lut_14336_4096(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 4096 && k == 14336) { + qgemm_lut_4096_14336(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 1024 && k == 4096) { + qgemm_lut_1024_4096(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 4096 && k == 4096) { + qgemm_lut_4096_4096(A, LUT, Scales, LUT_Scales, C); + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + const int scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 14336 && k == 4096) { + bm = BM14336_4096; + bk = BBK14336_4096; + } +else if (m == 4096 && k == 14336) { + bm = BM4096_14336; + bk = BBK4096_14336; + } +else if (m == 1024 && k == 4096) { + bm = BM1024_4096; + bk = BBK1024_4096; + } +else if (m == 4096 && k == 4096) { + bm = BM4096_4096; + bk = BBK4096_4096; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .scales_size = */ scales_size, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl2.h b/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl2.h new file mode 100644 index 0000000..88dc9e2 --- /dev/null +++ b/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl2.h @@ -0,0 +1,1454 @@ +#if defined(GGML_BITNET_X86_TL2) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} +#define BK2 32 +#if defined __AVX2__ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi32(va, vb); + *vh = _mm256_unpackhi_epi32(va, vb); +} +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi64(va, vb); + *vh = _mm256_unpackhi_epi64(va, vb); +} +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); +} +inline void Transpose_8_8( + __m256i *v0, + __m256i *v1, + __m256i *v2, + __m256i *v3, + __m256i *v4, + __m256i *v5, + __m256i *v6, + __m256i *v7) +{ + __m256i w0, w1, w2, w3, w4, w5, w6, w7; + __m256i x0, x1, x2, x3, x4, x5, x6, x7; + _mm256_merge_epi32(*v0, *v1, &w0, &w1); + _mm256_merge_epi32(*v2, *v3, &w2, &w3); + _mm256_merge_epi32(*v4, *v5, &w4, &w5); + _mm256_merge_epi32(*v6, *v7, &w6, &w7); + _mm256_merge_epi64(w0, w2, &x0, &x1); + _mm256_merge_epi64(w1, w3, &x2, &x3); + _mm256_merge_epi64(w4, w6, &x4, &x5); + _mm256_merge_epi64(w5, w7, &x6, &x7); + _mm256_merge_si128(x0, x4, v0, v1); + _mm256_merge_si128(x1, x5, v2, v3); + _mm256_merge_si128(x2, x6, v4, v5); + _mm256_merge_si128(x3, x7, v6, v7); +} +#endif +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#if defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif + return 0; +} +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + #pragma unroll + for (int i=0; i< bs; i++) { + lut_scales[i] = 0.0; + } + return 0; +} +template +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 24; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); + + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = vec_b0i; + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); + vec_lut[12] = vec_b0i; + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); + vec_lut[11] = vec_b0i; + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); + vec_lut[10] = vec_b0i; + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); + vec_lut[9] = vec_b0i; + vec_lut[8] = vec_b0i; + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); + vec_lut[7] = vec_b0i; + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); + vec_lut[6] = vec_b0i; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); + vec_lut[5] = vec_b0i; + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); + vec_lut[4] = vec_b1i; + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); + vec_lut[3] = vec_b1i; + vec_lut[2] = vec_b1i; + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); + vec_lut[1] = vec_b2i; + vec_lut[0] = _mm256_setzero_si256(); + __m256i ix[16]; + +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + int8_t* qlut_i8 = reinterpret_cast(qlut); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + + *lut_scales = scales; +#endif + return 0; +} + +template +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); + + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = _mm256_setzero_si256(); + vec_lut[12] = _mm256_setzero_si256(); + vec_lut[11] = _mm256_setzero_si256(); + vec_lut[10] = _mm256_setzero_si256(); + vec_lut[9] = _mm256_setzero_si256(); + vec_lut[8] = vec_b0; + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); + vec_lut[7] = vec_b0; + vec_lut[6] = vec_b0; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); + vec_lut[5] = vec_b1; + vec_lut[4] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); + vec_lut[2] = _mm256_setzero_si256(); + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); + vec_lut[1] = _mm256_setzero_si256(); + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); + vec_lut[0] = _mm256_setzero_si256(); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); + + __m256i ix[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + + int8_t* qlut_i8 = reinterpret_cast(qlut); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + *lut_scales = scales; +#endif + return 0; +} +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL2) { + return true; + } else { + return false; + } +} +#include + +#define BM14336_4096 256 +#define BBK14336_4096 96 +template +inline void three_tbl_impl_14336_4096(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK14336_4096 / 3; +#pragma unroll + for (int i = 0; i < BM14336_4096; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM14336_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM14336_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM14336_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM14336_4096 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM14336_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM14336_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM14336_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM14336_4096 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl14336_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM14336_4096; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM14336_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM14336_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM14336_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM14336_4096 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM14336_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM14336_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM14336_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM14336_4096 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_14336_4096(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM14336_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM14336_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4032 / BBK14336_4096; ++k_outer) { + three_tbl_impl_14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK14336_4096 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK14336_4096 / 3 / 2 * BM14336_4096)])), (&(((uint8_t*)sign)[(k_outer * BBK14336_4096 / 3 / 8 * BM14336_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM14336_4096; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM14336_4096]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_14336_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM14336_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM14336_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 64 / 32; ++k_outer) { + two_tbl_impl14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM14336_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM14336_4096; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM14336_4096]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM4096_14336 128 +#define BBK4096_14336 96 +template +inline void three_tbl_impl_4096_14336(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK4096_14336 / 3; +#pragma unroll + for (int i = 0; i < BM4096_14336; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_14336 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_14336 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_14336 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_14336 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_14336 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_14336 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_14336 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_14336 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl4096_14336(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM4096_14336; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_14336 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_14336 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_14336 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_14336 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_14336 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_14336 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_14336 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_14336 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_4096_14336(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_14336]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_14336 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 14304 / BBK4096_14336; ++k_outer) { + three_tbl_impl_4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_14336 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_14336 / 3 / 2 * BM4096_14336)])), (&(((uint8_t*)sign)[(k_outer * BBK4096_14336 / 3 / 8 * BM4096_14336)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_14336; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM4096_14336]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_4096_14336(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_14336]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_14336 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 32 / 32; ++k_outer) { + two_tbl_impl4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM4096_14336)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_14336; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM4096_14336]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM1024_4096 256 +#define BBK1024_4096 96 +template +inline void three_tbl_impl_1024_4096(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK1024_4096 / 3; +#pragma unroll + for (int i = 0; i < BM1024_4096; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1024_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1024_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1024_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1024_4096 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1024_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1024_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1024_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1024_4096 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl1024_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM1024_4096; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1024_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1024_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1024_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1024_4096 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1024_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1024_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1024_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1024_4096 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_1024_4096(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1024_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1024_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4032 / BBK1024_4096; ++k_outer) { + three_tbl_impl_1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1024_4096 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1024_4096 / 3 / 2 * BM1024_4096)])), (&(((uint8_t*)sign)[(k_outer * BBK1024_4096 / 3 / 8 * BM1024_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1024_4096; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM1024_4096]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_1024_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1024_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1024_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 64 / 32; ++k_outer) { + two_tbl_impl1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM1024_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1024_4096; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM1024_4096]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM4096_4096 128 +#define BBK4096_4096 96 +template +inline void three_tbl_impl_4096_4096(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK4096_4096 / 3; +#pragma unroll + for (int i = 0; i < BM4096_4096; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_4096 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_4096 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl4096_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM4096_4096; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_4096 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_4096 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_4096_4096(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4032 / BBK4096_4096; ++k_outer) { + three_tbl_impl_4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_4096 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_4096 / 3 / 2 * BM4096_4096)])), (&(((uint8_t*)sign)[(k_outer * BBK4096_4096 / 3 / 8 * BM4096_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_4096; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM4096_4096]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_4096_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 64 / 32; ++k_outer) { + two_tbl_impl4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM4096_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_4096; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM4096_4096]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { + partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); + if (m == 14336 && two_k == 64 && three_k == 4032) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<4032>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<64>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 4032])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 4096 && two_k == 32 && three_k == 14304) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<14304>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<32>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 14304])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 1024 && two_k == 64 && three_k == 4032) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<4032>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<64>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 4032])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 4096 && two_k == 64 && three_k == 4032) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<4032>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<64>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 4032])), (&(((float*)LUT_Scales)[b]))); + } + } +} +void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 14336 && k == 4096) { + if (BK == 64) { + if (bs == 1) { + two_qgemm_lut_14336_4096<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_14336_4096<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_14336_4096<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_14336_4096<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_14336_4096<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_14336_4096<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 4032) { + if (bs == 1) { + three_qgemm_lut_14336_4096<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_14336_4096<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_14336_4096<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_14336_4096<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_14336_4096<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_14336_4096<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 4096 && k == 14336) { + if (BK == 32) { + if (bs == 1) { + two_qgemm_lut_4096_14336<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_4096_14336<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_4096_14336<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_4096_14336<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_4096_14336<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_4096_14336<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 14304) { + if (bs == 1) { + three_qgemm_lut_4096_14336<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_4096_14336<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_4096_14336<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_4096_14336<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_4096_14336<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_4096_14336<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 1024 && k == 4096) { + if (BK == 64) { + if (bs == 1) { + two_qgemm_lut_1024_4096<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_1024_4096<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_1024_4096<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_1024_4096<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_1024_4096<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_1024_4096<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 4032) { + if (bs == 1) { + three_qgemm_lut_1024_4096<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_1024_4096<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_1024_4096<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_1024_4096<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_1024_4096<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_1024_4096<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 4096 && k == 4096) { + if (BK == 64) { + if (bs == 1) { + two_qgemm_lut_4096_4096<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_4096_4096<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_4096_4096<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_4096_4096<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_4096_4096<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_4096_4096<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 4032) { + if (bs == 1) { + three_qgemm_lut_4096_4096<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_4096_4096<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_4096_4096<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_4096_4096<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_4096_4096<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_4096_4096<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 14336 && k == 4096) { + bm = BM14336_4096; + bk = BBK14336_4096; + } +else if (m == 4096 && k == 14336) { + bm = BM4096_14336; + bk = BBK4096_14336; + } +else if (m == 1024 && k == 4096) { + bm = BM1024_4096; + bk = BBK1024_4096; + } +else if (m == 4096 && k == 4096) { + bm = BM4096_4096; + bk = BBK4096_4096; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini b/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini new file mode 100644 index 0000000..b27f477 --- /dev/null +++ b/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini @@ -0,0 +1,28 @@ +[Kernels_0] +m = 14336 +k = 4096 +bm = 256 +bk = 128 +bmm = 64 + +[Kernels_1] +m = 4096 +k = 14336 +bm = 256 +bk = 128 +bmm = 32 + +[Kernels_2] +m = 1024 +k = 4096 +bm = 128 +bk = 64 +bmm = 64 + +[Kernels_3] +m = 4096 +k = 4096 +bm = 128 +bk = 64 +bmm = 32 + diff --git a/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini b/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini new file mode 100644 index 0000000..a767353 --- /dev/null +++ b/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini @@ -0,0 +1,28 @@ +[Kernels_0] +m = 14336 +k = 4096 +bm = 256 +bk = 96 +bmm = 32 + +[Kernels_1] +m = 4096 +k = 14336 +bm = 128 +bk = 96 +bmm = 32 + +[Kernels_2] +m = 1024 +k = 4096 +bm = 256 +bk = 96 +bmm = 32 + +[Kernels_3] +m = 4096 +k = 4096 +bm = 128 +bk = 96 +bmm = 32 + diff --git a/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h b/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h new file mode 100644 index 0000000..3f3f551 --- /dev/null +++ b/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h @@ -0,0 +1,627 @@ +#if defined(GGML_BITNET_ARM_TL1) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) {{ +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +}} +static void aligned_free(void * ptr) {{ +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +}} + +void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#ifdef __ARM_NEON + float32x4_t temp_max = vdupq_n_f32(0); + for (int i=0; i < k / 4; i++) {{ + float32x4_t vec_bs = vld1q_f32(b + 4 * i); + float32x4_t abssum = vabsq_f32(vec_bs); + temp_max = vmaxq_f32(abssum, temp_max); + }} + float32_t scales = 127 / vmaxvq_f32(temp_max); + *lut_scales = scales; +#elif defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + // #pragma unroll + for (int i = 0; i < k / 8; i++) {{ + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + }} + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif +}} + +void partial_max_reset(void* lut_scales_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + *lut_scales = 0.0; +}} + +#ifdef __ARM_NEON +inline void Transpose_8_8( + int16x8_t *v0, + int16x8_t *v1, + int16x8_t *v2, + int16x8_t *v3, + int16x8_t *v4, + int16x8_t *v5, + int16x8_t *v6, + int16x8_t *v7) +{{ + int16x8x2_t q04 = vzipq_s16(*v0, *v4); + int16x8x2_t q15 = vzipq_s16(*v1, *v5); + int16x8x2_t q26 = vzipq_s16(*v2, *v6); + int16x8x2_t q37 = vzipq_s16(*v3, *v7); + + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); + + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); + + *v0 = q_fin_0.val[0]; + *v1 = q_fin_0.val[1]; + *v2 = q_fin_1.val[0]; + *v3 = q_fin_1.val[1]; + *v4 = q_fin_2.val[0]; + *v5 = q_fin_2.val[1]; + *v6 = q_fin_3.val[0]; + *v7 = q_fin_3.val[1]; +}} +#endif + +template +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ +#ifdef __ARM_NEON + int16x8_t vec_lut[16]; + float32_t scales = *lut_scales; + uint8_t tbl_mask[16]; + tbl_mask[0] = 0; + tbl_mask[1] = 2; + tbl_mask[2] = 4; + tbl_mask[3] = 6; + tbl_mask[4] = 8; + tbl_mask[5] = 10; + tbl_mask[6] = 12; + tbl_mask[7] = 14; + tbl_mask[8] = 1; + tbl_mask[9] = 3; + tbl_mask[10] = 5; + tbl_mask[11] = 7; + tbl_mask[12] = 9; + tbl_mask[13] = 11; + tbl_mask[14] = 13; + tbl_mask[15] = 15; + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) {{ + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); + vec_lut[0] = vdupq_n_s16(0); + vec_lut[0] = vec_lut[0] - vec_bs_0; + vec_lut[0] = vec_lut[0] - vec_bs_1; + vec_lut[1] = vdupq_n_s16(0); + vec_lut[1] = vec_lut[1] - vec_bs_0; + vec_lut[2] = vdupq_n_s16(0); + vec_lut[2] = vec_lut[2] - vec_bs_0; + vec_lut[2] = vec_lut[2] + vec_bs_1; + vec_lut[3] = vdupq_n_s16(0); + vec_lut[3] = vec_lut[3] - vec_bs_1; + vec_lut[4] = vdupq_n_s16(0); + vec_lut[5] = vec_bs_1; + vec_lut[6] = vec_bs_0; + vec_lut[6] = vec_lut[6] - vec_bs_1; + vec_lut[7] = vec_bs_0; + vec_lut[8] = vec_bs_0; + vec_lut[8] = vec_lut[8] + vec_bs_1; + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); +#pragma unroll + for (int idx = 0; idx < 8; idx++) {{ + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); + int8x8_t q0_low = vget_low_s8(q0_s); + int8x8_t q0_high = vget_high_s8(q0_s); + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); + int8x8_t q1_low = vget_low_s8(q1_s); + int8x8_t q1_high = vget_high_s8(q1_s); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); + }} + }} +#endif +}} + +static bool is_type_supported(enum ggml_type type) {{ + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL1) {{ + return true; + }} else {{ + return false; + }} +}} +#include + +#define BM3200_8640 160 +#define BBK3200_8640 64 +inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK3200_8640 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM3200_8640; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM3200_8640]; + memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) { + tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)]))); + } +#pragma unroll + for (int i = 0; i < BM3200_8640; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM3200_3200 320 +#define BBK3200_3200 128 +inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK3200_3200 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM3200_3200; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM3200_3200]; + memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) { + tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)]))); + } +#pragma unroll + for (int i = 0; i < BM3200_3200; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM8640_3200 320 +#define BBK8640_3200 64 +inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK8640_3200 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM8640_3200; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM8640_3200]; + memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) { + tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)]))); + } +#pragma unroll + for (int i = 0; i < BM8640_3200; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; + +template +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); + + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); +}} +void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { + if (m == 3200 && k == 8640) { + preprocessor_k<8640>(B, LUT_Scales, QLUT); + } + else if (m == 3200 && k == 3200) { + preprocessor_k<3200>(B, LUT_Scales, QLUT); + } + else if (m == 8640 && k == 3200) { + preprocessor_k<3200>(B, LUT_Scales, QLUT); + } +} +void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 3200 && k == 8640) { + qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 3200 && k == 3200) { + qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 8640 && k == 3200) { + qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C); + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + const int scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 3200 && k == 8640) { + bm = BM3200_8640; + bk = BBK3200_8640; + } +else if (m == 3200 && k == 3200) { + bm = BM3200_3200; + bk = BBK3200_3200; + } +else if (m == 8640 && k == 3200) { + bm = BM8640_3200; + bk = BBK8640_3200; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .scales_size = */ scales_size, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl2.h b/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl2.h new file mode 100644 index 0000000..678b0f3 --- /dev/null +++ b/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl2.h @@ -0,0 +1,1167 @@ +#if defined(GGML_BITNET_X86_TL2) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} +#define BK2 32 +#if defined __AVX2__ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi32(va, vb); + *vh = _mm256_unpackhi_epi32(va, vb); +} +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi64(va, vb); + *vh = _mm256_unpackhi_epi64(va, vb); +} +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); +} +inline void Transpose_8_8( + __m256i *v0, + __m256i *v1, + __m256i *v2, + __m256i *v3, + __m256i *v4, + __m256i *v5, + __m256i *v6, + __m256i *v7) +{ + __m256i w0, w1, w2, w3, w4, w5, w6, w7; + __m256i x0, x1, x2, x3, x4, x5, x6, x7; + _mm256_merge_epi32(*v0, *v1, &w0, &w1); + _mm256_merge_epi32(*v2, *v3, &w2, &w3); + _mm256_merge_epi32(*v4, *v5, &w4, &w5); + _mm256_merge_epi32(*v6, *v7, &w6, &w7); + _mm256_merge_epi64(w0, w2, &x0, &x1); + _mm256_merge_epi64(w1, w3, &x2, &x3); + _mm256_merge_epi64(w4, w6, &x4, &x5); + _mm256_merge_epi64(w5, w7, &x6, &x7); + _mm256_merge_si128(x0, x4, v0, v1); + _mm256_merge_si128(x1, x5, v2, v3); + _mm256_merge_si128(x2, x6, v4, v5); + _mm256_merge_si128(x3, x7, v6, v7); +} +#endif +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#if defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif + return 0; +} +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + #pragma unroll + for (int i=0; i< bs; i++) { + lut_scales[i] = 0.0; + } + return 0; +} +template +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 24; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); + + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = vec_b0i; + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); + vec_lut[12] = vec_b0i; + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); + vec_lut[11] = vec_b0i; + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); + vec_lut[10] = vec_b0i; + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); + vec_lut[9] = vec_b0i; + vec_lut[8] = vec_b0i; + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); + vec_lut[7] = vec_b0i; + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); + vec_lut[6] = vec_b0i; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); + vec_lut[5] = vec_b0i; + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); + vec_lut[4] = vec_b1i; + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); + vec_lut[3] = vec_b1i; + vec_lut[2] = vec_b1i; + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); + vec_lut[1] = vec_b2i; + vec_lut[0] = _mm256_setzero_si256(); + __m256i ix[16]; + +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + int8_t* qlut_i8 = reinterpret_cast(qlut); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + + *lut_scales = scales; +#endif + return 0; +} + +template +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); + + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = _mm256_setzero_si256(); + vec_lut[12] = _mm256_setzero_si256(); + vec_lut[11] = _mm256_setzero_si256(); + vec_lut[10] = _mm256_setzero_si256(); + vec_lut[9] = _mm256_setzero_si256(); + vec_lut[8] = vec_b0; + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); + vec_lut[7] = vec_b0; + vec_lut[6] = vec_b0; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); + vec_lut[5] = vec_b1; + vec_lut[4] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); + vec_lut[2] = _mm256_setzero_si256(); + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); + vec_lut[1] = _mm256_setzero_si256(); + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); + vec_lut[0] = _mm256_setzero_si256(); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); + + __m256i ix[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + + int8_t* qlut_i8 = reinterpret_cast(qlut); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + *lut_scales = scales; +#endif + return 0; +} +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL2) { + return true; + } else { + return false; + } +} +#include + +#define BM3200_8640 160 +#define BBK3200_8640 96 +template +inline void three_tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK3200_8640 / 3; +#pragma unroll + for (int i = 0; i < BM3200_8640; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl3200_8640(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM3200_8640; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_3200_8640(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_8640]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_8640 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) { + three_tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 3 / 2 * BM3200_8640)])), (&(((uint8_t*)sign)[(k_outer * BBK3200_8640 / 3 / 8 * BM3200_8640)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM3200_8640; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM3200_8640]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_8640]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_8640 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 0 / 32; ++k_outer) { + two_tbl_impl3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM3200_8640)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM3200_8640; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM3200_8640]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM3200_3200 320 +#define BBK3200_3200 96 +template +inline void three_tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK3200_3200 / 3; +#pragma unroll + for (int i = 0; i < BM3200_3200; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl3200_3200(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM3200_3200; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_3200_3200(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_3200]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 3168 / BBK3200_3200; ++k_outer) { + three_tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 3 / 2 * BM3200_3200)])), (&(((uint8_t*)sign)[(k_outer * BBK3200_3200 / 3 / 8 * BM3200_3200)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM3200_3200; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM3200_3200]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_3200]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 32 / 32; ++k_outer) { + two_tbl_impl3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM3200_3200)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM3200_3200; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM3200_3200]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM8640_3200 320 +#define BBK8640_3200 96 +template +inline void three_tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK8640_3200 / 3; +#pragma unroll + for (int i = 0; i < BM8640_3200; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl8640_3200(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM8640_3200; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_8640_3200(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM8640_3200]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM8640_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 3168 / BBK8640_3200; ++k_outer) { + three_tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 3 / 2 * BM8640_3200)])), (&(((uint8_t*)sign)[(k_outer * BBK8640_3200 / 3 / 8 * BM8640_3200)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM8640_3200; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM8640_3200]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM8640_3200]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM8640_3200 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 32 / 32; ++k_outer) { + two_tbl_impl8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM8640_3200)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM8640_3200; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM8640_3200]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { + partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); + if (m == 3200 && two_k == 0 && three_k == 8640) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<8640>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<0>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 8640])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 3200 && two_k == 32 && three_k == 3168) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<3168>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<32>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 3168])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 8640 && two_k == 32 && three_k == 3168) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<3168>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<32>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 3168])), (&(((float*)LUT_Scales)[b]))); + } + } +} +void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 3200 && k == 8640) { + if (BK == 0) { + if (bs == 1) { + two_qgemm_lut_3200_8640<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_3200_8640<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_3200_8640<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_3200_8640<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_3200_8640<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_3200_8640<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 8640) { + if (bs == 1) { + three_qgemm_lut_3200_8640<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_3200_8640<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_3200_8640<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_3200_8640<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_3200_8640<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_3200_8640<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 3200 && k == 3200) { + if (BK == 32) { + if (bs == 1) { + two_qgemm_lut_3200_3200<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_3200_3200<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_3200_3200<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_3200_3200<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_3200_3200<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_3200_3200<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 3168) { + if (bs == 1) { + three_qgemm_lut_3200_3200<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_3200_3200<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_3200_3200<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_3200_3200<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_3200_3200<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_3200_3200<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 8640 && k == 3200) { + if (BK == 32) { + if (bs == 1) { + two_qgemm_lut_8640_3200<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_8640_3200<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_8640_3200<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_8640_3200<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_8640_3200<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_8640_3200<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 3168) { + if (bs == 1) { + three_qgemm_lut_8640_3200<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_8640_3200<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_8640_3200<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_8640_3200<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_8640_3200<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_8640_3200<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 3200 && k == 8640) { + bm = BM3200_8640; + bk = BBK3200_8640; + } +else if (m == 3200 && k == 3200) { + bm = BM3200_3200; + bk = BBK3200_3200; + } +else if (m == 8640 && k == 3200) { + bm = BM8640_3200; + bk = BBK8640_3200; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini b/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini new file mode 100644 index 0000000..b8640ad --- /dev/null +++ b/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini @@ -0,0 +1,21 @@ +[Kernels_0] +m = 3200 +k = 8640 +bm = 160 +bk = 64 +bmm = 32 + +[Kernels_1] +m = 3200 +k = 3200 +bm = 320 +bk = 128 +bmm = 64 + +[Kernels_2] +m = 8640 +k = 3200 +bm = 320 +bk = 64 +bmm = 32 + diff --git a/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini b/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini new file mode 100644 index 0000000..46468bf --- /dev/null +++ b/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini @@ -0,0 +1,21 @@ +[Kernels_0] +m = 3200 +k = 8640 +bm = 160 +bk = 96 +bmm = 32 + +[Kernels_1] +m = 3200 +k = 3200 +bm = 320 +bk = 96 +bmm = 32 + +[Kernels_2] +m = 8640 +k = 3200 +bm = 320 +bk = 96 +bmm = 32 + diff --git a/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h b/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h new file mode 100644 index 0000000..d38806b --- /dev/null +++ b/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h @@ -0,0 +1,627 @@ +#if defined(GGML_BITNET_ARM_TL1) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) {{ +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +}} +static void aligned_free(void * ptr) {{ +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +}} + +void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#ifdef __ARM_NEON + float32x4_t temp_max = vdupq_n_f32(0); + for (int i=0; i < k / 4; i++) {{ + float32x4_t vec_bs = vld1q_f32(b + 4 * i); + float32x4_t abssum = vabsq_f32(vec_bs); + temp_max = vmaxq_f32(abssum, temp_max); + }} + float32_t scales = 127 / vmaxvq_f32(temp_max); + *lut_scales = scales; +#elif defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + // #pragma unroll + for (int i = 0; i < k / 8; i++) {{ + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + }} + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif +}} + +void partial_max_reset(void* lut_scales_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + *lut_scales = 0.0; +}} + +#ifdef __ARM_NEON +inline void Transpose_8_8( + int16x8_t *v0, + int16x8_t *v1, + int16x8_t *v2, + int16x8_t *v3, + int16x8_t *v4, + int16x8_t *v5, + int16x8_t *v6, + int16x8_t *v7) +{{ + int16x8x2_t q04 = vzipq_s16(*v0, *v4); + int16x8x2_t q15 = vzipq_s16(*v1, *v5); + int16x8x2_t q26 = vzipq_s16(*v2, *v6); + int16x8x2_t q37 = vzipq_s16(*v3, *v7); + + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); + + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); + + *v0 = q_fin_0.val[0]; + *v1 = q_fin_0.val[1]; + *v2 = q_fin_1.val[0]; + *v3 = q_fin_1.val[1]; + *v4 = q_fin_2.val[0]; + *v5 = q_fin_2.val[1]; + *v6 = q_fin_3.val[0]; + *v7 = q_fin_3.val[1]; +}} +#endif + +template +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ +#ifdef __ARM_NEON + int16x8_t vec_lut[16]; + float32_t scales = *lut_scales; + uint8_t tbl_mask[16]; + tbl_mask[0] = 0; + tbl_mask[1] = 2; + tbl_mask[2] = 4; + tbl_mask[3] = 6; + tbl_mask[4] = 8; + tbl_mask[5] = 10; + tbl_mask[6] = 12; + tbl_mask[7] = 14; + tbl_mask[8] = 1; + tbl_mask[9] = 3; + tbl_mask[10] = 5; + tbl_mask[11] = 7; + tbl_mask[12] = 9; + tbl_mask[13] = 11; + tbl_mask[14] = 13; + tbl_mask[15] = 15; + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) {{ + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); + vec_lut[0] = vdupq_n_s16(0); + vec_lut[0] = vec_lut[0] - vec_bs_0; + vec_lut[0] = vec_lut[0] - vec_bs_1; + vec_lut[1] = vdupq_n_s16(0); + vec_lut[1] = vec_lut[1] - vec_bs_0; + vec_lut[2] = vdupq_n_s16(0); + vec_lut[2] = vec_lut[2] - vec_bs_0; + vec_lut[2] = vec_lut[2] + vec_bs_1; + vec_lut[3] = vdupq_n_s16(0); + vec_lut[3] = vec_lut[3] - vec_bs_1; + vec_lut[4] = vdupq_n_s16(0); + vec_lut[5] = vec_bs_1; + vec_lut[6] = vec_bs_0; + vec_lut[6] = vec_lut[6] - vec_bs_1; + vec_lut[7] = vec_bs_0; + vec_lut[8] = vec_bs_0; + vec_lut[8] = vec_lut[8] + vec_bs_1; + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); +#pragma unroll + for (int idx = 0; idx < 8; idx++) {{ + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); + int8x8_t q0_low = vget_low_s8(q0_s); + int8x8_t q0_high = vget_high_s8(q0_s); + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); + int8x8_t q1_low = vget_low_s8(q1_s); + int8x8_t q1_high = vget_high_s8(q1_s); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); + }} + }} +#endif +}} + +static bool is_type_supported(enum ggml_type type) {{ + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL1) {{ + return true; + }} else {{ + return false; + }} +}} +#include + +#define BM1536_4096 256 +#define BBK1536_4096 128 +inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK1536_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM1536_4096; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM1536_4096]; + memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) { + tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM1536_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM1536_1536 128 +#define BBK1536_1536 64 +inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK1536_1536 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM1536_1536; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM1536_1536]; + memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) { + tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)]))); + } +#pragma unroll + for (int i = 0; i < BM1536_1536; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM4096_1536 256 +#define BBK4096_1536 128 +inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK4096_1536 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM4096_1536; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM4096_1536]; + memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) { + tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)]))); + } +#pragma unroll + for (int i = 0; i < BM4096_1536; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; + +template +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); + + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); +}} +void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { + if (m == 1536 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } + else if (m == 1536 && k == 1536) { + preprocessor_k<1536>(B, LUT_Scales, QLUT); + } + else if (m == 4096 && k == 1536) { + preprocessor_k<1536>(B, LUT_Scales, QLUT); + } +} +void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 1536 && k == 4096) { + qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 1536 && k == 1536) { + qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 4096 && k == 1536) { + qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C); + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + const int scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 1536 && k == 4096) { + bm = BM1536_4096; + bk = BBK1536_4096; + } +else if (m == 1536 && k == 1536) { + bm = BM1536_1536; + bk = BBK1536_1536; + } +else if (m == 4096 && k == 1536) { + bm = BM4096_1536; + bk = BBK4096_1536; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .scales_size = */ scales_size, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl2.h b/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl2.h new file mode 100644 index 0000000..92bda56 --- /dev/null +++ b/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl2.h @@ -0,0 +1,1167 @@ +#if defined(GGML_BITNET_X86_TL2) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) { +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +} + +static void aligned_free(void * ptr) { +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +} +#define BK2 32 +#if defined __AVX2__ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi32(va, vb); + *vh = _mm256_unpackhi_epi32(va, vb); +} +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0)); + *vl = _mm256_unpacklo_epi64(va, vb); + *vh = _mm256_unpackhi_epi64(va, vb); +} +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh) +{ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0)); + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1)); +} +inline void Transpose_8_8( + __m256i *v0, + __m256i *v1, + __m256i *v2, + __m256i *v3, + __m256i *v4, + __m256i *v5, + __m256i *v6, + __m256i *v7) +{ + __m256i w0, w1, w2, w3, w4, w5, w6, w7; + __m256i x0, x1, x2, x3, x4, x5, x6, x7; + _mm256_merge_epi32(*v0, *v1, &w0, &w1); + _mm256_merge_epi32(*v2, *v3, &w2, &w3); + _mm256_merge_epi32(*v4, *v5, &w4, &w5); + _mm256_merge_epi32(*v6, *v7, &w6, &w7); + _mm256_merge_epi64(w0, w2, &x0, &x1); + _mm256_merge_epi64(w1, w3, &x2, &x3); + _mm256_merge_epi64(w4, w6, &x4, &x5); + _mm256_merge_epi64(w5, w7, &x6, &x7); + _mm256_merge_si128(x0, x4, v0, v1); + _mm256_merge_si128(x1, x5, v2, v3); + _mm256_merge_si128(x2, x6, v4, v5); + _mm256_merge_si128(x3, x7, v6, v7); +} +#endif +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#if defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + for (int i = 0; i < k / 8; i++) { + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + } + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif + return 0; +} +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) { + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + #pragma unroll + for (int i=0; i< bs; i++) { + lut_scales[i] = 0.0; + } + return 0; +} +template +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 24; ++k) { + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1); + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1); + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1); + + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = vec_b0i; + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i); + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i); + vec_lut[12] = vec_b0i; + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i); + vec_lut[11] = vec_b0i; + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i); + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i); + vec_lut[10] = vec_b0i; + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i); + vec_lut[9] = vec_b0i; + vec_lut[8] = vec_b0i; + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i); + vec_lut[7] = vec_b0i; + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i); + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i); + vec_lut[6] = vec_b0i; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i); + vec_lut[5] = vec_b0i; + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i); + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i); + vec_lut[4] = vec_b1i; + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i); + vec_lut[3] = vec_b1i; + vec_lut[2] = vec_b1i; + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i); + vec_lut[1] = vec_b2i; + vec_lut[0] = _mm256_setzero_si256(); + __m256i ix[16]; + +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + int8_t* qlut_i8 = reinterpret_cast(qlut); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + + *lut_scales = scales; +#endif + return 0; +} + +template +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) { +#if defined __AVX2__ + __m256 vec_lut[16]; + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0); + float scales = *lut_scales; + __m256i shuffle_mask = _mm256_set_epi8( + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00, + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01, + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00 + ); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) { + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1); + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1); + + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + vec_lut[15] = _mm256_setzero_si256(); + vec_lut[14] = _mm256_setzero_si256(); + vec_lut[13] = _mm256_setzero_si256(); + vec_lut[12] = _mm256_setzero_si256(); + vec_lut[11] = _mm256_setzero_si256(); + vec_lut[10] = _mm256_setzero_si256(); + vec_lut[9] = _mm256_setzero_si256(); + vec_lut[8] = vec_b0; + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1); + vec_lut[7] = vec_b0; + vec_lut[6] = vec_b0; + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1); + vec_lut[5] = vec_b1; + vec_lut[4] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_setzero_si256(); + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1); + vec_lut[2] = _mm256_setzero_si256(); + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0); + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1); + vec_lut[1] = _mm256_setzero_si256(); + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0); + vec_lut[0] = _mm256_setzero_si256(); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0); + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1); + + __m256i ix[16]; +#pragma unroll + for (int g = 0; g < 16; ++g) { + ix[g] = vec_lut[g]; + } + + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7])); + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15])); + +#pragma unroll + for (int g = 0; g < 8; ++g) { + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask); + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0)); + } + + int8_t* qlut_i8 = reinterpret_cast(qlut); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]); + + } + *lut_scales = scales; +#endif + return 0; +} +static bool is_type_supported(enum ggml_type type) { + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL2) { + return true; + } else { + return false; + } +} +#include + +#define BM1536_4096 256 +#define BBK1536_4096 96 +template +inline void three_tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK1536_4096 / 3; +#pragma unroll + for (int i = 0; i < BM1536_4096; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_4096 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_4096 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl1536_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM1536_4096; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_4096 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_4096 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_4096 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_4096 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_4096 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_4096 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_4096 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_4096 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_1536_4096(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1536_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1536_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4032 / BBK1536_4096; ++k_outer) { + three_tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 3 / 2 * BM1536_4096)])), (&(((uint8_t*)sign)[(k_outer * BBK1536_4096 / 3 / 8 * BM1536_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1536_4096; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM1536_4096]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1536_4096]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1536_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 64 / 32; ++k_outer) { + two_tbl_impl1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM1536_4096)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1536_4096; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM1536_4096]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM1536_1536 128 +#define BBK1536_1536 192 +template +inline void three_tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK1536_1536 / 3; +#pragma unroll + for (int i = 0; i < BM1536_1536; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_1536 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_1536 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_1536 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_1536 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_1536 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_1536 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_1536 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_1536 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl1536_1536(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM1536_1536; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_1536 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_1536 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_1536 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_1536 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM1536_1536 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM1536_1536 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM1536_1536 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM1536_1536 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_1536_1536(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1536_1536]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1536_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) { + three_tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 3 / 2 * BM1536_1536)])), (&(((uint8_t*)sign)[(k_outer * BBK1536_1536 / 3 / 8 * BM1536_1536)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1536_1536; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM1536_1536]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM1536_1536]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM1536_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 0 / 32; ++k_outer) { + two_tbl_impl1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM1536_1536)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM1536_1536; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM1536_1536]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +#include + +#define BM4096_1536 256 +#define BBK4096_1536 96 +template +inline void three_tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000); + const __m256i vec_zero = _mm256_set1_epi8(0x00); + const __m256i vec_one = _mm256_set1_epi8(0xff); + const int KK = BBK4096_1536 / 3; +#pragma unroll + for (int i = 0; i < BM4096_1536; i += 32) { + __m256i vec_as[KK / 2]; + __m256i vec_signs[KK / 8]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } + #pragma unroll + for (int as = 0; as < KK / 8; as++) { + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + __m256i vec_sign = vec_signs[k]; + __m256i vec_a_0 = vec_as[k * 4 + 0]; + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15); + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15); + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask); + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0); + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0); + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15); + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15); + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask); + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0); + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0); + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0); + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0); + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0); + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0); + __m256i vec_a_1 = vec_as[k * 4 + 1]; + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15); + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15); + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask); + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1); + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1); + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15); + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15); + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask); + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1); + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1); + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1); + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1); + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1); + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1); + __m256i vec_a_2 = vec_as[k * 4 + 2]; + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15); + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15); + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask); + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2); + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2); + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15); + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15); + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask); + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2); + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2); + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2); + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2); + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2); + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2); + __m256i vec_a_3 = vec_as[k * 4 + 3]; + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs)); + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs)); + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs)); + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs)); + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15); + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15); + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask); + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3); + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3); + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15); + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15); + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask); + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3); + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3); + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3); + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3); + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3); + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3); + } + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_1536 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_1536 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_1536 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_1536 * bs)); + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_1536 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_1536 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_1536 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_1536 * bs), vec_gc3); + } + } +#endif +} + +template +inline int32_t two_tbl_impl4096_1536(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __AVX2__ + const __m256i vec_mask = _mm256_set1_epi8(0x0f); + const int KK = BK2 / 2; +#pragma unroll + for (int i = 0; i < BM4096_1536; i += 32) { + __m256i vec_as[KK / 2]; + #pragma unroll + for (int ai = 0; ai < KK / 2; ai++) { + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32)); + } +#pragma unroll + for (int bs = 0; bs < batch_size; bs++) { + __m256i vec_c0 = _mm256_setzero_si256(); + __m256i vec_c1 = _mm256_setzero_si256(); +#pragma unroll + for (int k = 0; k < KK / 8; k++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + __m256i vec_a = vec_as[k * 4 + j]; + + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs)); + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs)); + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs)); + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs)); + + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask); + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top); + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top); + + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask); + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot); + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot); + + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec); + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec); + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi); + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo); + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); + } + } + + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_1536 * bs)); + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_1536 * bs)); + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_1536 * bs)); + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_1536 * bs)); + + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0))); + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1))); + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1))); + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1))); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM4096_1536 * bs), vec_gc0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM4096_1536 * bs), vec_gc1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM4096_1536 * bs), vec_gc2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM4096_1536 * bs), vec_gc3); + } + } +#endif + return 0; +} + +template +int32_t three_qgemm_lut_4096_1536(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_1536]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) { + three_tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 3 / 2 * BM4096_1536)])), (&(((uint8_t*)sign)[(k_outer * BBK4096_1536 / 3 / 8 * BM4096_1536)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_1536; i++) { + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM4096_1536]); + } + } + return 0; +} + +template +int32_t two_qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BATCH_SIZE * BM4096_1536]; + memset(&(CBits[0]), 0, BATCH_SIZE * BM4096_1536 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 0 / 32; ++k_outer) { + two_tbl_impl4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM4096_1536)]))); + } +#pragma unroll + for (int bs = 0; bs < BATCH_SIZE; bs++) { +#pragma unroll + for (int i = 0; i < BM4096_1536; i++) { + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM4096_1536]); + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0]; + } + } + return 0; +} + +void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) { + partial_max_reset(bs, (&(((float*)LUT_Scales)[0]))); + if (m == 1536 && two_k == 64 && three_k == 4032) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<4032>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<64>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 4032])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 1536 && two_k == 0 && three_k == 1536) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<1536>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<0>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 1536])), (&(((float*)LUT_Scales)[b]))); + } + } + else if (m == 4096 && two_k == 0 && three_k == 1536) { + for (int32_t b = 0; b < bs; b++) { + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)]))); + three_lut_ctor<1536>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b]))); + two_lut_ctor<0>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 1536])), (&(((float*)LUT_Scales)[b]))); + } + } +} +void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 1536 && k == 4096) { + if (BK == 64) { + if (bs == 1) { + two_qgemm_lut_1536_4096<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_1536_4096<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_1536_4096<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_1536_4096<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_1536_4096<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_1536_4096<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 4032) { + if (bs == 1) { + three_qgemm_lut_1536_4096<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_1536_4096<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_1536_4096<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_1536_4096<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_1536_4096<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_1536_4096<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 1536 && k == 1536) { + if (BK == 0) { + if (bs == 1) { + two_qgemm_lut_1536_1536<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_1536_1536<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_1536_1536<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_1536_1536<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_1536_1536<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_1536_1536<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 1536) { + if (bs == 1) { + three_qgemm_lut_1536_1536<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_1536_1536<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_1536_1536<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_1536_1536<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_1536_1536<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_1536_1536<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } + else if (m == 4096 && k == 1536) { + if (BK == 0) { + if (bs == 1) { + two_qgemm_lut_4096_1536<1>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 8) { + two_qgemm_lut_4096_1536<8>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 32) { + two_qgemm_lut_4096_1536<32>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 128) { + two_qgemm_lut_4096_1536<128>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 256) { + two_qgemm_lut_4096_1536<256>(A, LUT, Scales, LUT_Scales, C); + } else if (bs == 512) { + two_qgemm_lut_4096_1536<512>(A, LUT, Scales, LUT_Scales, C); + } + } + else if (BK == 1536) { + if (bs == 1) { + three_qgemm_lut_4096_1536<1>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 8) { + three_qgemm_lut_4096_1536<8>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 32) { + three_qgemm_lut_4096_1536<32>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 128) { + three_qgemm_lut_4096_1536<128>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 256) { + three_qgemm_lut_4096_1536<256>(A, sign, LUT, Scales, LUT_Scales, C); + }else if (bs == 512) { + three_qgemm_lut_4096_1536<512>(A, sign, LUT, Scales, LUT_Scales, C); + } + } + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 1536 && k == 4096) { + bm = BM1536_4096; + bk = BBK1536_4096; + } +else if (m == 1536 && k == 1536) { + bm = BM1536_1536; + bk = BBK1536_1536; + } +else if (m == 4096 && k == 1536) { + bm = BM4096_1536; + bk = BBK4096_1536; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini b/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini new file mode 100644 index 0000000..5d94318 --- /dev/null +++ b/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini @@ -0,0 +1,21 @@ +[Kernels_0] +m = 1536 +k = 4096 +bm = 256 +bk = 128 +bmm = 32 + +[Kernels_1] +m = 1536 +k = 1536 +bm = 128 +bk = 64 +bmm = 64 + +[Kernels_2] +m = 4096 +k = 1536 +bm = 256 +bk = 128 +bmm = 32 + diff --git a/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini b/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini new file mode 100644 index 0000000..54d4b83 --- /dev/null +++ b/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini @@ -0,0 +1,21 @@ +[Kernels_0] +m = 1536 +k = 4096 +bm = 256 +bk = 96 +bmm = 32 + +[Kernels_1] +m = 1536 +k = 1536 +bm = 128 +bk = 192 +bmm = 32 + +[Kernels_2] +m = 4096 +k = 1536 +bm = 256 +bk = 96 +bmm = 64 + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3f5c547 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +# These requirements include all dependencies for all top-level python scripts +# for llama.cpp. Avoid adding packages here directly. +# +# Package versions must stay compatible across all top-level python scripts. +# + +-r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt +-r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +-r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +-r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +-r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt \ No newline at end of file diff --git a/run_inference.py b/run_inference.py new file mode 100644 index 0000000..fe14e0e --- /dev/null +++ b/run_inference.py @@ -0,0 +1,53 @@ +import os +import sys +import signal +import platform +import argparse +import subprocess + +def run_command(command, shell=False): + """Run a system command and ensure it succeeds.""" + try: + subprocess.run(command, shell=shell, check=True) + except subprocess.CalledProcessError as e: + print(f"Error occurred while running command: {e}") + sys.exit(1) + +def run_inference(): + build_dir = "build" + if platform.system() == "Windows": + main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe") + if not os.path.exists(main_path): + main_path = os.path.join(build_dir, "bin", "llama-cli") + else: + main_path = os.path.join(build_dir, "bin", "llama-cli") + command = [ + f'{main_path}', + '-m', args.model, + '-n', str(args.n_predict), + '-t', str(args.threads), + '-p', args.prompt, + '-ngl', '0', + '-c', str(args.ctx_size), + '--temp', str(args.temperature), + "-b", "1" + ] + run_command(command) + +def signal_handler(sig, frame): + print("Ctrl+C pressed, exiting...") + sys.exit(0) + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington." + parser = argparse.ArgumentParser(description='Run inference') + parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf") + parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128) + parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True) + parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) + parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048) + parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8) + + args = parser.parse_args() + run_inference() \ No newline at end of file diff --git a/setup_env.py b/setup_env.py new file mode 100644 index 0000000..b9bf5fc --- /dev/null +++ b/setup_env.py @@ -0,0 +1,202 @@ +import subprocess +import signal +import sys +import os +import platform +import argparse +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger("setup_env") + +SUPPORTED_HF_MODELS = { + "1bitLLM/bitnet_b1_58-large": { + "model_name": "bitnet_b1_58-large", + }, + "1bitLLM/bitnet_b1_58-3B": { + "model_name": "bitnet_b1_58-3B", + }, + "HF1BitLLM/Llama3-8B-1.58-100B-tokens": { + "model_name": "Llama3-8B-1.58-100B-tokens", + } +} + +SUPPORTED_QUANT_TYPES = { + "arm64": ["i2_s", "tl1"], + "x86_64": ["i2_s", "tl2"] +} + +COMPILER_EXTRA_ARGS = { + "arm64": ["-DBITNET_ARM_TL1=ON"], + "x86_64": ["-DBITNET_X86_TL2=ON"] +} + +OS_EXTRA_ARGS = { + "Windows":["-T", "ClangCL"], + "Linux": ["-DCMAKE_C_COMPILER=clang", "-DCMAKE_CXX_COMPILER=clang++"] +} + +ARCH_ALIAS = { + "AMD64": "x86_64", + "x86": "x86_64", + "x86_64": "x86_64", + "aarch64": "arm64", + "arm64": "arm64", + "ARM64": "arm64", +} + +def system_info(): + return platform.system(), ARCH_ALIAS[platform.machine()] + +def get_model_name(): + if args.hf_repo: + return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"] + return os.path.basename(os.path.normpath(args.model_dir)) + +def run_command(command, shell=False, log_step=None): + """Run a system command and ensure it succeeds.""" + if log_step: + log_file = os.path.join(args.log_dir, log_step + ".log") + with open(log_file, "w") as f: + try: + subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) + except subprocess.CalledProcessError as e: + logging.error(f"Error occurred while running command: {e}, check details in {log_file}") + sys.exit(1) + else: + try: + subprocess.run(command, shell=shell, check=True) + except subprocess.CalledProcessError as e: + logging.error(f"Error occurred while running command: {e}") + sys.exit(1) + +def prepare_model(): + _, arch = system_info() + hf_url = args.hf_repo + model_dir = args.model_dir + quant_type = args.quant_type + quant_embd = args.quant_embd + if hf_url is not None: + # download the model + model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"]) + Path(model_dir).mkdir(parents=True, exist_ok=True) + logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...") + run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model") + elif not os.path.exists(model_dir): + logging.error(f"Model directory {model_dir} does not exist.") + sys.exit(1) + else: + logging.info(f"Loading model from directory {model_dir}.") + gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf") + if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0: + logging.info(f"Converting HF model to GGUF format...") + if quant_type.startswith("tl"): + run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl") + else: # i2s + # convert to f32 + run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf") + f32_model = os.path.join(model_dir, "ggml-model-f32.gguf") + i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf") + # quantize to i2s + if platform.system() != "Windows": + if quant_embd: + run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") + else: + run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") + else: + if quant_embd: + run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") + else: + run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") + + logging.info(f"GGUF model saved at {gguf_path}") + else: + logging.info(f"GGUF model already exists at {gguf_path}") + +def setup_gguf(): + # Install the pip package + run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf") + +def gen_code(): + _, arch = system_info() + if arch == "arm64": + if args.use_pretuned: + pretuned_kernels = os.path.join("preset_kernels", get_model_name()) + if not os.path.exists(pretuned_kernels): + logging.error(f"Pretuned kernels not found for model {args.hf_repo}") + sys.exit(1) + if args.quant_type == "tl1": + shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h") + shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini") + elif args.quant_type == "tl2": + shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") + shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini") + if get_model_name() == "bitnet_b1_58-large": + run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen") + elif get_model_name() == "Llama3-8B-1.58-100B-tokens": + run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen") + elif get_model_name() == "bitnet_b1_58-3B": + run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") + else: + raise NotImplementedError() + else: + if args.use_pretuned: + # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h + pretuned_kernels = os.path.join("preset_kernels", get_model_name()) + if not os.path.exists(pretuned_kernels): + logging.error(f"Pretuned kernels not found for model {args.hf_repo}") + sys.exit(1) + shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") + if get_model_name() == "bitnet_b1_58-large": + run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen") + elif get_model_name() == "Llama3-8B-1.58-100B-tokens": + run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") + elif get_model_name() == "bitnet_b1_58-3B": + run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen") + else: + raise NotImplementedError() + + +def compile(): + # Check if cmake is installed + cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True) + if cmake_exists.returncode != 0: + logging.error("Cmake is not available. Please install CMake and try again.") + sys.exit(1) + _, arch = system_info() + if arch not in COMPILER_EXTRA_ARGS.keys(): + logging.error(f"Arch {arch} is not supported yet") + exit(0) + logging.info("Compiling the code using CMake.") + run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files") + # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"]) + run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile") + +def main(): + setup_gguf() + gen_code() + compile() + prepare_model() + +def parse_args(): + _, arch = system_info() + parser = argparse.ArgumentParser(description='Setup the environment for running the inference') + parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys()) + parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models") + parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs") + parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s") + parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16") + parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters") + return parser.parse_args() + +def signal_handler(sig, frame): + logging.info("Ctrl+C pressed, exiting...") + sys.exit(0) + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + args = parse_args() + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + logging.basicConfig(level=logging.INFO) + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..9cead70 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,10 @@ +set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h) +set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp) +set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp) + +include_directories(3rdparty/llama.cpp/ggml/include) + +if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR +(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")) + message(FATAL_ERROR "Clang is required for Bitnet.cpp compilation") +endif() \ No newline at end of file diff --git a/src/ggml-bitnet-lut.cpp b/src/ggml-bitnet-lut.cpp new file mode 100644 index 0000000..d6bea16 --- /dev/null +++ b/src/ggml-bitnet-lut.cpp @@ -0,0 +1,163 @@ +#include +#include + +#include "ggml-bitnet.h" +#include "ggml-quants.h" +#include "bitnet-lut-kernels.h" + +#if defined(GGML_BITNET_ARM_TL1) + +void ggml_bitnet_init(void) { + // LOG(INFO) << "ggml_bitnet_init"; + + if (initialized) { + return; + } + initialized = true; + + // if (wrapper == nullptr) { + // wrapper = new BITNET::BITNETGeMMWrapper(); + // } + if (bitnet_tensor_extras == nullptr) { + bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; + } + bitnet_tensor_extras_index = 0; +} + +void ggml_bitnet_free(void) { + // LOG(INFO) << "ggml_bitnet_free"; + + if (!initialized) { + return; + } + initialized = false; + + // delete wrapper; + // wrapper = nullptr; + for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { + // aligned_free(bitnet_tensor_extras[i].qweights); + // aligned_free(bitnet_tensor_extras[i].scales); + } + delete[] bitnet_tensor_extras; + bitnet_tensor_extras = nullptr; +} + +static bool do_permutate(enum ggml_type type) { + if (type == GGML_TYPE_TL1) { + // Add additional args to decide if permuted I2 or naive I2 + return false; + } else { + return true; + } +} + +bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + if ((is_type_supported(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + src0->backend == GGML_BACKEND_TYPE_CPU) { + if (src1->ne[1] <= 1) { + return true; + } + } + return false; +} + +size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + const size_t ne01 = src0->ne[1]; + const size_t ne10 = src1->ne[0]; + const size_t ne11 = src1->ne[1]; + const int bits = ggml_bitnet_get_type_bits(src0->type); + + size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type); + if (sizeof(bitnet_float_type) == 2) { + // Need fp32 to fp16 conversion + wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); + } + wsize = ((wsize - 1) / 64 + 1) * 64; + return wsize; +} + +int ggml_bitnet_get_type_bits(enum ggml_type type) { + switch (type) { + case GGML_TYPE_TL1: + return 2; + case GGML_TYPE_Q4_0: + return 4; + default: + return 0; + } +} + +#endif +#if defined(GGML_BITNET_X86_TL2) +void ggml_bitnet_init(void) { + // LOG(INFO) << "ggml_bitnet_init"; + + if (initialized) { + return; + } + initialized = true; + + // if (wrapper == nullptr) { + // wrapper = new BITNET::BITNETGeMMWrapper(); + // } + if (bitnet_tensor_extras == nullptr) { + bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; + } + bitnet_tensor_extras_index = 0; +} + +void ggml_bitnet_free(void) { + // LOG(INFO) << "ggml_bitnet_free"; + + if (!initialized) { + return; + } + initialized = false; + + // delete wrapper; + // wrapper = nullptr; + for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { + // aligned_free(bitnet_tensor_extras[i].qweights); + // aligned_free(bitnet_tensor_extras[i].scales); + } + delete[] bitnet_tensor_extras; + bitnet_tensor_extras = nullptr; +} + +bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + if ((is_type_supported(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + src0->backend == GGML_BACKEND_TYPE_CPU) { + return true; + } + return false; +} + +size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { + const size_t ne01 = src0->ne[1]; + const size_t ne10 = src1->ne[0]; + const size_t ne11 = src1->ne[1]; + + size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type); + if (sizeof(bitnet_float_type) == 2) { + // Need fp32 to fp16 conversion + wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); + } + wsize = ((wsize - 1) / 64 + 1) * 64; + return wsize; +} + +int ggml_bitnet_get_type_bits(enum ggml_type type) { + switch (type) { + case GGML_TYPE_TL2: + return 2; + case GGML_TYPE_Q4_0: + return 4; + default: + return 0; + } +} +#endif \ No newline at end of file diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp new file mode 100644 index 0000000..f75e6ca --- /dev/null +++ b/src/ggml-bitnet-mad.cpp @@ -0,0 +1,361 @@ +#include +#include + +#include "ggml-bitnet.h" +#include "ggml-quants.h" +#include +#include + +#define QK_I2_S 128 +#define QK_I2 128 + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +#include +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +#elif defined(__loongarch_asx) +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} +#endif + +size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + // 2 bits per weight + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + + int n = nrow * n_per_row; + + // f32 -> q8 + double max = 0; + for (int i = 0; i < n; ++i) { + max = fmax(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); + for (int i=0; i 0 ? 2 : 0; + } + + memset(dst, 0, n * sizeof(uint8_t) / 4); + + // q8 -> 0, 1, 2 + // | | | + // -1, 0, 1 + + uint8_t* i2_weight = (uint8_t*)dst; + for (int i = 0; i < n / QK_I2; i++) { + for (int j = 0; j < QK_I2; j++) { + int group_idx = j / 32; + int group_pos = j % 32; + uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx)); + i2_weight[i * 32 + group_pos] |= temp; + } + } + + float* scale_ptr = (float*)((char*)i2_weight + n / 4); + scale_ptr[0] = i2_scale; + + // 32B for alignment + return nrow * row_size / 4 + 32; +} + +void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + +#if defined(__AVX2__) + + __m256i mask = _mm256_set1_epi8(0x03); + __m256i accu = _mm256_setzero_si256(); + + for (int i=0; i < group32_num; i++){ + __m256i accu32 = _mm256_setzero_si256(); + for (int j=0; j < 32; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96)); + + // 128 index accumulation add + // split into 32 accumulation block + // each block each 128 index accumulated 4index + // each index maximum 256 + // each block maximum 4 * 256 + // each block accumulation maximum 127 * 256 + // each 32 group index (128 index in one group) needs cast to int32 + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); + } + accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu); + } + + for (int i = 0; i < groupla_num; i++){ + __m256i accula = _mm256_setzero_si256(); + for (int j = 0; j < la_num; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96)); + + // 128 index accumulation add + // split into 32 accumulation block + // each block each 128 index accumulated 4index + // each index maximum 256 + // each block maximum 4 * 256 + // each block accumulation maximum 127 * 256 + // each 32 group index (128 index in one group) needs cast to int32 + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); + } + accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1))); + } + int sumi = hsum_i32_8(accu); + *s = (float)sumi; + +#elif defined(__ARM_NEON) + + int32x4_t accu_0 = vdupq_n_s32(0); + int32x4_t accu_1 = vdupq_n_s32(0); + int32x4_t accu_2 = vdupq_n_s32(0); + int32x4_t accu_3 = vdupq_n_s32(0); + const uint8x16_t mask = vdupq_n_u8(3); + + for (int i=0; i < group32_num; i++) { + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accu32_0 = vdupq_n_s16(0); + int16x8_t accu32_1 = vdupq_n_s16(0); + int16x8_t accu32_2 = vdupq_n_s16(0); + int16x8_t accu32_3 = vdupq_n_s16(0); +#endif + + for (int j=0; j < 32; j++) { + uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32); + uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16); + uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); + uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); + uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); + int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); + int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); + int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); + + const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48); + const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64); + const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80); + const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96); + const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); + accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); + accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); + accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); + accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); + accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); + accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); + accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); +#else + accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); + accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); + accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); + accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); + accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); + accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); + accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); + accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); +#endif + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0))); + accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0)); + accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1))); + accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1)); + accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2))); + accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2)); + accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3))); + accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3)); +#endif + } + + for (int i = 0; i < groupla_num; i++){ +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accula_0 = vdupq_n_s16(0); + int16x8_t accula_1 = vdupq_n_s16(0); + int16x8_t accula_2 = vdupq_n_s16(0); + int16x8_t accula_3 = vdupq_n_s16(0); +#endif + for (int j = 0; j < la_num; j++) { + uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32); + uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16); + uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); + uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); + uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); + int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); + int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); + int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); + + const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48); + const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64); + const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80); + const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96); + const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); + accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); + accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); + accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); + accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); + accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); + accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); + accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); +#else + accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); + accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); + accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); + accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); + accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); + accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); + accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); + accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); +#endif + } +#if defined(__ARM_FEATURE_DOTPROD) + +#else + accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0))); + accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0)); + accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1))); + accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1)); + accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2))); + accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2)); + accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3))); + accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3)); +#endif + } + accu_0 = vaddq_s32(accu_0, accu_1); + accu_2 = vaddq_s32(accu_2, accu_3); + accu_0 = vaddq_s32(accu_0, accu_2); + int sumi = vaddlvq_s32(accu_0); + *s = (float)sumi; + +#endif +} \ No newline at end of file diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py new file mode 100644 index 0000000..4c2e7dd --- /dev/null +++ b/utils/codegen_tl1.py @@ -0,0 +1,442 @@ +import argparse +import os +from configparser import ConfigParser + +def gen_ctor_code(): + kernel_code = "\n\ +#include \"ggml-bitnet.h\"\n\ +#define GGML_BITNET_MAX_NODES 8192\n\ +static bool initialized = false;\n\ +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ +static size_t bitnet_tensor_extras_index = 0;\n\ +static void * aligned_malloc(size_t size) {{\n\ +#if defined(_WIN32)\n\ + return _aligned_malloc(size, 64);\n\ +#else\n\ + void * ptr = nullptr;\n\ + posix_memalign(&ptr, 64, size);\n\ + return ptr;\n\ +#endif\n\ +}}\n\ +static void aligned_free(void * ptr) {{\n\ +#if defined(_WIN32)\n\ + _aligned_free(ptr);\n\ +#else\n\ + free(ptr);\n\ +#endif\n\ +}}\n\ +\n\ +void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ + bitnet_float_type* b = (bitnet_float_type*)b_;\n\ +#ifdef __ARM_NEON\n\ + float32x4_t temp_max = vdupq_n_f32(0);\n\ + for (int i=0; i < k / 4; i++) {{\n\ + float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\ + float32x4_t abssum = vabsq_f32(vec_bs);\n\ + temp_max = vmaxq_f32(abssum, temp_max);\n\ + }}\n\ + float32_t scales = 127 / vmaxvq_f32(temp_max);\n\ + *lut_scales = scales;\n\ +#elif defined __AVX2__\n\ + __m256 max_vec = _mm256_set1_ps(0.f);\n\ + const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ + // #pragma unroll\n\ + for (int i = 0; i < k / 8; i++) {{\n\ + __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ + max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ + }}\n\ + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ + float scales = 127 / _mm_cvtss_f32(max1);\n\ + *lut_scales = scales;\n\ +#endif\n\ +}}\n\ +\n\ +void partial_max_reset(void* lut_scales_) {{\n\ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ + *lut_scales = 0.0;\n\ +}}\n\ +\n\ +#ifdef __ARM_NEON\n\ +inline void Transpose_8_8(\n\ + int16x8_t *v0,\n\ + int16x8_t *v1,\n\ + int16x8_t *v2,\n\ + int16x8_t *v3,\n\ + int16x8_t *v4,\n\ + int16x8_t *v5,\n\ + int16x8_t *v6,\n\ + int16x8_t *v7)\n\ +{{\n\ + int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\ + int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\ + int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\ + int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\ +\n\ + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\ + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\ + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\ + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\ +\n\ + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\ + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\ + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\ + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\ +\n\ + *v0 = q_fin_0.val[0];\n\ + *v1 = q_fin_0.val[1];\n\ + *v2 = q_fin_1.val[0];\n\ + *v3 = q_fin_1.val[1];\n\ + *v4 = q_fin_2.val[0];\n\ + *v5 = q_fin_2.val[1];\n\ + *v6 = q_fin_3.val[0];\n\ + *v7 = q_fin_3.val[1];\n\ +}}\n\ +#endif\n\ +\n\ +template\n\ +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\ +#ifdef __ARM_NEON\n\ + int16x8_t vec_lut[16];\n\ + float32_t scales = *lut_scales;\n\ + uint8_t tbl_mask[16];\n\ + tbl_mask[0] = 0;\n\ + tbl_mask[1] = 2;\n\ + tbl_mask[2] = 4;\n\ + tbl_mask[3] = 6;\n\ + tbl_mask[4] = 8;\n\ + tbl_mask[5] = 10;\n\ + tbl_mask[6] = 12;\n\ + tbl_mask[7] = 14;\n\ + tbl_mask[8] = 1;\n\ + tbl_mask[9] = 3;\n\ + tbl_mask[10] = 5;\n\ + tbl_mask[11] = 7;\n\ + tbl_mask[12] = 9;\n\ + tbl_mask[13] = 11;\n\ + tbl_mask[14] = 13;\n\ + tbl_mask[15] = 15;\n\ + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\ +#pragma unroll\n\ + for (int k = 0; k < act_k / 16; ++k) {{\n\ + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\ + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\ + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\ + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\ + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\ + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\ + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\ + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\ + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\ + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\ + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\ + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\ + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\ + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\ + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\ + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\ + vec_lut[0] = vdupq_n_s16(0);\n\ + vec_lut[0] = vec_lut[0] - vec_bs_0;\n\ + vec_lut[0] = vec_lut[0] - vec_bs_1;\n\ + vec_lut[1] = vdupq_n_s16(0);\n\ + vec_lut[1] = vec_lut[1] - vec_bs_0;\n\ + vec_lut[2] = vdupq_n_s16(0);\n\ + vec_lut[2] = vec_lut[2] - vec_bs_0;\n\ + vec_lut[2] = vec_lut[2] + vec_bs_1;\n\ + vec_lut[3] = vdupq_n_s16(0);\n\ + vec_lut[3] = vec_lut[3] - vec_bs_1;\n\ + vec_lut[4] = vdupq_n_s16(0);\n\ + vec_lut[5] = vec_bs_1;\n\ + vec_lut[6] = vec_bs_0;\n\ + vec_lut[6] = vec_lut[6] - vec_bs_1;\n\ + vec_lut[7] = vec_bs_0;\n\ + vec_lut[8] = vec_bs_0;\n\ + vec_lut[8] = vec_lut[8] + vec_bs_1;\n\ + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\ + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\ + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\ + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\ +#pragma unroll\n\ + for (int idx = 0; idx < 8; idx++) {{\n\ + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\ + int8x8_t q0_low = vget_low_s8(q0_s);\n\ + int8x8_t q0_high = vget_high_s8(q0_s);\n\ + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\ + int8x8_t q1_low = vget_low_s8(q1_s);\n\ + int8x8_t q1_high = vget_high_s8(q1_s);\n\ + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\ + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\ + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\ + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\ + }}\n\ + }}\n\ +#endif\n\ +}}\n\ +\n\ +static bool is_type_supported(enum ggml_type type) {{\n\ + if (type == GGML_TYPE_Q4_0 ||\n\ + type == GGML_TYPE_TL1) {{\n\ + return true;\n\ + }} else {{\n\ + return false;\n\ + }}\n\ +}}\n\ +" + return kernel_code + +def gen_body_core_code(bm, by): + length = 4 + all_code = "" + for i in range(length): + core_code = "\n\ + uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\ + uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\ + uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\ + int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\ + int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\ + int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\ + int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ + int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ + int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ + vec_c[{6}] += vec_v_left_{0}.val[0];\n\ + vec_c[{6}] += vec_v_right_{0}.val[0];\n\ + vec_c[{7}] += vec_v_left_{0}.val[1];\n\ + vec_c[{7}] += vec_v_right_{0}.val[1];\n\ + ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) + + all_code = "".join([all_code, core_code]) + + all_code = "".join([all_code, "\n }\n\n"]) + + for i in range(bm // 8): + core_code = "\ + int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\ + int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\ + vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\ + vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4) + all_code = "".join([all_code, core_code]) + + return all_code + +def gen_tbl_impl(pre, BM, BK, bm, k): + + kernel_code = "\ +#include \n\ +\n\ +#define BM{0} {1}\n\ +#define BBK{0} {2}\n\ +inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ +#ifdef __ARM_NEON\n\ + const int KK = BBK{0} / 2;\n\ + const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ + const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ + int8x16_t vec_lut[2 * KK];\n\ +".format(pre, BM, BK) + + kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)]) + + kernel_code = "".join([kernel_code, "\n\ +#pragma unroll\n\ + for (int k = 0; k < 2 * KK; k++) {\n\ + vec_lut[k] = vld1q_s8(lut + k * 16);\n\ + }\n"]) + + pre_core_code = "\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{}; i += {}) {{\n\ + #pragma unroll\n\ + for (int i=0; i<{}; i++) {{\n\ + vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\ + }}\n".format(pre, bm, bm // 8) + + body_core_pre_code = "\n\ +#pragma unroll\n\ + for (int k = 0; k < KK / {}; k++) {{\n\ + ".format(256 // bm // 2) + + body_core_post_code = "\n\ + }\n\ +\ +#endif\n\ +}\n" + + kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code]) + + kernel_code = "".join([kernel_code, "\n\ +int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ + alignas({1}) uint32_t CBits[BM{0}];\n\ + memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\ +#pragma unroll\n\ + for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\ + tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\ + }}\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{0}; i++) {{\n\ + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\ + }}\n\ + return 0;\n\ +}};\n".format(pre, min(32, BK), k)]) + + return kernel_code + +def gen_top_api(kernel_shapes): + + kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\ + if (m == {0} && k == {1}) {{\n\ + preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ + }}\n\ +".format(kernel_shapes[0][0], kernel_shapes[0][1]) + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ + preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ + }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) + kernel_code = "".join([kernel_code, "}\n"]) + kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ + if (m == {0} && k == {1}) {{\n\ + qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ +".format(kernel_shapes[0][0], kernel_shapes[0][1])]) + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ + qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ +".format(kernel_shapes[i][0], kernel_shapes[i][1])]) + kernel_code = "".join([kernel_code, "}\n"]) + return kernel_code + +def gen_preprocess_code(): + kernel_code = "\n\ +template\n\ +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\ + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\ + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\ + \n\ + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\ +}}\n" + return kernel_code + +def gen_transform_code(kernel_shape): + kernel_code = "\n\ +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ + return;\n\ + }\n\ +\n\ + int k = tensor->ne[0];\n\ + int m = tensor->ne[1];\n\ + const int lut_scales_size = 1;\n\ + const int scales_size = 1;\n\ + int bk = 0;\n\ + int bm = 0;\n" + + kernel_code = "".join([kernel_code, "\n\ + if (m == {0} && k == {1}) {{\n\ + bm = BM{0}_{1};\n\ + bk = BBK{0}_{1};\n\ + }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) + + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ + bm = BM{0}_{1};\n\ + bk = BBK{0}_{1};\n\ + }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) + + kernel_code = "".join([kernel_code, "\n\ + const int n_tile_num = m / bm;\n\ + const int BK = bk;\n\ + uint8_t * qweights;\n\ + bitnet_float_type * scales;\n\ +\n\ + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ + qweights = (uint8_t *) tensor->data;\n\ + float * i2_scales = (float * )(qweights + k * m / 4);\n\ + scales[0] = (bitnet_float_type) i2_scales[0];\n\ +\n\ + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ + bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ + /* .lut_scales_size = */ lut_scales_size,\n\ + /* .BK = */ BK,\n\ + /* .n_tile_num = */ n_tile_num,\n\ + /* .qweights = */ qweights,\n\ + /* .scales = */ scales\n\ + };\n\ +}\n"]) + + return kernel_code + +if __name__ == "__main__": + ModelShapeDict = { + "bitnet_b1_58-large" : [[1536, 4096], + [1536, 1536], + [4096, 1536]], + "bitnet_b1_58-3B" : [[3200, 8640], + [3200, 3200], + [8640, 3200]], + "Llama3-8B-1.58-100B-tokens" : [[14336, 4096], + [4096, 14336], + [1024, 4096], + [4096, 4096]] + } + + parser = argparse.ArgumentParser(description='gen impl') + parser.add_argument('--model',default="input", type=str, dest="model", + help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.") + parser.add_argument('--BM',default="input", type=str, + help="block length when cutting one weight (M, K) into M / BM weights (BM, K).") + parser.add_argument('--BK',default="input", type=str, + help="block length when cutting one weight (M, K) into K / BK weights (M, BK).") + parser.add_argument('--bm',default="input", type=str, + help="using simd instructions to compute (bm, 256 / bm) in one block") + args = parser.parse_args() + + kernel_shapes = ModelShapeDict[args.model] + + BM_list = [int(item) for item in args.BM.split(',')] + BK_list = [int(item) for item in args.BK.split(',')] + bm_list = [int(item) for item in args.bm.split(',')] + + assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes)) + + for i in range(len(kernel_shapes)): + assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0" + assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0" + assert bm_list[i] in [32, 64], "choose bm from [32, 64]" + + tbl_impl_code = [] + + for i in range(len(kernel_shapes)): + tbl_impl_code.append( + gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1]) + ) + api_code = gen_top_api(kernel_shapes) + pre_code = gen_preprocess_code() + ctor_code = gen_ctor_code() + trans_code = gen_transform_code(kernel_shapes) + + output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") + + with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: + f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)")) + f.write(''.join(ctor_code)) + for code in tbl_impl_code: + f.write(''.join(code)) + f.write(''.join(pre_code)) + f.write(''.join(api_code)) + f.write(''.join(trans_code)) + f.write(''.join("#endif")) + + config = ConfigParser() + + for i in range(len(kernel_shapes)): + config.add_section('Kernels_{}'.format(i)) + config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0])) + config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1])) + config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i])) + config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) + config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) + + with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: + config.write(configfile) \ No newline at end of file diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py new file mode 100644 index 0000000..44d2418 --- /dev/null +++ b/utils/codegen_tl2.py @@ -0,0 +1,757 @@ +import argparse +import os +from configparser import ConfigParser + +def gen_ctor_code(): + kernel_code = "\n\ +#include \"ggml-bitnet.h\"\n\ +#include \n\ +#include \n\ +#define GGML_BITNET_MAX_NODES 8192\n\ +static bool initialized = false;\n\ +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ +static size_t bitnet_tensor_extras_index = 0;\n\ +static void * aligned_malloc(size_t size) {\n\ +#if defined(_WIN32)\n\ + return _aligned_malloc(size, 64);\n\ +#else\n\ + void * ptr = nullptr;\n\ + posix_memalign(&ptr, 64, size);\n\ + return ptr;\n\ +#endif\n\ +}\n\ +\n\ +static void aligned_free(void * ptr) {\n\ +#if defined(_WIN32)\n\ + _aligned_free(ptr);\n\ +#else\n\ + free(ptr);\n\ +#endif\n\ +}\n\ +#define BK2 32\n\ +#if defined __AVX2__\n\ +inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ +{\n\ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\ + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\ + *vl = _mm256_unpacklo_epi32(va, vb);\n\ + *vh = _mm256_unpackhi_epi32(va, vb);\n\ +}\n\ +inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ +{\n\ + __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));\n\ + __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));\n\ + *vl = _mm256_unpacklo_epi64(va, vb);\n\ + *vh = _mm256_unpackhi_epi64(va, vb);\n\ +}\n\ +inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)\n\ +{\n\ + *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0));\n\ + *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1));\n\ +}\n\ +inline void Transpose_8_8(\n\ + __m256i *v0,\n\ + __m256i *v1,\n\ + __m256i *v2,\n\ + __m256i *v3,\n\ + __m256i *v4,\n\ + __m256i *v5,\n\ + __m256i *v6,\n\ + __m256i *v7)\n\ +{\n\ + __m256i w0, w1, w2, w3, w4, w5, w6, w7;\n\ + __m256i x0, x1, x2, x3, x4, x5, x6, x7;\n\ + _mm256_merge_epi32(*v0, *v1, &w0, &w1);\n\ + _mm256_merge_epi32(*v2, *v3, &w2, &w3);\n\ + _mm256_merge_epi32(*v4, *v5, &w4, &w5);\n\ + _mm256_merge_epi32(*v6, *v7, &w6, &w7);\n\ + _mm256_merge_epi64(w0, w2, &x0, &x1);\n\ + _mm256_merge_epi64(w1, w3, &x2, &x3);\n\ + _mm256_merge_epi64(w4, w6, &x4, &x5);\n\ + _mm256_merge_epi64(w5, w7, &x6, &x7);\n\ + _mm256_merge_si128(x0, x4, v0, v1);\n\ + _mm256_merge_si128(x1, x5, v2, v3);\n\ + _mm256_merge_si128(x2, x6, v4, v5);\n\ + _mm256_merge_si128(x3, x7, v6, v7);\n\ +}\n\ +#endif\n\ +inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) {\n\ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ + bitnet_float_type* b = (bitnet_float_type*)b_;\n\ +#if defined __AVX2__\n\ + __m256 max_vec = _mm256_set1_ps(0.f);\n\ + const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ + for (int i = 0; i < k / 8; i++) {\n\ + __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ + max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ + }\n\ + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ + float scales = 127 / _mm_cvtss_f32(max1);\n\ + *lut_scales = scales;\n\ +#endif\n\ + return 0;\n\ +}\n\ +inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {\n\ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ + #pragma unroll\n\ + for (int i=0; i< bs; i++) {\n\ + lut_scales[i] = 0.0;\n\ + }\n\ + return 0;\n\ +}\n\ +template\n\ +inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ +#if defined __AVX2__\n\ + __m256 vec_lut[16];\n\ + const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);\n\ + float scales = *lut_scales;\n\ + __m256i shuffle_mask = _mm256_set_epi8(\n\ + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\ + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\ + );\n\ +#pragma unroll\n\ + for (int k = 0; k < act_k / 24; ++k) {\n\ + __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1);\n\ + __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1);\n\ + __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1);\n\ +\n\ + __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ + __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ + __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ +\n\ + vec_lut[15] = _mm256_setzero_si256();\n\ + vec_lut[14] = _mm256_setzero_si256();\n\ + vec_lut[13] = vec_b0i;\n\ + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i);\n\ + vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i);\n\ + vec_lut[12] = vec_b0i;\n\ + vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i);\n\ + vec_lut[11] = vec_b0i;\n\ + vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i);\n\ + vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i);\n\ + vec_lut[10] = vec_b0i;\n\ + vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i);\n\ + vec_lut[9] = vec_b0i;\n\ + vec_lut[8] = vec_b0i;\n\ + vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i);\n\ + vec_lut[7] = vec_b0i;\n\ + vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i);\n\ + vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i);\n\ + vec_lut[6] = vec_b0i;\n\ + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i);\n\ + vec_lut[5] = vec_b0i;\n\ + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i);\n\ + vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i);\n\ + vec_lut[4] = vec_b1i;\n\ + vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i);\n\ + vec_lut[3] = vec_b1i;\n\ + vec_lut[2] = vec_b1i;\n\ + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i);\n\ + vec_lut[1] = vec_b2i;\n\ + vec_lut[0] = _mm256_setzero_si256();\n\ + __m256i ix[16];\n\ +\n\ +#pragma unroll\n\ + for (int g = 0; g < 16; ++g) {\n\ + ix[g] = vec_lut[g];\n\ + }\n\ +\n\ + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\ + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\ +\n\ +#pragma unroll\n\ + for (int g = 0; g < 8; ++g) {\n\ + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\ + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\ + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ + }\n\ + int8_t* qlut_i8 = reinterpret_cast(qlut);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\ +\n\ + }\n\ +\n\ + *lut_scales = scales;\n\ +#endif\n\ + return 0;\n\ +}\n\ +\n\ +template\n\ +inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {\n\ +#if defined __AVX2__\n\ + __m256 vec_lut[16];\n\ + const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);\n\ + float scales = *lut_scales;\n\ + __m256i shuffle_mask = _mm256_set_epi8(\n\ + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,\n\ + 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,\n\ + 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00\n\ + );\n\ +#pragma unroll\n\ + for (int k = 0; k < act_k / 16; ++k) {\n\ + __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1);\n\ + __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1);\n\ +\n\ + __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ + __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n\ + vec_lut[15] = _mm256_setzero_si256();\n\ + vec_lut[14] = _mm256_setzero_si256();\n\ + vec_lut[13] = _mm256_setzero_si256();\n\ + vec_lut[12] = _mm256_setzero_si256();\n\ + vec_lut[11] = _mm256_setzero_si256();\n\ + vec_lut[10] = _mm256_setzero_si256();\n\ + vec_lut[9] = _mm256_setzero_si256();\n\ + vec_lut[8] = vec_b0;\n\ + vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1);\n\ + vec_lut[7] = vec_b0;\n\ + vec_lut[6] = vec_b0;\n\ + vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1);\n\ + vec_lut[5] = vec_b1;\n\ + vec_lut[4] = _mm256_setzero_si256();\n\ + vec_lut[3] = _mm256_setzero_si256();\n\ + vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1);\n\ + vec_lut[2] = _mm256_setzero_si256();\n\ + vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0);\n\ + vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1);\n\ + vec_lut[1] = _mm256_setzero_si256();\n\ + vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0);\n\ + vec_lut[0] = _mm256_setzero_si256();\n\ + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0);\n\ + vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1);\n\ +\n\ + __m256i ix[16];\n\ +#pragma unroll\n\ + for (int g = 0; g < 16; ++g) {\n\ + ix[g] = vec_lut[g];\n\ + }\n\ +\n\ + Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));\n\ + Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));\n\ +\n\ +#pragma unroll\n\ + for (int g = 0; g < 8; ++g) {\n\ + ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);\n\ + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ + ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);\n\ + ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));\n\ + }\n\ +\n\ + int8_t* qlut_i8 = reinterpret_cast(qlut);\n\ +\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);\n\ +\n\ + }\n\ + *lut_scales = scales;\n\ +#endif\n\ + return 0;\n\ +}\n\ +static bool is_type_supported(enum ggml_type type) {\n\ + if (type == GGML_TYPE_Q4_0 ||\n\ + type == GGML_TYPE_TL2) {\n\ + return true;\n\ + } else {\n\ + return false;\n\ + }\n\ +}\n\ +" + return kernel_code + +def gen_tbl_impl(pre, BM, BK, bm, k_list): + + kernel_code = "\ +#include \n\ +\n\ +#define BM{0} {1}\n\ +#define BBK{0} {2}\n\ +template\n\ +inline void three_tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {{\n\ +".format(pre, BM, BK) + + kernel_code = "".join([kernel_code, "\ +#ifdef __AVX2__\n\ + const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\ + const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);\n\ + const __m256i vec_zero = _mm256_set1_epi8(0x00);\n\ + const __m256i vec_one = _mm256_set1_epi8(0xff);\n\ + const int KK = BBK{0} / 3;\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{0}; i += 32) {{\n\ + __m256i vec_as[KK / 2];\n\ + __m256i vec_signs[KK / 8];\n\ + #pragma unroll\n\ + for (int ai = 0; ai < KK / 2; ai++) {{\n\ + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\ + }}\n\ + #pragma unroll\n\ + for (int as = 0; as < KK / 8; as++) {{\n\ + vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));\n\ + }}\n\ +#pragma unroll\n\ + for (int bs = 0; bs < batch_size; bs++) {{\n\ + __m256i vec_c0 = _mm256_setzero_si256();\n\ + __m256i vec_c1 = _mm256_setzero_si256();\n\ +#pragma unroll\n\ + for (int k = 0; k < KK / 8; k++) {{\n\ + __m256i vec_sign = vec_signs[k];\n\ + __m256i vec_a_0 = vec_as[k * 4 + 0];\n\ + __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));\n\ + __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);\n\ + __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);\n\ + __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);\n\ + __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);\n\ + __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);\n\ + __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);\n\ + __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);\n\ + __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);\n\ + __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);\n\ + __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);\n\ + __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);\n\ + __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);\n\ + __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);\n\ + __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);\n\ + __m256i vec_a_1 = vec_as[k * 4 + 1];\n\ + __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));\n\ + __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);\n\ + __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);\n\ + __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);\n\ + __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);\n\ + __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);\n\ + __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);\n\ + __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);\n\ + __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);\n\ + __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);\n\ + __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);\n\ + __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);\n\ + __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);\n\ + __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);\n\ + __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);\n\ + __m256i vec_a_2 = vec_as[k * 4 + 2];\n\ + __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));\n\ + __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);\n\ + __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);\n\ + __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);\n\ + __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);\n\ + __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);\n\ + __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);\n\ + __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);\n\ + __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);\n\ + __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);\n\ + __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);\n\ + __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);\n\ + __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);\n\ + __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);\n\ + __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);\n\ + __m256i vec_a_3 = vec_as[k * 4 + 3];\n\ + __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));\n\ + __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));\n\ + __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);\n\ + __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);\n\ + __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);\n\ + __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);\n\ + __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);\n\ + __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);\n\ + __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);\n\ + __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);\n\ + __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);\n\ + __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);\n\ + __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);\n\ + __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);\n\ + __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);\n\ + __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);\n\ + }}\n\ + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\ + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\ + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\ + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\ + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\ + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\ + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\ + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\ + }}\n\ + }}\n\ +#endif\n\ +}}\n\ +\n\ +template\n\ +inline int32_t two_tbl_impl{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ +#ifdef __AVX2__\n\ + const __m256i vec_mask = _mm256_set1_epi8(0x0f);\n\ + const int KK = BK2 / 2;\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{0}; i += 32) {{\n\ + __m256i vec_as[KK / 2];\n\ + #pragma unroll\n\ + for (int ai = 0; ai < KK / 2; ai++) {{\n\ + vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));\n\ + }}\n\ +#pragma unroll\n\ + for (int bs = 0; bs < batch_size; bs++) {{\n\ + __m256i vec_c0 = _mm256_setzero_si256();\n\ + __m256i vec_c1 = _mm256_setzero_si256();\n\ +#pragma unroll\n\ + for (int k = 0; k < KK / 8; k++) {{\n\ + #pragma unroll\n\ + for (int j = 0; j < 4; j++) {{\n\ + __m256i vec_a = vec_as[k * 4 + j];\n\ +\n\ + __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));\n\ + __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));\n\ + __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));\n\ + __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));\n\ +\n\ + __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);\n\ + __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);\n\ + __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);\n\ +\n\ + __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);\n\ + __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);\n\ + __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);\n\ +\n\ + __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);\n\ + __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);\n\ + __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\ + __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);\n\ + vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);\n\ + vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo); \n\ + }}\n\ + }}\n\ +\n\ + __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs));\n\ + __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs));\n\ + __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs));\n\ + __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs));\n\ +\n\ + vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));\n\ + vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));\n\ + vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));\n\ + vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));\n\ +\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM{0} * bs), vec_gc0);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM{0} * bs), vec_gc1);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM{0} * bs), vec_gc2);\n\ + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM{0} * bs), vec_gc3);\n\ + }}\n\ + }}\n\ +#endif\n\ + return 0;\n\ +}}\n\ +\n\ +template\n\ +int32_t three_qgemm_lut_{0}(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ + alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\ + memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\ +#pragma unroll\n\ + for (int32_t k_outer = 0; k_outer < {1} / BBK{0}; ++k_outer) {{\n\ + three_tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 3 / 2 * BM{0})])), (&(((uint8_t*)sign)[(k_outer * BBK{0} / 3 / 8 * BM{0})])));\n\ + }}\n\ +#pragma unroll\n\ + for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{0}; i++) {{\n\ + ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\ + }}\n\ + }}\n\ + return 0;\n\ +}}\n\ +\n\ +template\n\ +int32_t two_qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ + alignas(32) uint32_t CBits[BATCH_SIZE * BM{0}];\n\ + memset(&(CBits[0]), 0, BATCH_SIZE * BM{0} * sizeof(int32_t));\n\ +#pragma unroll\n\ + for (int32_t k_outer = 0; k_outer < {2} / 32; ++k_outer) {{\n\ + two_tbl_impl{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM{0})])));\n\ + }}\n\ +#pragma unroll\n\ + for (int bs = 0; bs < BATCH_SIZE; bs++) {{\n\ +#pragma unroll\n\ + for (int i = 0; i < BM{0}; i++) {{\n\ + ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM{0}]);\n\ + ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];\n\ + }}\n\ + }}\n\ + return 0;\n\ +}}\n\ +\n\ +".format(pre, k_list[1], k_list[0])]) + return kernel_code + +def gen_top_api(kernel_shapes, k_list): + + kernel_code = "void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) {{\n\ + partial_max_reset(bs, (&(((float*)LUT_Scales)[0])));\n\ + if (m == {0} && two_k == {1} && three_k == {2}) {{\n\ + for (int32_t b = 0; b < bs; b++) {{\n\ + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\ + three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\ + two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\ + }}\n\ + }}\n\ +".format(kernel_shapes[0][0], k_list[0][0], k_list[0][1]) + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, " else if (m == {0} && two_k == {1} && three_k == {2}) {{\n\ + for (int32_t b = 0; b < bs; b++) {{\n\ + per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));\n\ + three_lut_ctor<{2}>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));\n\ + two_lut_ctor<{1}>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + {2}])), (&(((float*)LUT_Scales)[b])));\n\ + }}\n\ + }}\n".format(kernel_shapes[i][0], k_list[i][0], k_list[i][1])]) + kernel_code = "".join([kernel_code, "}\n"]) + + + kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ + if (m == {0} && k == {1}) {{\n\ + if (BK == {2}) {{\n\ + if (bs == 1) {{\n\ + two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 8) {{\n\ + two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 32) {{\n\ + two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 128) {{\n\ + two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 256) {{\n\ + two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 512) {{\n\ + two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ + }}\n\ + else if (BK == {3}) {{\n\ + if (bs == 1) {{\n\ + three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 8) {{\n\ + three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 32) {{\n\ + three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 128) {{\n\ + three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 256) {{\n\ + three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 512) {{\n\ + three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ + }}\n\ + }}\n\ +".format(kernel_shapes[0][0], kernel_shapes[0][1], k_list[0][0], k_list[0][1], "{}_{}".format(kernel_shapes[0][0], kernel_shapes[0][1]))]) + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ + if (BK == {2}) {{\n\ + if (bs == 1) {{\n\ + two_qgemm_lut_{4}<1>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 8) {{\n\ + two_qgemm_lut_{4}<8>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 32) {{\n\ + two_qgemm_lut_{4}<32>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 128) {{\n\ + two_qgemm_lut_{4}<128>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 256) {{\n\ + two_qgemm_lut_{4}<256>(A, LUT, Scales, LUT_Scales, C);\n\ + }} else if (bs == 512) {{\n\ + two_qgemm_lut_{4}<512>(A, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ + }}\n\ + else if (BK == {3}) {{\n\ + if (bs == 1) {{\n\ + three_qgemm_lut_{4}<1>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 8) {{\n\ + three_qgemm_lut_{4}<8>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 32) {{\n\ + three_qgemm_lut_{4}<32>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 128) {{\n\ + three_qgemm_lut_{4}<128>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 256) {{\n\ + three_qgemm_lut_{4}<256>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}else if (bs == 512) {{\n\ + three_qgemm_lut_{4}<512>(A, sign, LUT, Scales, LUT_Scales, C);\n\ + }}\n\ + }}\n\ + }}\n\ +".format(kernel_shapes[i][0], kernel_shapes[i][1], k_list[i][0], k_list[i][1], "{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]))]) + kernel_code = "".join([kernel_code, "}\n"]) + return kernel_code + +def gen_transform_code(kernel_shapes): + kernel_code = "\n\ +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ + return;\n\ + }\n\ +\n\ + int k = tensor->ne[0];\n\ + int m = tensor->ne[1];\n\ + const int lut_scales_size = 1;\n\ + int bk = 0;\n\ + int bm = 0;\n" + + kernel_code = "".join([kernel_code, "\n\ + if (m == {0} && k == {1}) {{\n\ + bm = BM{0}_{1};\n\ + bk = BBK{0}_{1};\n\ + }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) + + for i in range(1, len(kernel_shapes)): + kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ + bm = BM{0}_{1};\n\ + bk = BBK{0}_{1};\n\ + }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) + + kernel_code = "".join([kernel_code, "\n\ + const int n_tile_num = m / bm;\n\ + const int BK = bk;\n\ + uint8_t * qweights;\n\ + bitnet_float_type * scales;\n\ +\n\ + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ + qweights = (uint8_t *) tensor->data;\n\ + int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;\n\ + if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes;\n\ + float * i2_scales = (float * )(qweights + nbytes);\n\ + scales[0] = (bitnet_float_type) i2_scales[0];\n\ +\n\ + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ + bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ + /* .lut_scales_size = */ lut_scales_size,\n\ + /* .BK = */ BK,\n\ + /* .n_tile_num = */ n_tile_num,\n\ + /* .qweights = */ qweights,\n\ + /* .scales = */ scales\n\ + };\n\ +}\n"]) + + return kernel_code + +def get_three_k_two_k(K, bk): + bk_num = K // bk + three_k = bk_num * bk + two_k = K - three_k + return two_k, three_k + +if __name__ == "__main__": + ModelShapeDict = { + "bitnet_b1_58-large" : [[1536, 4096], + [1536, 1536], + [4096, 1536]], + "bitnet_b1_58-3B" : [[3200, 8640], + [3200, 3200], + [8640, 3200]], + "Llama3-8B-1.58-100B-tokens" : [[14336, 4096], + [4096, 14336], + [1024, 4096], + [4096, 4096]] + } + + parser = argparse.ArgumentParser(description='gen impl') + parser.add_argument('--model',default="input", type=str, dest="model", + help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.") + parser.add_argument('--BM',default="input", type=str, + help="block length when cutting one weight (M, K) into M / BM weights (BM, K).") + parser.add_argument('--BK',default="input", type=str, + help="block length when cutting one weight (M, K) into K / BK weights (M, BK).") + parser.add_argument('--bm',default="input", type=str, + help="using simd instructions to compute (bm, 192 / bm) in one block") + args = parser.parse_args() + + kernel_shapes = ModelShapeDict[args.model] + + BM_list = [int(item) for item in args.BM.split(',')] + BK_list = [int(item) for item in args.BK.split(',')] + bm_list = [int(item) for item in args.bm.split(',')] + + tbl_impl_code = [] + k_list = [] + + for i in range(len(kernel_shapes)): + k_list.append(get_three_k_two_k(kernel_shapes[i][1], BK_list[i])) + + for i in range(len(kernel_shapes)): + tbl_impl_code.append( + gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], k_list[i]) + ) + + assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes)) + + for i in range(len(kernel_shapes)): + assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0" + assert (kernel_shapes[i][1] % BK_list[i]) % 32 == 0, "K %% BK %% 32 should be 0" + assert bm_list[i] in [32], "choose bm from [32]" + + ctor_code = gen_ctor_code() + api_code = gen_top_api(kernel_shapes, k_list) + trans_code = gen_transform_code(kernel_shapes) + + output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") + + with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: + f.write(''.join("#if defined(GGML_BITNET_X86_TL2)")) + f.write(''.join(ctor_code)) + for code in tbl_impl_code: + f.write(''.join(code)) + f.write(''.join(api_code)) + f.write(''.join(trans_code)) + f.write(''.join("#endif")) + + config = ConfigParser() + + for i in range(len(kernel_shapes)): + config.add_section('Kernels_{}'.format(i)) + config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0])) + config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1])) + config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i])) + config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) + config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) + + with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: + config.write(configfile) \ No newline at end of file diff --git a/utils/convert-hf-to-gguf-bitnet.py b/utils/convert-hf-to-gguf-bitnet.py new file mode 100644 index 0000000..55b27ae --- /dev/null +++ b/utils/convert-hf-to-gguf-bitnet.py @@ -0,0 +1,1161 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import logging +import argparse +import contextlib +import json +import os +import re +import sys +from abc import ABC, abstractmethod +from enum import IntEnum +from pathlib import Path +from hashlib import sha256 +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast +import configparser + +import numpy as np +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +from convert import LlamaHfVocab, permute + +logger = logging.getLogger("hf-to-gguf") + + +###### MODEL DEFINITIONS ###### + +class SentencePieceTokenTypes(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +AnyModel = TypeVar("AnyModel", bound="type[Model]") + + +class Model(ABC): + _model_classes: dict[str, type[Model]] = {} + + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): + self.dir_model = dir_model + self.ftype = ftype + self.fname_out = fname_out + self.is_big_endian = is_big_endian + self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.use_temp_file = use_temp_file + self.is_safetensors = self._is_model_safetensors() + self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") + self.part_names = self._get_part_names() + self.hparams = Model.load_hparams(self.dir_model) + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + @property + @abstractmethod + def model_arch(self) -> gguf.MODEL_ARCH: + pass + + def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in self.hparams), None) + if key is not None: + return self.hparams[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + def set_vocab(self): + self._set_vocab_gpt2() + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for part_name in self.part_names: + logger.info(f"gguf: loading model part '{part_name}'") + ctx: ContextManager[Any] + if self.is_safetensors: + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + else: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + + with ctx as model_part: + for name in model_part.keys(): + data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] + yield name, data + + def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + return False + key_name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in key_name: + if bid is None: + return False + key_name = key_name.format(bid=bid) + else: + if bid is not None: + return False + return name == (key_name + suffix) + + def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: + new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + return new_name + + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + + if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: + self.gguf_writer.add_context_length(n_ctx) + logger.info(f"gguf: context length = {n_ctx}") + + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + logger.info(f"gguf: feed forward length = {n_ff}") + + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + logger.info(f"gguf: expert count = {n_experts}") + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + def write(self): + self.write_tensors() + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file() + self.gguf_writer.close() + + def write_vocab(self): + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.close() + + @staticmethod + def count_model_parts(dir_model: Path, prefix: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.endswith(prefix): + num_parts += 1 + + return num_parts + + @staticmethod + def load_hparams(dir_model): + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) + + @classmethod + def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: + assert names + + def func(modelcls: type[Model]): + for name in names: + cls._model_classes[name] = modelcls + return modelcls + return func + + @classmethod + def from_model_architecture(cls, arch): + try: + return cls._model_classes[arch] + except KeyError: + raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + + def _is_model_safetensors(self) -> bool: + return Model.count_model_parts(self.dir_model, ".safetensors") > 0 + + def _get_part_names(self): + if self.is_safetensors: + if self.num_parts == 1: # there's only one .safetensors file + return ("model.safetensors",) + return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + + if self.num_parts == 1: # there's only one .bin file + return ("pytorch_model.bin",) + return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + + # used for GPT-2 BPE and WordPiece vocabs + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + return tokens, toktypes, tokpre + + # NOTE: this function is generated by convert-hf-to-gguf-update.py + # do not modify it manually! + # ref: https://github.com/ggerganov/llama.cpp/pull/6920 + def get_vocab_base_pre(self, tokenizer) -> str: + # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that + # is specific for the BPE pre-tokenizer used by the model + # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can + # use in llama.cpp to implement the same pre-tokenizer + + chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + logger.debug(f"chktok: {chktok}") + logger.debug(f"chkhsh: {chkhsh}") + + res = None + + # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script + # or pull the latest version of the model from Huggingface + # don't edit the hashes manually! + if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": + # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B + res = "llama-bpe" + if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": + # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base + res = "deepseek-llm" + if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821": + # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base + res = "deepseek-coder" + if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": + # ref: https://huggingface.co/tiiuae/falcon-7b + res = "falcon" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 + res = "bert-bge" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/mosaicml/mpt-7b + res = "mpt" + if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34": + # ref: https://huggingface.co/bigcode/starcoder2-3b + res = "starcoder" + if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454": + # ref: https://huggingface.co/openai-community/gpt2 + res = "gpt-2" + if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff": + # ref: https://huggingface.co/smallcloudai/Refact-1_6-base + res = "refact" + if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": + # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 + res = "command-r" + + if res is None: + logger.warning("\n") + logger.warning("**************************************************************************************") + logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") + logger.warning("** There are 2 possible reasons for this:") + logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet") + logger.warning("** - the pre-tokenization config has changed upstream") + logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.") + logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + logger.warning("**") + logger.warning(f"** chkhsh: {chkhsh}") + logger.warning("**************************************************************************************") + logger.warning("\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + + logger.debug(f"tokenizer.ggml.pre: {repr(res)}") + logger.debug(f"chkhsh: {chkhsh}") + + return res + + def _set_vocab_gpt2(self) -> None: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_sentencepiece(self): + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + key = key.encode("utf-8") + if key not in tokens: + tokens.append(key) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(f"[PAD{i}]") + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + assert len(tokens) == vocab_size + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def _set_vocab_llama_hf(self): + vocab = LlamaHfVocab(self.dir_model) + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + +# TL1 + +def process_tl1(weight, BM, BY, bm, by, M, K): + weight = weight.reshape((M, K // 2)).astype(np.uint8) + weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1) + weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2) + weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm // 16, 16, by // 2)).transpose(0, 1, 2, 3, 4, 6, 5) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm // 16, by // 4, 4 // 2, 16)).transpose(0, 1, 2, 3, 4, 5, 7, 6) + weight = weight.reshape((M * K // 16 // 4, 16, 4 // 2)) + weight_0 = weight[:, :, 0] << 4 + weight_1 = weight[:, :, 1] + weight = weight_0 + weight_1 + return weight + +def preprocess_weights_tl1( + w: np.ndarray, + bits = 2, + g = 4, +) -> Tuple[np.ndarray, np.ndarray]: + from configparser import ConfigParser + config = ConfigParser() + + M, K = w.shape + weight = w + weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32) + weight = np.sign(weight) + weight_num = np.prod(weight.shape) + + config.read('include/kernel_config.ini') + BM = -1 + BY = -1 + bm = -1 + + for kernel in config.sections(): + if int(config.get(kernel, 'm')) == M and int(config.get(kernel, 'k')) == K: + BM = int(config.get(kernel, 'bm')) + BY = int(config.get(kernel, 'bk')) + bm = int(config.get(kernel, 'bmm')) + by = 256 // bm + break + + if BM == -1: + raise NotImplementedError + + weight = np.reshape(weight, (weight_num // 2, 2)) + hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3) + lo_weight = np.split(weight, 2, axis=1)[1] + + weight = np.reshape((hi_weight + lo_weight), weight_num // 2) + + weight = weight + 4 + weight = np.reshape(weight, (M, K // 2)).astype(np.uint8) + + weight = process_tl1(weight, BM, BY, bm, by, M, K) + + return weight + + +def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): + weight = np.reshape(weight, (weight_num // 2, 2)) + hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3) + lo_weight = np.split(weight, 2, axis=1)[1] + + weight = np.reshape((hi_weight + lo_weight), weight_num // 2) + + weight = weight + 4 + weight = np.reshape(weight, (M, K // 2)).astype(np.uint8) + weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1) + weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2) + weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2)) + weight_0 = weight[:, :, :, :, :, 0] + weight_1 = weight[:, :, :, :, :, 1] + weight_0 = weight_0 << 4 + weight_1 = weight_1 + weight = weight_0 + weight_1 + weight = weight.reshape((M * K // bm // by, bm // 8, 8)) + weight[:, [0, 1, 2, 3], :] = weight[:, [0, 2, 1, 3], :] + weight = weight.reshape(M * K // bm // by, bm) + + for i in range(weight.shape[0]): + final_weight.append(weight[i, :]) + +def preprocess_three_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): + weight = np.reshape(weight, (weight_num // 3, 3)) + split_weights = np.split(weight, 3, axis=1) + first_weight = np.multiply(split_weights[0], 9) + second_weight = np.multiply(split_weights[1], 3) + third_weight = split_weights[2] + + weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3) + sign_weight = np.sign(weight) + 2 + sign_weight = np.where(sign_weight > 1, 0, sign_weight) + weight = np.abs(weight) + + weight = np.reshape(weight, (M, K // 3)).astype(np.uint8) + sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8) + + weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1) + weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2) + weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4) + weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3)) + weight_0 = weight[:, :, :, :, :, 0] + weight_1 = weight[:, :, :, :, :, 1] + weight_0 = weight_0 << 4 + weight_1 = weight_1 + weight = weight_0 + weight_1 + weight = weight.reshape((M * K // bm // by, bm // 8, 8)) + weight[:, [0, 1, 2, 3], :] = weight[:, [0, 2, 1, 3], :] + weight = weight.reshape(M * K // bm // by, bm) + + for i in range(weight.shape[0]): + final_weight.append(weight[i, :]) + + sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1) + sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2) + sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3) + sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4) + sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), bm, by // 3 * 4)).transpose(0, 1, 2, 3, 5, 4) + sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 8, bm // 2)).astype(np.uint16) + combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm // 2), dtype=np.uint16) + for i in range(16): + temp_weight = sign_weight[:, :, :, :, i, :] << 15 - i + combine_weight += temp_weight + combine_weight = combine_weight.view(np.uint8) + combine_weight = combine_weight.reshape((M * K // bm // (by * 4)), bm) + + for i in range(combine_weight.shape[0]): + final_weight.append(combine_weight[i, :]) + +def preprocess_weights_tl2( + w: np.ndarray, + bits = 2, + g = 4, +) -> Tuple[np.ndarray, np.ndarray]: + from configparser import ConfigParser + config = ConfigParser() + + M, K = w.shape + weight = w + weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32) + weight = np.sign(weight) + weight_num = np.prod(weight.shape) + + config.read('include/kernel_config.ini') + BM = -1 + BY = -1 + bm = -1 + + for kernel in config.sections(): + if int(config.get(kernel, 'm')) == M and int(config.get(kernel, 'k')) == K: + BM = int(config.get(kernel, 'bm')) + BY = int(config.get(kernel, 'bk')) + bm = int(config.get(kernel, 'bmm')) + by = 192 // bm + break + + if BM == -1: + raise NotImplementedError + + if (weight.shape[1] % BY != 0): + slice_k_idx = weight.shape[1] - weight.shape[1] % BY + slice_weights = np.split(weight, [slice_k_idx], axis=1) + three_weight = slice_weights[0] + two_weight = slice_weights[1] + else: + three_weight = weight + + final_weight = [] + + preprocess_three_weights_tl2(three_weight.shape[0], + three_weight.shape[1], + three_weight.shape[0] * three_weight.shape[1], + BM, + BY, + bm, + by, + three_weight, + final_weight) + + if (weight.shape[1] % BY != 0): + preprocess_two_weights_tl2( two_weight.shape[0], + two_weight.shape[1], + two_weight.shape[0] * two_weight.shape[1], + BM, + 32, + 32, + 4, + two_weight, + final_weight) + weight = np.array(final_weight, dtype=np.uint8).reshape(-1) + weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 - + weight.shape[0]), mode='constant', constant_values=0) + return weight + +def transform_to_tl1(x: np.ndarray): + scale = np.max(np.abs(x)) + # res = np.round(x / scale + 2).astype(np.uint8) + res = preprocess_weights_tl1(x) + return res, scale + +def transform_to_tl2(x: np.ndarray): + scale = np.max(np.abs(x)) + # res = np.round(x / scale + 2).astype(np.uint8) + res = preprocess_weights_tl2(x) + return res, scale + + +def read_model_config(model_dir: str) -> dict[str, Any]: + config = os.path.join(model_dir, "config.json") + if not os.path.exists(config): + raise FileNotFoundError(f"Model config file not found: {config}") + with open(config, "r") as f: + return json.load(f) + +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +class LlamaModel(Model): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + try: + self._set_vocab_llama_hf() + except (FileNotFoundError, TypeError): + # Llama 3 + self._set_vocab_gpt2() + + # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) + if self.hparams.get("vocab_size", 32000) == 32016: + special_vocab = gguf.SpecialVocab( + self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot'] + ) + special_vocab._set_special_token("prefix", 32007) + special_vocab._set_special_token("suffix", 32008) + special_vocab._set_special_token("middle", 32009) + special_vocab._set_special_token("eot", 32010) + special_vocab.add_to_gguf(self.gguf_writer) + + def write_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + scale_map = dict() + + for name, data_torch in self.get_tensors(): + if name.endswith(("weight_scale")): + data_torch = data_torch.to(torch.float32) + name = name.replace(".weight_scale", "") + scale_map[name] = data_torch + + for name, data_torch in self.get_tensors(): + if name.endswith(("weight_scale")): + continue + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + if name.replace(".weight", "") in scale_map: + data_torch = data_torch.to(torch.uint8) + origin_shape = data_torch.shape + shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(origin_shape))))) + data_torch = data_torch.unsqueeze(0).expand((4, *origin_shape)) >> shift + data_torch = data_torch & 3 + data_torch = (data_torch.float() - 1).reshape((origin_shape[0] * 4, *origin_shape[1:])) + data_torch = data_torch / scale_map[name.replace(".weight", "")].float() + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + # old gguf bf16 not implenmented + # if data_torch.dtype == torch.bfloat16: + # for new_name, data in ((n, d) for n, d in self.modify_tensors(data_torch, name, bid)): + # shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + # # n_dims is implicit in the shape + # logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype}, shape = {shape_str}") + # self.gguf_writer.add_tensor(new_name, data, raw_shape=data.shape, raw_dtype=data.dtype) + # continue + + for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): + data: np.ndarray = data # type hint + data_shape = data.shape + n_dims = len(data.shape) + data_dtype = data.dtype + data_qtype: gguf.GGMLQuantizationType | None = None + + # when both are True, f32 should win + # extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) + # extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + extra_f32 = False + extra_f16 = False + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + tensors_f32 = [ + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + ] + if not args.quant_embd: + tensors_f32.append(gguf.MODEL_TENSOR.TOKEN_EMBD) + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in tensors_f32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) + + suit_i2 = True + if name.endswith('lm_head.weight') or name.endswith('norm.weight') or name.endswith('embed_tokens.weight'): + suit_i2 = False + + i2_scale = None + if self.ftype != gguf.GGMLQuantizationType.F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.GGMLQuantizationType.TL1 and suit_i2: + data, i2_scale = transform_to_tl1(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL1 + elif self.ftype == gguf.GGMLQuantizationType.TL2 and suit_i2: + data, i2_scale = transform_to_tl2(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL2 + else: # default to float16 for quantized tensors + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 + + if data_qtype is None: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + shape = data_shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) + if i2_scale is not None: + self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + # Apply to granite small models only + if self.hparams.get("vocab_size", 32000) == 49152: + self.gguf_writer.add_add_bos_token(False) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for wid in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # quant weight to i2 (in fp16) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = self.weight_quant(data_torch) + + return [(self.map_tensor_name(name), data_torch)] + + def write_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): + data: np.ndarray = data # type hint + data_shape = data.shape + n_dims = len(data.shape) + data_dtype = data.dtype + data_qtype: gguf.GGMLQuantizationType | None = None + + # when both are True, f32 should win + # extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) + # extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + extra_f32 = False + extra_f16 = False + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + tensors_f32 = [ + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + ] + if not args.quant_embd: + tensors_f32.append(gguf.MODEL_TENSOR.TOKEN_EMBD) + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in tensors_f32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) + + suit_i2 = True + if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): + suit_i2 = False + + i2_scale = None + if self.ftype != gguf.GGMLQuantizationType.F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.GGMLQuantizationType.TL1 and suit_i2: + data, i2_scale = transform_to_tl1(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL1 + elif self.ftype == gguf.GGMLQuantizationType.TL2 and suit_i2: + data, i2_scale = transform_to_tl2(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL2 + else: # default to float16 for quantized tensors + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 + + if data_qtype is None: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + shape = data_shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) + if i2_scale is not None: + self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + + +###### CONVERSION LOGIC ###### + + +ftype_map = { + "f32": gguf.GGMLQuantizationType.F32, + "f16": gguf.GGMLQuantizationType.F16, + "tl1" : gguf.GGMLQuantizationType.TL1, + "tl2" : gguf.GGMLQuantizationType.TL2, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--awq-path", type=Path, default=None, + help="Path to scale awq cache file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "--outtype", type=str, choices=ftype_map.keys(), default="f32", + help="output format - use f32 for float32, f16 for float16", + ) + parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser.add_argument( + "model", type=Path, + help="directory containing model file", + ) + parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)") + parser.add_argument("--model-name", type=str, default=None, help="name of the model") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer") + + return parser.parse_args() + + +def main() -> None: + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + dir_model = args.model + + if not dir_model.is_dir(): + logger.error(f'Error: {args.model} is not a directory') + sys.exit(1) + + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_model / f'ggml-model-{args.outtype}.gguf' + + logger.info(f"Loading model: {dir_model.name}") + + hparams = Model.load_hparams(dir_model) + + with torch.inference_mode(): + model_class = Model.from_model_architecture(hparams["architectures"][0]) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file) + + logger.info("Set model parameters") + model_instance.set_gguf_parameters() + + logger.info("Set model tokenizer") + model_instance.set_vocab() + + if args.vocab_only: + logger.info(f"Exporting model vocab to '{fname_out}'") + model_instance.write_vocab() + else: + logger.info(f"Exporting model to '{fname_out}'") + model_instance.write() + + logger.info(f"Model successfully exported to '{fname_out}'") + + +if __name__ == '__main__': + args = parse_args() + + main() diff --git a/utils/convert.py b/utils/convert.py new file mode 100644 index 0000000..5938c42 --- /dev/null +++ b/utils/convert.py @@ -0,0 +1,1711 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import concurrent.futures +import enum +import faulthandler +import functools +import itertools +import json +import math +import mmap +import os +import pickle +import re +import signal +import struct +import sys +import textwrap +import time +import zipfile +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Tuple + +import configparser +import numpy as np +from sentencepiece import SentencePieceProcessor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + +logger = logging.getLogger("convert") + +if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): + faulthandler.register(signal.SIGUSR1) + +NDArray: TypeAlias = 'np.ndarray[Any, Any]' + +ARCH = gguf.MODEL_ARCH.BITNET + +DEFAULT_CONCURRENCY = 16 + +ADDED_TOKENS_FILE = 'added_tokens.json' +FAST_TOKENIZER_FILE = 'tokenizer.json' + +# +# data types +# + + +@dataclass(frozen=True) +class DataType: + name: str + dtype: np.dtype[Any] + valid_conversions: list[str] + + def elements_to_bytes(self, n_elements: int) -> int: + return n_elements * self.dtype.itemsize + + +@dataclass(frozen=True) +class UnquantizedDataType(DataType): + pass + + +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0', 'I2']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) +DT_I2 = UnquantizedDataType('I2', dtype = np.dtype(np.uint8), valid_conversions = ['F32', 'F16', 'Q8_0']) + +@dataclass(frozen=True) +class QuantizedDataType(DataType): + block_size: int + quantized_dtype: np.dtype[Any] + ggml_type: gguf.GGMLQuantizationType + + def quantize(self, arr: NDArray) -> NDArray: + raise NotImplementedError(f'Quantization for {self.name} not implemented') + + def elements_to_bytes(self, n_elements: int) -> int: + assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + return self.quantized_dtype.itemsize * (n_elements // self.block_size) + + +@dataclass(frozen=True) +class Q8_0QuantizedDataType(QuantizedDataType): + # Mini Q8_0 quantization in Python! + def quantize(self, arr: NDArray) -> NDArray: + assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' + assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + n_blocks = arr.size // self.block_size + blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre + + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(d, qs) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + +# @dataclass(frozen=True) +# class TransformedDataType(DataType): +# transformed_dtype: np.dtype[Any] + +# def transform(self, arr: NDArray) -> NDArray: +# raise NotImplementedError(f'Transformation for {self.name} not implemented') + +# @dataclass(frozen=True) +# class I2TransformedDataType(TransformedDataType): +# # fp32 -> int2 (dtype is uint8) +# def transform(self, arr: NDArray) -> NDArray: +# assert(np.prod(arr.shape) % 4 == 0) +# # Much faster implementation of block quantization contributed by @Cebtenzzre + +# def transform_to_i2(x : NDArray) -> Iterable[tuple[Any, Any]]: +# x_num = np.prod(x.shape) +# x = np.reshape(x, x_num) +# for i in range(x_num): +# if x[i] != 0: +# d = x[i] +# break +# x = np.divide(x, d) +# x = x.astype(np.uint8) +# x = np.reshape(x, [x.shape[0] // 4, 4]) +# keep_bit = {0:192, 1:48, 2:12, 3:3} +# ans = np.zeros([x_num // 4], dtype=np.uint8) +# for i in range(4): +# x_bit_col = x[:, i] +# x_bit_shift = np.left_shift(x_bit_col, 6 - i * 2) +# x_bit_shift = np.bitwise_and(x_bit_shift, keep_bit[i]) +# ans = np.bitwise_or(ans, x_bit_shift) +# return ans +# return transform_to_i2(arr) + +# def elements_to_bytes(self, n_elements: int) -> int: +# return n_elements // 4 + + +DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', + dtype = np.dtype(np.float32), valid_conversions = [], + ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, + quantized_dtype = np.dtype([('d', ' DataType: + dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) + if dt is None: + raise ValueError(self) + # Convert all 1D tensors to F32. Most of the codebase that takes in 1D tensors only handles F32 tensors, and most of the outputs tensors are F32. + # Also The 1d tensors aren't much of a performance/size issue. So instead of having to have separate F32 and F16 implementations of both, just convert everything to F32 for now. + dt = dt if len(tensor.shape) > 1 else DT_F32 + if name == "token_embd.weight" or name == "output.weight": + dt = DT_F32 + return dt + + +GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { + GGMLFileType.AllF32 : DT_F32, + GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.MostlyI2 : DT_I2, + GGMLFileType.MostlyQ8_0: DT_Q8_0, +} + +# +# hparams loading +# + + +@dataclass +class Params: + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None + + rope_scaling_type: gguf.RopeScalingType | None = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None + n_orig_ctx: int | None = None + rope_finetuned: bool | None = None + + ftype: GGMLFileType | None = None + + # path to the directory containing the model files + path_model: Path | None = None + + @staticmethod + def guessed(model: LazyModel) -> Params: + # try transformer naming first + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + else: + n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + + if n_layer < 1: + msg = """\ + failed to guess 'n_layer'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed + + # TODO: verify this + n_ff = int(2 * (4 * n_embd) / 3) + n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) + + return Params( + n_vocab = n_vocab, + n_embd = n_embd, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, + ) + + @staticmethod + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None + rope_scaling = config.get("rope_scaling") + + if rope_scaling is not None and (typ := rope_scaling.get("type")): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = gguf.RopeScalingType.LINEAR + elif typ == "yarn": + rope_scaling_type = gguf.RopeScalingType.YARN + n_orig_ctx = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] + else: + raise NotImplementedError(f'Unknown rope scaling type: {typ}') + + if "max_sequence_length" in config: + n_ctx = config["max_sequence_length"] + elif "max_position_embeddings" in config: + n_ctx = config["max_position_embeddings"] + else: + msg = """\ + failed to guess 'n_ctx'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_experts = None + n_experts_used = None + + if "num_local_experts" in config: + n_experts = config["num_local_experts"] + n_experts_used = config["num_experts_per_tok"] + + return Params( + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = (n_head := config["num_attention_heads"]), + n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, + f_rope_scale = f_rope_scale, + n_orig_ctx = n_orig_ctx, + rope_finetuned = rope_finetuned, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} + @staticmethod + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + n_experts = None + n_experts_used = None + f_rope_freq_base = None + + # hack to determine LLaMA v1 vs v2 vs CodeLlama + if config.get("moe"): + # Mixtral + n_ctx = 32768 + elif config.get("rope_theta") == 1000000: + # CodeLlama + n_ctx = 16384 + elif config["norm_eps"] == 1e-05: + # LLaMA v2 + n_ctx = 4096 + else: + # LLaMA v1 + n_ctx = 2048 + + if "layers.0.feed_forward.w1.weight" in model: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + + if config.get("moe"): + n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] + n_experts = config["moe"]["num_experts"] + n_experts_used = config["moe"]["num_experts_per_tok"] + f_rope_freq_base = 1e6 + + return Params( + n_vocab = model["tok_embeddings.weight"].shape[0], + n_embd = config["dim"], + n_layer = config["n_layers"], + n_ctx = n_ctx, + n_ff = n_ff, + n_head = (n_head := config["n_heads"]), + n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["norm_eps"], + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + ) + + @staticmethod + def load(model_plus: ModelPlus) -> Params: + hf_config_path = model_plus.paths[0].parent / "config.json" + orig_config_path = model_plus.paths[0].parent / "params.json" + + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + elif model_plus.format != 'none': + params = Params.guessed(model_plus.model) + else: + raise ValueError('Cannot guess params when model format is none') + + params.path_model = model_plus.paths[0].parent + + return params + + +# +# vocab +# + +@runtime_checkable +class BaseVocab(Protocol): + tokenizer_model: ClassVar[str] + name: ClassVar[str] + + +class NoVocab(BaseVocab): + tokenizer_model = "no_vocab" + name = "no_vocab" + + def __repr__(self) -> str: + return "" + + +@runtime_checkable +class Vocab(BaseVocab, Protocol): + vocab_size: int + added_tokens_dict: dict[str, int] + added_tokens_list: list[str] + fname_tokenizer: Path + + def __init__(self, base_path: Path): ... + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ... + + +class BpeVocab(Vocab): + tokenizer_model = "gpt2" + name = "bpe" + + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + + if (fname_tokenizer := base_path / 'vocab.json').exists(): + # "slow" tokenizer + with open(fname_tokenizer, encoding="utf-8") as f: + self.vocab = json.load(f) + + try: + # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. + with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass + else: + # "fast" tokenizer + fname_tokenizer = base_path / FAST_TOKENIZER_FILE + + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding="utf-8") as f: + tokenizer_json = json.load(f) + + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + if ( + tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'ByteLevel' + ): + raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') + + self.vocab = tokenizer_model["vocab"] + + if (added := tokenizer_json.get('added_tokens')) is not None: + # Added tokens here can be duplicates of the main vocabulary. + added_tokens = {item['content']: item['id'] + for item in added + if item['content'] not in self.vocab} + + vocab_size = len(self.vocab) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) + if expected_ids != actual_ids: + expected_end_id = vocab_size + len(actual_ids) - 1 + raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " + f"{vocab_size} - {expected_end_id}; got {actual_ids}") + + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + + def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} + + for i, _ in enumerate(self.vocab): + yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.CONTROL + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.bpe_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class SentencePieceVocab(Vocab): + tokenizer_model = "llama" + name = "spm" + + def __init__(self, base_path: Path): + added_tokens: dict[str, int] = {} + if (fname_tokenizer := base_path / 'tokenizer.model').exists(): + # normal location + try: + with open(base_path / ADDED_TOKENS_FILE, encoding="utf-8") as f: + added_tokens = json.load(f) + except FileNotFoundError: + pass + elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): + # not found in alternate location either + raise FileNotFoundError('Cannot find tokenizer.model') + + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + vocab_size = self.sentencepiece_tokenizer.vocab_size() + + new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) + actual_new_ids = sorted(new_tokens.keys()) + + if expected_new_ids != actual_new_ids: + raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + + # Token pieces that were added to the base vocabulary. + self.added_tokens_dict = added_tokens + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + + def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.sentencepiece_tokenizer + for i in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score: float = tokenizer.get_score(i) + + toktype = gguf.TokenType.NORMAL + if tokenizer.is_unknown(i): + toktype = gguf.TokenType.UNKNOWN + if tokenizer.is_control(i): + toktype = gguf.TokenType.CONTROL + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + + if tokenizer.is_unused(i): + toktype = gguf.TokenType.UNUSED + if tokenizer.is_byte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.sentencepiece_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +class LlamaHfVocab(Vocab): + tokenizer_model = "llama" + name = "hfft" + + def __init__(self, base_path: Path): + fname_tokenizer = base_path / FAST_TOKENIZER_FILE + # if this fails, FileNotFoundError propagates to caller + with open(fname_tokenizer, encoding='utf-8') as f: + tokenizer_json = json.load(f) + + # pre-check so we know if we need transformers + tokenizer_model: dict[str, Any] = tokenizer_json['model'] + is_llama3 = ( + tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) + and not tokenizer_model.get('byte_fallback', True) + ) + if is_llama3: + raise TypeError('Llama 3 must be converted with BpeVocab') + + if not is_llama3 and ( + tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) + or tokenizer_json['decoder']['type'] != 'Sequence' + ): + raise FileNotFoundError('Cannot find Llama BPE tokenizer') + + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError( + "To use LlamaHfVocab, please install the `transformers` package. " + "You can install it with `pip install transformers`." + ) from e + + # Allow the tokenizer to default to slow or fast versions. + # Explicitly set tokenizer to use local paths. + self.tokenizer = AutoTokenizer.from_pretrained( + base_path, + cache_dir=base_path, + local_files_only=True, + ) + assert self.tokenizer.is_fast # assume tokenizer.json is used + + # Initialize lists and dictionaries for added tokens + self.added_tokens_list = [] + self.added_tokens_dict = dict() + self.added_tokens_ids = set() + + # Process added tokens + for tok, tokidx in sorted( + self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] + ): + # Only consider added tokens that are not in the base vocabulary + if tokidx >= self.tokenizer.vocab_size: + self.added_tokens_list.append(tok) + self.added_tokens_dict[tok] = tokidx + self.added_tokens_ids.add(tokidx) + + # Store special tokens and their IDs + self.specials = { + tok: self.tokenizer.get_vocab()[tok] + for tok in self.tokenizer.all_special_tokens + } + self.special_ids = set(self.tokenizer.all_special_ids) + + # Set vocabulary sizes + self.vocab_size_base = self.tokenizer.vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + + self.fname_tokenizer = fname_tokenizer + + def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + reverse_vocab = { + id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() + } + + for token_id in range(self.vocab_size_base): + # Skip processing added tokens here + if token_id in self.added_tokens_ids: + continue + + # Convert token text to bytes + token_text = reverse_vocab[token_id].encode("utf-8") + + # Yield token text, score, and type + yield token_text, self.get_token_score(token_id), self.get_token_type( + token_id, token_text, self.special_ids # Reuse already stored special IDs + ) + + def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: + # Special case for byte tokens + if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + return gguf.TokenType.BYTE + + # Determine token type based on whether it's a special token + return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + + def get_token_score(self, token_id: int) -> float: + # Placeholder for actual logic to determine the token's score + # This needs to be implemented based on specific requirements + return -1000.0 # Default score + + def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + if text in self.specials: + toktype = self.get_token_type(self.specials[text], b'', self.special_ids) + score = self.get_token_score(self.specials[text]) + else: + toktype = gguf.TokenType.USER_DEFINED + score = -1000.0 + + yield text.encode("utf-8"), score, toktype + + def has_newline_token(self): + return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab + + def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + yield from self.hf_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"" + + +# +# data loading +# TODO: reuse (probably move to gguf.py?) +# + + +def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +class Tensor(ABC): + ndarray: NDArray + data_type: DataType + + @abstractmethod + def astype(self, data_type: DataType) -> Self: ... + @abstractmethod + def permute(self, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def part(self, n_part: int) -> Self: ... + @abstractmethod + def to_ggml(self) -> GGMLCompatibleTensor: ... + + +def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: + assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + fp32_arr = bf16_arr.astype(np.uint32) << 16 + return fp32_arr.view(np.float32) + +def preprocess_weights( + w: np.ndarray, + bits = 2, + g = 4, +) -> Tuple[np.ndarray, np.ndarray]: + M, K = w.shape + + cf=configparser.ConfigParser() + cf.read("./build/kcfg.ini") + secs=cf.sections() + for sec in secs: + sec_splits = str(sec).split('_') + if sec_splits[-4] == "m" + str(M*2) and sec_splits[-3] == "k" + str(K): + bm = int(cf.get(sec, 'bm')) + kfactor = int(cf.get(sec, 'kfactor')) + simd_n_in = int(cf.get(sec, 'simd_n_in')) + simd_n_out = int(cf.get(sec, 'simd_n_out')) + break + + M = M * bits + ngroups_per_elem = 8 // g + + # (M // bits, K, bits) + w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1) + # print(w) + # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) + w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g) + w = sum([(w[:, :, :, ig] << ig) for ig in range(g)]) + # print(w) + # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31 + # for bits=3 + # bit0: [0, 8), bit1: [8, 16), bit2: [16, 24), bit0: [24, 32) + # (M // bits // simd_n_float16, bits, simd_n_float16, K // g) + w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + mgroup = ngroups_per_elem * simd_n_in + w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + # 0 1 2 3 4 5 + w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + w = w.reshape(M // bm, K // g // kfactor, bm // mgroup, kfactor, simd_n_in) + # input size of current TVM API + w = w.reshape(M // bm, K // g, bm // ngroups_per_elem) + + return w + +def transform_to_i2(x : NDArray): + x_num = np.prod(x.shape) + tile_x = np.reshape(x, x_num) + scale = 1 + for i in range(x_num): + if tile_x[i] != 0: + scale = tile_x[i] + break + tile_x = np.divide(tile_x, scale) + tile_x = (tile_x.astype(np.int8) + 2).astype(np.uint8) + ans = np.reshape(tile_x, x.shape) + return ans, scale + +class UnquantizedTensor(Tensor): + def __init__(self, ndarray: NDArray, i2_scale: NDArray = None): + assert isinstance(ndarray, np.ndarray) + self.ndarray = ndarray + self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] + self.i2_scale = i2_scale + + def astype(self, data_type: DataType) -> UnquantizedTensor: + dtype = data_type.dtype + if self.data_type == DT_BF16: + self.ndarray = bf16_to_fp32(self.ndarray) + if dtype == np.uint8: + self.ndarray, self.i2_scale = transform_to_i2(self.ndarray) + return UnquantizedTensor(self.ndarray.astype(dtype), self.i2_scale) + + def to_ggml(self) -> Self: + return self + + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + + def part(self, n_part: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + + def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: + return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) + + +def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: + tensor = lazy_tensor.load() + assert isinstance(tensor, UnquantizedTensor) + + # double-check: + actual_shape = list(tensor.ndarray.shape) + assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape) + if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype: + if convert: + tensor.ndarray = tensor.ndarray.astype(expected_dtype) + else: + raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + + return tensor.ndarray + + +GGMLCompatibleTensor = UnquantizedTensor + + +@dataclass +class LazyTensor: + _load: Callable[[], Tensor] + shape: list[int] + data_type: DataType + description: str + + def load(self) -> Tensor: + ret = self._load() + # Should be okay if it maps to the same numpy type? + assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ + (self.data_type, ret.data_type, self.description) + return ret + + def astype(self, data_type: DataType) -> LazyTensor: + self.validate_conversion_to(data_type) + + def load() -> Tensor: + return self.load().astype(data_type) + return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + def validate_conversion_to(self, data_type: DataType) -> None: + if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: + raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + + +LazyModel: TypeAlias = 'dict[str, LazyTensor]' + + +@dataclass +class ModelPlus: + model: LazyModel + paths: list[Path] # Where this was read from. + format: Literal['ggml', 'torch', 'safetensors', 'none'] + vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab. + + +def merge_sharded(models: list[LazyModel]) -> LazyModel: + # Original LLaMA models have each file contain one part of each tensor. + # Use a dict instead of a set to preserve order. + names = {name: None for model in models for name in model} + + def convert(name: str) -> LazyTensor: + lazy_tensors = [model[name] for model in models] + if len(lazy_tensors) == 1: + # only one file; don't go through this procedure since there might + # be quantized tensors + return lazy_tensors[0] + if len(lazy_tensors[0].shape) == 1: + # the tensor is just duplicated in every file + return lazy_tensors[0] + if name.startswith('tok_embeddings.') or \ + name.endswith('.attention.wo.weight') or \ + name.endswith('.feed_forward.w2.weight'): + # split by columns + axis = 1 + else: + # split by rows + axis = 0 + concatenated_shape = list(lazy_tensors[0].shape) + concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors) + + def load() -> UnquantizedTensor: + ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] + concatenated = np.concatenate(ndarrays, axis=axis) + return UnquantizedTensor(concatenated) + description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' + return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + return {name: convert(name) for name in names} + + +def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: + formats = set(mp.format for mp in models_plus) + assert len(formats) == 1, "different formats?" + format = formats.pop() + paths = [path for mp in models_plus for path in mp.paths] + # Use the first non-None vocab, if any. + try: + vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None) + except StopIteration: + vocab = None + + if any("model.embed_tokens.weight" in mp.model for mp in models_plus): + # Transformers models put different tensors in different files, but + # don't split individual tensors between files. + model: LazyModel = {} + for mp in models_plus: + model.update(mp.model) + else: + model = merge_sharded([mp.model for mp in models_plus]) + + return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types + + +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute(n_head, n_head_kv) + return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + + +def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor: + def load() -> Tensor: + tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors] + return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors])) + s = lazy_tensors[0].shape.copy() + s.insert(0, len(lazy_tensors)) + return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) + + +# Functionality that simulates `torch.load` but where individual tensors are +# only loaded into memory on demand, not all at once. +# PyTorch can't do this natively as of time of writing: +# - https://github.com/pytorch/pytorch/issues/64327 +# This allows us to de-shard without multiplying RAM usage, and also +# conveniently drops the PyTorch dependency (though we still need numpy). + + +@dataclass +class LazyStorageKind: + data_type: DataType + + +@dataclass +class LazyStorage: + load: Callable[[int, int], NDArray] + kind: LazyStorageKind + description: str + + +class LazyUnpickler(pickle.Unpickler): + def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): + super().__init__(fp) + self.data_base_path = data_base_path + self.zip_file = zip_file + + def persistent_load(self, pid: Any) -> Any: + assert pid[0] == 'storage' + assert isinstance(pid[1], LazyStorageKind) + data_type = pid[1].data_type + filename_stem = pid[2] + filename = f'{self.data_base_path}/{filename_stem}' + info = self.zip_file.getinfo(filename) + + def load(offset: int, elm_count: int) -> NDArray: + dtype = data_type.dtype + with self.zip_file.open(info) as fp: + fp.seek(offset * dtype.itemsize) + size = elm_count * dtype.itemsize + data = fp.read(size) + assert len(data) == size + return np.frombuffer(data, dtype) + description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + return LazyStorage(load=load, kind=pid[1], description=description) + + @staticmethod + def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, + requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + assert isinstance(storage, LazyStorage) + + def load() -> UnquantizedTensor: + elm_count = stride[0] * size[0] + return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) + description = f'pickled storage_offset={storage_offset} in {storage.description}' + return LazyTensor(load, list(size), storage.kind.data_type, description) + + @staticmethod + def rebuild_from_type_v2(func, new_type, args, state): + return func(*args) + + CLASSES = { + # getattr used here as a workaround for mypy not being smart enough to determine + # the staticmethods have a __func__ attribute. + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), + ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), + ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), + ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), + ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, + } + + def find_class(self, module: str, name: str) -> Any: + if not module.startswith('torch'): + return super().find_class(module, name) + return self.CLASSES[(module, name)] + + +def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: + zf = zipfile.ZipFile(outer_fp) + pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + assert len(pickle_paths) == 1, pickle_paths + pickle_fp = zf.open(pickle_paths[0], 'r') + unpickler = LazyUnpickler(pickle_fp, + data_base_path=pickle_paths[0][:-4], + zip_file=zf) + model = unpickler.load() + if 'model' in model: model = model['model'] + as_dict = dict(model.items()) + return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + + +def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: + header_size, = struct.unpack(' LazyTensor: + data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + numpy_dtype = data_type.dtype + shape: list[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * numpy_dtype.itemsize + buf = byte_buf[begin:end] + + def load() -> UnquantizedTensor: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return LazyTensor(load, shape, data_type, description) + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} + return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + +def must_read(fp: IO[bytes], length: int) -> bytes: + ret = fp.read(length) + if len(ret) < length: + raise EOFError("unexpectedly reached end of file") + return ret + + +@functools.lru_cache(maxsize=None) +def lazy_load_file(path: Path) -> ModelPlus: + fp = open(path, 'rb') + first8 = fp.read(8) + fp.seek(0) + if first8[:2] == b'PK': + # A zip file, i.e. PyTorch format + return lazy_load_torch_file(fp, path) + elif struct.unpack(' Iterable[Out]: + '''Parallel map, but with backpressure. If the caller doesn't call `next` + fast enough, this will stop calling `func` at some point rather than + letting results pile up in memory. Specifically, there is a max of one + output value buffered per thread.''' + if concurrency < 2: + yield from map(func, iterable) + # Not reached. + iterable = iter(iterable) + executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] + if use_processpool_executor: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + with executor_class(max_workers=max_workers) as executor: + futures: list[concurrent.futures.Future[Out]] = [] + done = False + for _ in range(concurrency): + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + + while futures: + result = futures.pop(0).result() + while not done and len(futures) < concurrency: + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + yield result + + +def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None: + # Handle special case where the model's vocab size is not set + if params.n_vocab == -1: + raise ValueError( + "The model's vocab size is set to -1 in params.json. Please update it manually." + + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""), + ) + if not isinstance(vocab, Vocab): + return # model has no vocab + + # Check for a vocab size mismatch + if params.n_vocab == vocab.vocab_size: + logger.warning("Ignoring added_tokens.json since model matches vocab size without it.") + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + logger.debug( + f"Padding vocab with {pad_count} token(s) - through " + ) + for i in range(1, pad_count + 1): + vocab.added_tokens_dict[f""] = -1 + vocab.added_tokens_list.append(f"") + vocab.vocab_size = params.n_vocab + return + + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Add the --pad-vocab option and try again." + + raise ValueError(msg) + + +class OutputFile: + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + + def add_meta_arch(self, params: Params) -> None: + name = "LLaMA" + + # TODO: better logic to determine model name + if params.n_ctx == 4096: + name = "LLaMA v2" + elif params.path_model is not None: + name = str(params.path_model.parent).split('/')[-1] + + self.gguf.add_name (name) + self.gguf.add_vocab_size (params.n_vocab) + self.gguf.add_context_length (params.n_ctx) + self.gguf.add_embedding_length (params.n_embd) + self.gguf.add_block_count (params.n_layer) + self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) + + if params.n_experts: + self.gguf.add_expert_count(params.n_experts) + + if params.n_experts_used: + self.gguf.add_expert_used_count(params.n_experts_used) + + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') + + if params.f_rope_freq_base is not None: + self.gguf.add_rope_freq_base(params.f_rope_freq_base) + + if params.rope_scaling_type: + assert params.f_rope_scale is not None + self.gguf.add_rope_scaling_type(params.rope_scaling_type) + self.gguf.add_rope_scaling_factor(params.f_rope_scale) + + if params.n_orig_ctx is not None: + self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx) + + if params.rope_finetuned is not None: + self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) + + if params.ftype is not None: + self.gguf.add_file_type(params.ftype) + + def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + tokens = [] + scores = [] + toktypes = [] + + # NOTE: `all_tokens` returns the base vocabulary and added tokens + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + return tokens, scores, toktypes + + def add_meta_vocab(self, vocab: Vocab) -> None: + # Ensure that tokenizer_model is added to the GGUF model + self.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # Extract model vocabulary for model conversion + tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) + + # Add extracted token information for model conversion + self.gguf.add_token_list(tokens) + self.gguf.add_token_scores(scores) + self.gguf.add_token_types(toktypes) + + def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: + svocab.add_to_gguf(self.gguf) + + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) + if tensor.data_type.name == "I2": + # i2 * n + scale (fp32) + # print(tensor.shape) + # print(data_nbytes) + data_nbytes = data_nbytes // 4 + 32 + # print(data_nbytes) + # scale_name = name.replace('.weight', '_scale.weight') + # scale_shape = [1] + # scale_data_type = np.float32 + # scale_nbytes = 4 + # self.gguf.add_tensor_info(scale_name, scale_shape, scale_data_type, scale_nbytes, raw_dtype=None) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + + def write_meta(self) -> None: + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() + + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() + + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + # elif ftype == GGMLFileType.MostlyI2: + # # ndarrays = bounded_parallel_map( + # # OutputFile.maybe_do_transform, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, use_processpool_executor=True,) + # ndarrays = map(OutputFile.maybe_do_transform, ndarrays_inner) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + ndarray, i2_scale = ndarray + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + logger.info( + f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + + if i2_scale is not None: + i2_scale = np.tile(i2_scale, 8) + ndarray = preprocess_weights(ndarray) + self.gguf.write_tensor_data(ndarray) + self.gguf.write_tensor_data(i2_scale) + else: + self.gguf.write_tensor_data(ndarray) + + def close(self) -> None: + self.gguf.close() + + @staticmethod + def write_vocab_only( + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + + of.write_meta() + + of.close() + + @staticmethod + def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]: + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray, tensor.i2_scale) + + @staticmethod + def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: + dt, arr, i2_scale = item + if not isinstance(dt, QuantizedDataType): + return arr, i2_scale + return dt.quantize(arr) + + @staticmethod + def write_all( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_arch(params) + if isinstance(vocab, Vocab): + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + else: # NoVocab + of.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() + + # tensor data + of.write_tensor_data(ftype, model, concurrency) + + of.close() + + +def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: + wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type + + if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + return GGMLFileType.AllF32 + if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): + return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 + if output_type_str == "i2": + return GGMLFileType.MostlyI2 + + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + + raise ValueError(f"Unexpected combination of types: {name_to_type}") + + +def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: + # for (name, tensor) in model.items(): + # print(name) + # print(tensor) + # print(output_type.type_for_tensor(name, tensor)) + # print(tensor.astype(output_type.type_for_tensor(name, tensor))) + return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items()} + + +def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel: + tmap = gguf.TensorNameMap(ARCH, params.n_layer) + should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + + tmp = model + + # merge experts into one tensor + if params.n_experts and params.n_experts > 0: + for i_l in range(params.n_layer): + for w in range(1, 4): + experts = [] + for e in range(params.n_experts): + if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model: + experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]) + del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"] + elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model: + experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) + del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] + else: + raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight") + tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) + + # HF models permut or pack some of the tensors, so we need to undo that + for i in itertools.count(): + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + logger.debug(f"Permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + logger.debug(f"Unpacking and permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] + else: + break + + # check if is bitnet + if ARCH == 33: + del tmp['output.weight'] + + out: LazyModel = {} + for name, lazy_tensor in model.items(): + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + if name_new is None: + if skip_unknown: + logger.warning(f"Unexpected tensor name: {name} - skipping") + continue + raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") + + if tensor_type in should_skip: + logger.debug(f"skipping tensor {name_new}") + continue + + logger.debug(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor + + return out + + +def nth_multifile_path(path: Path, n: int) -> Path | None: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the nth path in the model. + ''' + # Support the following patterns: + patterns = [ + # - x.00.pth, x.01.pth, etc. + (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. + (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + # x.bin, x.bin.1, etc. + (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + ] + for regex, replacement in patterns: + if re.search(regex, path.name): + new_path = path.with_name(re.sub(regex, replacement, path.name)) + if new_path.exists(): + return new_path + return None + + +def find_multifile_paths(path: Path) -> list[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the whole list of paths in the model. + ''' + ret: list[Path] = [] + for i in itertools.count(): + nth_path = nth_multifile_path(path, i) + if nth_path is None: + break + ret.append(nth_path) + if not ret: + # No matches. This should only happen if the file was named, e.g., + # foo.0, and there was no file named foo. Oh well, try to process it + # as a single file. + return [path] + return ret + + +def load_some_model(path: Path) -> ModelPlus: + '''Load a model of any supported format.''' + # Be extra-friendly and accept either a file or a directory: + if path.is_dir(): + # Check if it's a set of safetensors files first + globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors", "model-int2.pth"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + raise FileNotFoundError(f"Can't find model in directory {path}") + if len(files) > 1: + raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") + path = files[0] + + paths = find_multifile_paths(path) + models_plus: list[ModelPlus] = [] + for path in paths: + logger.info(f"Loading model file {path}") + models_plus.append(lazy_load_file(path)) + + model_plus = merge_multifile_models(models_plus) + return model_plus + + +class VocabFactory: + _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab] + + def __init__(self, path: Path): + self.path = path + + def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocab.name == "bpe" + n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None + return gguf.SpecialVocab( + model_parent_path, + load_merges=load_merges, + special_token_types=None, # Predetermined or passed as a parameter + n_vocab=n_vocab, + ) + + def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: + vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES} + selected_vocabs: dict[str, type[Vocab]] = {} + for vtype in vocab_types: + try: + selected_vocabs[vtype] = vocab_classes[vtype] + except KeyError: + raise ValueError(f"Unsupported vocabulary type {vtype}") from None + + for vtype, cls in selected_vocabs.items(): + try: + vocab = cls(self.path) + break + except FileNotFoundError: + pass # ignore unavailable tokenizers + else: + raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}") + + logger.info(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}") + return vocab + + def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]: + vocab: BaseVocab + if vocab_types is None: + vocab = NoVocab() + else: + vocab = self._create_vocab_by_path(vocab_types) + # FIXME: Respect --vocab-dir? + special_vocab = self._create_special_vocab( + vocab, + model_parent_path, + ) + return vocab, special_vocab + + +def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: + namestr = { + GGMLFileType.AllF32: "f32", + GGMLFileType.MostlyF16: "f16", + GGMLFileType.MostlyQ8_0:"q8_0", + GGMLFileType.MostlyI2: "i2", + }[file_type] + ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" + if ret in model_paths: + logger.error( + f"Error: Default output path ({ret}) would overwrite the input. " + "Please explicitly specify a path using --outfile.") + sys.exit(1) + return ret + + +def do_dump_model(model_plus: ModelPlus) -> None: + print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100 + print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100 + print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100 + for name, lazy_tensor in model_plus.model.items(): + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100 + + +def main(args_in: list[str] | None = None) -> None: + output_choices = ["f32", "f16", "i2"] + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # We currently only support Q8_0 output on little endian systems. + output_choices.append("q8_0") + parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") + parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) + parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") + parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args(args_in) + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + elif args.dump_single or args.dump: + # Avoid printing anything besides the dump output + logging.basicConfig(level=logging.WARNING) + else: + logging.basicConfig(level=logging.INFO) + + if args.no_vocab and args.vocab_only: + raise ValueError("--vocab-only does not make sense with --no-vocab") + + if args.dump_single: + model_plus = lazy_load_file(args.model) + do_dump_model(model_plus) + return + + if not args.vocab_only: + model_plus = load_some_model(args.model) + else: + model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + + if args.dump: + do_dump_model(model_plus) + return + + endianess = gguf.GGUFEndian.LITTLE + if args.big_endian: + endianess = gguf.GGUFEndian.BIG + + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + msg = """\ + The model doesn't have a context size, and you didn't specify one with --ctx + Please specify one with --ctx: + - LLaMA v1: --ctx 2048 + - LLaMA v2: --ctx 4096""" + parser.error(textwrap.dedent(msg)) + params.n_ctx = args.ctx + + if args.outtype: + params.ftype = { + "f32": GGMLFileType.AllF32, + "f16": GGMLFileType.MostlyF16, + "i2" : GGMLFileType.MostlyI2, + "q8_0": GGMLFileType.MostlyQ8_0, + }[args.outtype] + + logger.info(f"params = {params}") + + model_parent_path = model_plus.paths[0].parent + vocab_path = Path(args.vocab_dir or args.model or model_parent_path) + vocab_factory = VocabFactory(vocab_path) + vocab_types = None if args.no_vocab else args.vocab_type.split(",") + vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path) + + if args.vocab_only: + assert isinstance(vocab, Vocab) + if not args.outfile: + raise ValueError("need --outfile if using --vocab-only") + outfile = args.outfile + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess=endianess, pad_vocab=args.pad_vocab) + logger.info(f"Wrote {outfile}") + return + + if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: + vocab = model_plus.vocab + + logger.info(f"Vocab info: {vocab}") + logger.info(f"Special vocab info: {special_vocab}") + model = model_plus.model + model = convert_model_names(model, params, args.skip_unknown) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype) + + params.ftype = ftype + logger.info(f"Writing {outfile}, format {ftype}") + + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) + logger.info(f"Wrote {outfile}") + + +if __name__ == '__main__': + main() diff --git a/utils/e2e_benchmark.py b/utils/e2e_benchmark.py new file mode 100644 index 0000000..4789aac --- /dev/null +++ b/utils/e2e_benchmark.py @@ -0,0 +1,52 @@ +import os +import sys +import logging +import argparse +import subprocess + +def run_command(command, shell=False, log_step=None): + """Run a system command and ensure it succeeds.""" + if log_step: + log_file = os.path.join(args.log_dir, log_step + ".log") + with open(log_file, "w") as f: + try: + subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) + except subprocess.CalledProcessError as e: + logging.error(f"Error occurred while running command: {e}, check details in {log_file}") + sys.exit(1) + else: + try: + subprocess.run(command, shell=shell, check=True) + except subprocess.CalledProcessError as e: + logging.error(f"Error occurred while running command: {e}") + sys.exit(1) + +def run_benchmark(): + bench_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build/bin/llama-bench") + if not os.path.exists(bench_path): + logging.error(f"Benchmark binary not found, please build first.") + sys.exit(1) + command = [ + f'{bench_path}', + '-m', args.model, + '-n', str(args.n_token), + '-ngl', '0', + '-b', '1', + '-t', str(args.threads), + '-p', str(args.n_prompt), + '-r', '5' + ] + run_command(command) + +def parse_args(): + parser = argparse.ArgumentParser(description='Setup the environment for running the inference') + parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True) + parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128) + parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512) + parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) + return parser.parse_args() + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + args = parse_args() + run_benchmark() \ No newline at end of file diff --git a/utils/generate-dummy-bitnet-model.py b/utils/generate-dummy-bitnet-model.py new file mode 100644 index 0000000..be3f6cd --- /dev/null +++ b/utils/generate-dummy-bitnet-model.py @@ -0,0 +1,1048 @@ +#!/usr/bin/env python3 + +# dummy model generation script based on convert-hf-to-gguf-bitnet.py +from __future__ import annotations +import sys +from pathlib import Path + +import numpy as np +import configparser +import logging +import argparse +import contextlib +import json +import os +import re +import sys +from abc import ABC, abstractmethod +from enum import IntEnum +from pathlib import Path +from hashlib import sha256 +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast, Tuple, Iterable + +# Necessary to load the local gguf package +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from gguf import GGUFWriter, GGUFReader, RopeScalingType, TokenType, GGMLQuantizationType # noqa: E402 +if TYPE_CHECKING: + from torch import Tensor + +import torch +import gguf +logger = logging.getLogger("generate-dummy-bitnet-model") + +###### MODEL HPARAMS CONFIGURATION ###### + +model_config = { + "125M": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_hidden_layers": 11, + "num_attention_heads": 12 + }, + "350M": { + "hidden_size": 1024, + "intermediate_size": 3072, + "num_hidden_layers": 24, + "num_attention_heads": 16 + }, + "1B": { + "hidden_size": 2048, + "intermediate_size": 3584, + "num_hidden_layers": 24, + "num_attention_heads": 32 + }, + "1.5B": { + "hidden_size": 1536, + "intermediate_size": 9216, + "num_hidden_layers": 28, + "num_attention_heads": 32 + }, + "2.7B": { + "hidden_size": 3072, + "intermediate_size": 7680, + "num_hidden_layers": 24, + "num_attention_heads": 32 + }, + "3.8B": { + "hidden_size": 3840, + "intermediate_size": 8192, + "num_hidden_layers": 24, + "num_attention_heads": 32 + }, + "7B": { + "hidden_size": 4096, + "intermediate_size": 12032, + "num_hidden_layers": 32, + "num_attention_heads": 32 + }, + "13B": { + "hidden_size": 5120, + "intermediate_size": 13824, + "num_hidden_layers": 40, + "num_attention_heads": 40 + }, + "30B": { + "hidden_size": 6656, + "intermediate_size": 16384, + "num_hidden_layers": 60, + "num_attention_heads": 52 + }, + "70B": { + "hidden_size": 8192, + "intermediate_size": 24576, + "num_hidden_layers": 80, + "num_attention_heads": 64 + }, + "100B": { + "hidden_size": 8192, + "intermediate_size": 45568, + "num_hidden_layers": 72, + "num_attention_heads": 64 + } +} + + +###### MODEL DEFINITIONS ###### + +class SentencePieceTokenTypes(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + + +AnyModel = TypeVar("AnyModel", bound="type[Model]") + + +class Model(ABC): + _model_classes: dict[str, type[Model]] = {} + + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): + self.dir_model = dir_model + self.ftype = ftype + self.fname_out = fname_out + self.is_big_endian = is_big_endian + self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.use_temp_file = use_temp_file + self.is_safetensors = self._is_model_safetensors() + self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") + self.part_names = self._get_part_names() + self.hparams = Model.load_hparams(self.dir_model) + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + @property + @abstractmethod + def model_arch(self) -> gguf.MODEL_ARCH: + pass + + def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in self.hparams), None) + if key is not None: + return self.hparams[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + def set_vocab(self): + self._set_vocab_gpt2() + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for part_name in self.part_names: + logger.info(f"gguf: loading model part '{part_name}'") + ctx: ContextManager[Any] + if self.is_safetensors: + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + else: + ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + + with ctx as model_part: + for name in model_part.keys(): + data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] + yield name, data + + def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: + if key not in gguf.MODEL_TENSORS[self.model_arch]: + return False + key_name: str = gguf.TENSOR_NAMES[key] + if "{bid}" in key_name: + if bid is None: + return False + key_name = key_name.format(bid=bid) + else: + if bid is not None: + return False + return name == (key_name + suffix) + + def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: + new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + return new_name + + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + + if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: + self.gguf_writer.add_context_length(n_ctx) + logger.info(f"gguf: context length = {n_ctx}") + + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + logger.info(f"gguf: feed forward length = {n_ff}") + + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + logger.info(f"gguf: expert count = {n_experts}") + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + raise ValueError(f"Can not map tensor {name!r}") + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + logger.info(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + def write(self): + self.write_tensors() + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file() + self.gguf_writer.close() + + def write_vocab(self): + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.close() + + @staticmethod + def count_model_parts(dir_model: Path, prefix: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.endswith(prefix): + num_parts += 1 + + return num_parts + + @staticmethod + def load_hparams(dir_model): + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) + + @classmethod + def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: + assert names + + def func(modelcls: type[Model]): + for name in names: + cls._model_classes[name] = modelcls + return modelcls + return func + + @classmethod + def from_model_architecture(cls, arch): + try: + return cls._model_classes[arch] + except KeyError: + raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + + def _is_model_safetensors(self) -> bool: + return Model.count_model_parts(self.dir_model, ".safetensors") > 0 + + def _get_part_names(self): + if self.is_safetensors: + if self.num_parts == 1: # there's only one .safetensors file + return ("model.safetensors",) + return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + + if self.num_parts == 1: # there's only one .bin file + return ("pytorch_model.bin",) + return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + + # used for GPT-2 BPE and WordPiece vocabs + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + assert max(tokenizer.vocab.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + return tokens, toktypes, tokpre + + # NOTE: this function is generated by convert-hf-to-gguf-update.py + # do not modify it manually! + # ref: https://github.com/ggerganov/llama.cpp/pull/6920 + def get_vocab_base_pre(self, tokenizer) -> str: + # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that + # is specific for the BPE pre-tokenizer used by the model + # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can + # use in llama.cpp to implement the same pre-tokenizer + + chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + logger.debug(f"chktok: {chktok}") + logger.debug(f"chkhsh: {chkhsh}") + + res = None + + # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script + # or pull the latest version of the model from Huggingface + # don't edit the hashes manually! + if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": + # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B + res = "llama-bpe" + if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": + # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base + res = "deepseek-llm" + if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821": + # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base + res = "deepseek-coder" + if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": + # ref: https://huggingface.co/tiiuae/falcon-7b + res = "falcon" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 + res = "bert-bge" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/mosaicml/mpt-7b + res = "mpt" + if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34": + # ref: https://huggingface.co/bigcode/starcoder2-3b + res = "starcoder" + if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454": + # ref: https://huggingface.co/openai-community/gpt2 + res = "gpt-2" + if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff": + # ref: https://huggingface.co/smallcloudai/Refact-1_6-base + res = "refact" + if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": + # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 + res = "command-r" + + if res is None: + logger.warning("\n") + logger.warning("**************************************************************************************") + logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") + logger.warning("** There are 2 possible reasons for this:") + logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet") + logger.warning("** - the pre-tokenization config has changed upstream") + logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.") + logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + logger.warning("**") + logger.warning(f"** chkhsh: {chkhsh}") + logger.warning("**************************************************************************************") + logger.warning("\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + + logger.debug(f"tokenizer.ggml.pre: {repr(res)}") + logger.debug(f"chkhsh: {chkhsh}") + + return res + + def _set_vocab_sentencepiece(self): + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + + for key in added_tokens_json: + key = key.encode("utf-8") + if key not in tokens: + tokens.append(key) + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.USER_DEFINED) + + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(f"[PAD{i}]") + scores.append(-1000.0) + toktypes.append(SentencePieceTokenTypes.UNUSED) + + assert len(tokens) == vocab_size + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + +# TL1 + +def process_tl1(weight, BM, BY, bm, by, M, K): + final_weight = [] + + # split in row with size of BM (160) + outer_BM_weights = np.split(weight, (M // BM), axis=0) + for outer_BM_weight in outer_BM_weights: + # split in col with size of by (16index * 2 == 32nums) + outer_BY_weights = np.split(outer_BM_weight, (K // BY), axis=1) + for outer_BY_weight in outer_BY_weights: + # split in row with size of bm (32) + inner_bm_weights = np.split(outer_BY_weight, (BM // bm), axis=0) + for inner_bm_weight in inner_bm_weights: + # split in col with size of by (2index * 2 == 4nums) + inner_by_weights = np.split(inner_bm_weight, (BY // by), axis=1) + for inner_by_weight in inner_by_weights: + # 16 * 6 minor + minor_bm_weights = np.split(inner_by_weight, (bm // 16), axis=0) + for minor_bm_weight in minor_bm_weights: + minor_by_weights = np.split(minor_bm_weight, (by // 4), axis=1) + for minor in minor_by_weights: + minor_weight = np.split(minor, 2, axis=1) + hi_weight = minor_weight[0].astype(np.uint8) << 4 + lo_weight = minor_weight[1].astype(np.uint8) + func_weight = lo_weight + hi_weight + final_weight.append(func_weight) + + weight = np.array(final_weight, dtype=np.uint8) + return weight + +# based on t_mac.utils.preprocess_weights +def preprocess_weights_tl1( + w: np.ndarray, + bits = 2, + g = 4, +) -> Tuple[np.ndarray, np.ndarray]: + M, K = w.shape + weight = w + weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32) + weight = np.sign(weight) + weight_num = np.prod(weight.shape) + model_size = args.model_size + + KEMD = model_config[model_size]['hidden_size'] + # outer loop + BMEMD = 256 + BYEMD = 256 + + # inner loop (32row 32num/16index) + bmEMD = 32 + byEMD = 8 + + KGQA = model_config[model_size]['intermediate_size'] + + BMGQA = 256 + BYGQA = 256 + + bmGQA = 32 + byGQA = 8 + + weight = np.reshape(weight, (weight_num // 2, 2)) + hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3) + lo_weight = np.split(weight, 2, axis=1)[1] + + weight = np.reshape((hi_weight + lo_weight), weight_num // 2) + + # row-major index + weight = weight + 4 + weight = np.reshape(weight, (M, K // 2)).astype(np.uint8) + + if K == KEMD: + weight = process_tl1(weight, BMEMD, BYEMD, bmEMD, byEMD, M, K) + elif K == KGQA: + weight = process_tl1(weight, BMGQA, BYGQA, bmGQA, byGQA, M, K) + else: + raise NotImplementedError + + return weight + + +def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): + weight = np.reshape(weight, (weight_num // 2, 2)) + hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3) + lo_weight = np.split(weight, 2, axis=1)[1] + + weight = np.reshape((hi_weight + lo_weight), weight_num // 2) + + # row-major index + weight = weight + 4 + weight = np.reshape(weight, (M, K // 2)).astype(np.uint8) + + outer_BM_weights = np.split(weight, (M // BM), axis=0) + for outer_BM_weight in outer_BM_weights: + # split in col with size of by (32index * 3 == 96nums) + outer_BY_weights = np.split(outer_BM_weight, (K // BY), axis=1) + for outer_BY_weight in outer_BY_weights: + # split in row with size of bm (32) + inner_bm_weights = np.split(outer_BY_weight, (BM // bm), axis=0) + for inner_bm_weight in inner_bm_weights: + # split in col with size of by (2index * 2 == 4nums) + inner_by_weights = np.split(inner_bm_weight, (BY // by), axis=1) + for inner_by_weight in inner_by_weights: + func_weights = np.split(inner_by_weight, 2, axis=1) + + left_weight = func_weights[0] + left_sub_weights = np.split(left_weight, 4, axis=0) + new_left_weight = np.reshape( + np.concatenate([left_sub_weights[0], left_sub_weights[2], + left_sub_weights[1], left_sub_weights[3]], axis=0, dtype=np.uint8), + (bm)) + + right_weight = func_weights[1] + right_sub_weights = np.split(right_weight, 4, axis=0) + new_right_weight = np.reshape( + np.concatenate([right_sub_weights[0], right_sub_weights[2], + right_sub_weights[1], right_sub_weights[3]], axis=0, dtype=np.uint8), + (bm)) + hi_weight = new_left_weight.astype(np.uint8) << 4 + lo_weight = new_right_weight + func_weight = hi_weight + lo_weight + func_weight = np.reshape(func_weight, bm * by // 4) + final_weight.append(func_weight) + +def preprocess_three_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): + weight = np.reshape(weight, (weight_num // 3, 3)) + split_weights = np.split(weight, 3, axis=1) + first_weight = np.multiply(split_weights[0], 9) + second_weight = np.multiply(split_weights[1], 3) + third_weight = split_weights[2] + + weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3) + sign_weight = np.sign(weight) + 2 + sign_weight = np.where(sign_weight > 1, 0, sign_weight) + weight = np.abs(weight) + + # row-major index + weight = np.reshape(weight, (M, K // 3)).astype(np.uint8) + sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8) + # print(weight) + + # split in row with size of BM (160) + outer_BM_weights = np.split(weight, (M // BM), axis=0) + for outer_BM_weight in outer_BM_weights: + # split in col with size of by (32index * 3 == 96nums) + outer_BY_weights = np.split(outer_BM_weight, (K // BY), axis=1) + for outer_BY_weight in outer_BY_weights: + # split in row with size of bm (32) + inner_bm_weights = np.split(outer_BY_weight, (BM // bm), axis=0) + for inner_bm_weight in inner_bm_weights: + # split in col with size of by (2index * 3 == 6nums) + inner_by_weights = np.split(inner_bm_weight, (BY // by), axis=1) + for inner_by_weight in inner_by_weights: + func_weights = np.split(inner_by_weight, 2, axis=1) + + left_weight = func_weights[0] + left_sub_weights = np.split(left_weight, 4, axis=0) + new_left_weight = np.reshape( + np.concatenate([left_sub_weights[0], left_sub_weights[2], + left_sub_weights[1], left_sub_weights[3]], axis=0, dtype=np.uint8), + (bm)) + + right_weight = func_weights[1] + right_sub_weights = np.split(right_weight, 4, axis=0) + + new_right_weight = np.reshape( + np.concatenate([right_sub_weights[0], right_sub_weights[2], + right_sub_weights[1], right_sub_weights[3]], axis=0, dtype=np.uint8), + (bm)) + hi_weight = new_left_weight.astype(np.uint8) << 4 + lo_weight = new_right_weight + func_weight = hi_weight + lo_weight + func_weight = np.reshape(func_weight, bm * by // 6) + final_weight.append(func_weight) + + sign_weight_list = [] + sign_outer_BM_weights = np.split(sign_weight, (M // BM), axis=0) + for sign_outer_BM_weight in sign_outer_BM_weights: + # split in col with size of by (32index * 3 == 96nums) + sign_outer_BY_weights = np.split(sign_outer_BM_weight, (K // BY), axis=1) + for sign_outer_BY_weight in sign_outer_BY_weights: + # split in row with size of bm (32) + sign_inner_bm_weights = np.split(sign_outer_BY_weight, (BM // bm), axis=0) + for sign_inner_bm_weight in sign_inner_bm_weights: + # split in col with size of by (4index * 3 == 12nums) + sign_inner_by_weights = np.split(sign_inner_bm_weight, (BY // (by * 4)), axis=1) + for sign_inner_by_weight in sign_inner_by_weights: + func_weight = np.split(sign_inner_by_weight, 8, axis=1) + combine_weight = np.zeros((16, 1), dtype=np.uint16) + for i in range(len(func_weight)): + min_weight = np.split(func_weight[i], 2) + min_top_weight = min_weight[0].astype(np.uint16) << 15 - (2 * i) + min_bot_weight = min_weight[1].astype(np.uint16) << 15 - (2 * i + 1) + combine_weight += min_top_weight + combine_weight += min_bot_weight + combine_weight = combine_weight.view(np.uint8) + # combine_weight = combine_weight[:, [1, 0]] + combine_weight = np.reshape(combine_weight, bm) + sign_weight_list.append(combine_weight) + final_weight.extend(sign_weight_list) + final_weight.extend(sign_weight_list) + + +def preprocess_weights_tl2( + w: np.ndarray, + bits = 2, + g = 4, +) -> Tuple[np.ndarray, np.ndarray]: + M, K = w.shape + weight = w + weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32) + weight = np.sign(weight) + weight_num = np.prod(weight.shape) + + # for three num 6 bit -> + + # outer loop + KEMD = 1536 + BMEMD = 256 + BYEMD = 96 + + KGQA = 4096 + BMGQA = 128 + BYGQA = 96 + + # inner loop (32row 32num/16index) + bm3 = 32 + by3 = 6 + + if K == KEMD: + BM3 = BMEMD + BY3 = BYEMD + elif K == KGQA: + BM3 = BMGQA + BY3 = BYGQA + else: + raise NotImplementedError + + BM2 = BM3 + BY2 = 32 + # inner loop (32row 32num/16index) + bm2 = 32 + by2 = 4 + + if (weight.shape[1] % BY3 != 0): + slice_k_idx = weight.shape[1] - weight.shape[1] % BY3 + slice_weights = np.split(weight, [slice_k_idx], axis=1) + three_weight = slice_weights[0] + two_weight = slice_weights[1] + else: + three_weight = weight + + final_weight = [] + + preprocess_three_weights_tl2(three_weight.shape[0], + three_weight.shape[1], + three_weight.shape[0] * three_weight.shape[1], + BM3, + BY3, + bm3, + by3, + three_weight, + final_weight) + + if (weight.shape[1] % BY3 != 0): + preprocess_two_weights_tl2( two_weight.shape[0], + two_weight.shape[1], + two_weight.shape[0] * two_weight.shape[1], + BM2, + BY2, + bm2, + by2, + two_weight, + final_weight) + + weight = np.array(final_weight, dtype=np.uint8) + + return weight + + +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + params: str = "" + + def set_params(self, params: str): + self.params = params + hp_config = model_config[self.params] + self.hparams["hidden_size"] = hp_config["hidden_size"] + self.hparams["intermediate_size"] = hp_config["intermediate_size"] + self.hparams["num_hidden_layers"] = hp_config["num_hidden_layers"] + self.hparams["num_attention_heads"] = hp_config["num_attention_heads"] + self.hparams["num_key_value_heads"] = hp_config["num_attention_heads"] + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def transform_to_tl1(self, x: np.ndarray): + scale = np.max(np.abs(x)) + # res = np.round(x / scale + 2).astype(np.uint8) + res = preprocess_weights_tl1(x) + return res, scale + + def transform_to_tl2(self, x: np.ndarray): + scale = np.max(np.abs(x)) + # res = np.round(x / scale + 2).astype(np.uint8) + res = preprocess_weights_tl2(x) + return res, scale + + # generate dummy model + def generate_tensors(self) -> Iterator[tuple[str, np.ndarray]]: + hp_config = model_config[self.params] + hidden_size = hp_config["hidden_size"] + intermediate_size = hp_config["intermediate_size"] + num_hidden_layers = hp_config["num_hidden_layers"] + num_attention_heads = hp_config["num_attention_heads"] + + # generate dummy tensors + tensor = torch.randn((32002, hidden_size), dtype=torch.float32) + yield ("model.embed_tokens.weight", tensor) + for i in range(num_hidden_layers): + yield f"model.layers.{i}.input_layernorm.weight", torch.randn((hidden_size,), dtype=torch.float32) + yield f"model.layers.{i}.mlp.down_proj.weight", torch.randn((hidden_size, intermediate_size), dtype=torch.float32) + yield f"model.layers.{i}.mlp.ffn_layernorm.weight", torch.randn((intermediate_size,), dtype=torch.float32) + yield f"model.layers.{i}.mlp.gate_proj.weight", torch.randn((intermediate_size, hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.mlp.up_proj.weight", torch.randn((intermediate_size, hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.post_attention_layernorm.weight", torch.randn((hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.inner_attn_ln.weight", torch.randn((hidden_size,), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.k_proj.weight", torch.randn((hidden_size, hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.o_proj.weight", torch.randn((hidden_size, hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.q_proj.weight", torch.randn((hidden_size, hidden_size), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.rotary_emb.inv_freq", torch.randn((hidden_size // (num_attention_heads * 2),), dtype=torch.float32) + yield f"model.layers.{i}.self_attn.v_proj.weight", torch.randn((hidden_size, hidden_size), dtype=torch.float32) + tensor = torch.randn((hidden_size,), dtype=torch.float32) + yield("model.norm.weight", tensor) + + + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # quant weight to i2 (in fp16) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = self.weight_quant(data_torch) + + return [(self.map_tensor_name(name), data_torch)] + + def write_tensors(self): + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + + for name, data_torch in self.generate_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # use the first number-like part of the tensor name as the block id + bid = None + for part in name.split("."): + if part.isdecimal(): + bid = int(part) + break + + for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): + data: np.ndarray = data # type hint + data_shape = data.shape + n_dims = len(data.shape) + data_dtype = data.dtype + data_qtype: gguf.GGMLQuantizationType | None = None + + # when both are True, f32 should win + # extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) + # extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + extra_f32 = False + extra_f16 = False + + # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + extra_f32 = any(cond for cond in ( + extra_f32, + n_dims == 1, + new_name.endswith("_norm.weight"), + )) + + # Some tensor types are always in float32 + extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + # for debug / delete when inference + gguf.MODEL_TENSOR.TOKEN_EMBD, + )) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + extra_f16 = any(cond for cond in ( + extra_f16, + (name.endswith(".weight") and n_dims >= 2), + )) + + suit_i2 = True + if name.endswith('embed_tokens.weight') or name.endswith('norm.weight'): + suit_i2 = False + + i2_scale = None + if self.ftype != gguf.GGMLQuantizationType.F32 and extra_f16 and not extra_f32: + if self.ftype == gguf.GGMLQuantizationType.TL1 and suit_i2: + data, i2_scale = self.transform_to_tl1(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL1 + elif self.ftype == gguf.GGMLQuantizationType.TL2 and suit_i2: + data, i2_scale = self.transform_to_tl2(data) + assert data.dtype == np.uint8 + assert i2_scale.dtype == np.float32 + data_qtype = gguf.GGMLQuantizationType.TL2 + else: # default to float16 for quantized tensors + if data_dtype != np.float16: + data = data.astype(np.float16) + data_qtype = gguf.GGMLQuantizationType.F16 + + if data_qtype is None: # by default, convert to float32 + if data_dtype != np.float32: + data = data.astype(np.float32) + data_qtype = gguf.GGMLQuantizationType.F32 + + shape = data_shape + # shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + + # n_dims is implicit in the shape + logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(new_name, data, raw_shape=shape, raw_dtype=data_qtype) + if i2_scale is not None: + self.gguf_writer.add_tensor(new_name + "_scale", i2_scale, raw_dtype=gguf.GGMLQuantizationType.F32) + +ftype_map = { + "f32": gguf.GGMLQuantizationType.F32, + "f16": gguf.GGMLQuantizationType.F16, + "tl1" : gguf.GGMLQuantizationType.TL1, + "tl2" : gguf.GGMLQuantizationType.TL2, +} + +def main() -> None: + dir_model = args.model + fname_out = args.outfile + model_size = args.model_size + + hparams = Model.load_hparams(dir_model) + + with torch.inference_mode(): + model_class = Model.from_model_architecture(hparams["architectures"][0]) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file) + model_instance.set_params(model_size) + + logger.info("Set model parameters") + model_instance.set_gguf_parameters() + + logger.info("Set model tokenizer") + model_instance.set_vocab() + + if args.vocab_only: + logger.info(f"Exporting model vocab to '{fname_out}'") + model_instance.write_vocab() + else: + logger.info(f"Exporting model to '{fname_out}'") + model_instance.write() + + logger.info(f"Model successfully exported to '{fname_out}'") + +def read_gguf_file(gguf_file_path): + """ + Reads and prints key-value pairs and tensor information from a GGUF file in an improved format. + + Parameters: + - gguf_file_path: Path to the GGUF file. + """ + + reader = GGUFReader(gguf_file_path) + + # List all key-value pairs in a columnized format + print("Key-Value Pairs:") # noqa: NP100 + max_key_length = max(len(key) for key in reader.fields.keys()) + for key, field in reader.fields.items(): + value = field.parts[field.data[0]] + print(f"{key:{max_key_length}} : {value}") # noqa: NP100 + print("----") # noqa: NP100 + + # List all tensors + print("Tensors:") # noqa: NP100 + tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}" + print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization")) # noqa: NP100 + print("-" * 80) # noqa: NP100 + for tensor in reader.tensors: + shape_str = "x".join(map(str, tensor.shape)) + size_str = str(tensor.n_elements) + quantization_str = tensor.tensor_type.name + print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100 + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate a dummy bitnet model with GGUF format") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "--outtype", type=str, choices=ftype_map.keys(), default="f16", + help="output format - use f32 for float32, f16 for float16", + ) + parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser.add_argument( + "model", type=Path, + help="directory containing model file", + ) + parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)") + parser.add_argument("--model-name", type=str, default=None, help="name of the model") + parser.add_argument("--model-size", type=str, default="7B", help="size of the model") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + return parser.parse_args() + +if __name__ == '__main__': + args = parse_args() + main() \ No newline at end of file diff --git a/utils/kernel_tuning.py b/utils/kernel_tuning.py new file mode 100644 index 0000000..e69de29