diff --git a/3rdparty/llama.cpp b/3rdparty/llama.cpp
index 40ed0f2..1f86f05 160000
--- a/3rdparty/llama.cpp
+++ b/3rdparty/llama.cpp
@@ -1 +1 @@
-Subproject commit 40ed0f290203a9a78540b8f7eb18bd828043fe21
+Subproject commit 1f86f058de0c3f4098dedae2ae8653c335c868a1
diff --git a/README.md b/README.md
index 798c0e9..2cc2a73 100644
--- a/README.md
+++ b/README.md
@@ -10,10 +10,10 @@ bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.5
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
-
-
+**Latest optimization** introduces parallel kernel implementations with configurable tiling and embedding quantization support, achieving **1.15x to 2.1x** additional speedup over the original implementation across different hardware platforms and workloads. For detailed technical information, see the [optimization guide](src/README.md).
+
+
->The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
## Demo
@@ -22,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
## What's New:
-- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) 
+- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/XsquirrelC/BitNet/blob/main/src/README.md) 
+- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md)
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
diff --git a/assets/intel_performance.jpg b/assets/intel_performance.jpg
deleted file mode 100644
index 38a1bcf..0000000
Binary files a/assets/intel_performance.jpg and /dev/null differ
diff --git a/assets/m2_performance.jpg b/assets/m2_performance.jpg
deleted file mode 100644
index 9b59348..0000000
Binary files a/assets/m2_performance.jpg and /dev/null differ
diff --git a/assets/performance.png b/assets/performance.png
new file mode 100644
index 0000000..078fd3f
Binary files /dev/null and b/assets/performance.png differ
diff --git a/include/gemm-config.h b/include/gemm-config.h
new file mode 100644
index 0000000..6a88c42
--- /dev/null
+++ b/include/gemm-config.h
@@ -0,0 +1,35 @@
+#define ACT_PARALLEL
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+#if defined(ACT_PARALLEL)
+ #define ROW_BLOCK_SIZE 4
+ #define COL_BLOCK_SIZE 128
+ #define PARALLEL_SIZE 4
+#else
+ #define ROW_BLOCK_SIZE 128
+ #define COL_BLOCK_SIZE 32
+ #define PARALLEL_SIZE 8
+#endif // ACT_PARALLEL
+#elif defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_DOTPROD)
+#if defined(ACT_PARALLEL)
+ #define ROW_BLOCK_SIZE 8
+ #define COL_BLOCK_SIZE 256
+ #define PARALLEL_SIZE 8
+#else
+ #define ROW_BLOCK_SIZE 64
+ #define COL_BLOCK_SIZE 16
+ #define PARALLEL_SIZE 2
+#endif // ACT_PARALLEL
+#else
+#if defined(ACT_PARALLEL)
+ #define ROW_BLOCK_SIZE 8
+ #define COL_BLOCK_SIZE 256
+ #define PARALLEL_SIZE 4
+#else
+ #define ROW_BLOCK_SIZE 128
+ #define COL_BLOCK_SIZE 32
+ #define PARALLEL_SIZE 4
+#endif // ACT_PARALLEL
+#endif // __ARM_FEATURE_DOTPROD
+#endif // __AVX__
+
diff --git a/setup_env.py b/setup_env.py
index f15d65f..3bf5fb8 100644
--- a/setup_env.py
+++ b/setup_env.py
@@ -64,8 +64,8 @@ SUPPORTED_QUANT_TYPES = {
}
COMPILER_EXTRA_ARGS = {
- "arm64": ["-DBITNET_ARM_TL1=ON"],
- "x86_64": ["-DBITNET_X86_TL2=ON"]
+ "arm64": ["-DBITNET_ARM_TL1=OFF"],
+ "x86_64": ["-DBITNET_X86_TL2=OFF"]
}
OS_EXTRA_ARGS = {
diff --git a/src/README.md b/src/README.md
new file mode 100644
index 0000000..f713b9a
--- /dev/null
+++ b/src/README.md
@@ -0,0 +1,205 @@
+# BitNet CPU Inference Optimization
+
+This update provides significant performance improvements for BitNet inference on CPU through paralleled kernel implementations, native I2_S GEMM/GEMV support, configurable tiling block size and embedding quantization.
+
+## Update
+
+- **Parallel Weight & Activation Computation**
+ Implemented parallel processing of weights and activations in the W2A8 vet_dot kernel, achieving improved throughput on both x86 and ARM architectures.
+
+- **Native I2_S GEMM & GEMV Support**
+ Integrated I2_S GEMM and GEMV operations into ggml library, making them fully compatible with the llama.cpp architecture. This enables seamless integration with existing inference pipelines.
+
+- **Configurable Tiling & Parallelism**
+ Introduced configurable GEMM & GEMV block sizes and parallelism levels, allowing performance fine-tuning for different CPU architectures.
+
+- **Embedding Quantization**
+ Added support for embedding layer quantization with Q6_K format, reducing memory footprint and improving inference speed while maintaining high accuracy.
+
+## Usage
+
+### Configuration Options
+
+The `include/gemm-config.h` file controls kernel behavior:
+
+```c
+#define ROW_BLOCK_SIZE 4
+#define COL_BLOCK_SIZE 128
+#define PARALLEL_SIZE 4
+```
+
+Modify these values based on your CPU cache size and architecture for optimal performance. Users can fine-tune performance on their machine through `include/gemm-config.h`.
+
+### Enabling Embedding Quantization
+
+To use embedding quantization for additional speedup:
+
+**Using setup_env.py:**
+```bash
+python setup_env.py --quant-embd
+```
+This automatically converts embeddings to Q6_K format.
+
+**Manual conversion:**
+```bash
+build/bin/llama-quantize --token-embedding-type Q6_K models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf I2_S 1 1
+```
+
+## Optimizations
+
+### 1. Weight & Activation Parallelism
+
+The kernel implements two parallelization strategies:
+
+- **Weight Parallel:** Processes multiple weight rows/columns in a single kernel call, reducing kernel launch overhead.
+
+- **Activation Parallel:** Built on top of weight parallel, amortizes the I2_S weight unpacking cost across multiple activation elements.
+
+**Recommendation:** For I2_S quantization format, activation parallel is recommended due to the unpack operation benefits. The current kernel defaults to activation parallel.
+
+**Kernel Performance Comparison:**
+
+
+
+Test configuration: AMD EPYC 7V13 (x86), 1 threads, time in milliseconds (mean±std)
+
+| Matrix Size | No Parallel | Weight Parallel | Activation Parallel |
+|:---:|:---:|:---:|:---:|
+| [1, 2048] × [2048, 2048] | 0.075±0.012 | **0.058±0.007** | 0.076±0.011 |
+| [32, 2048] × [2048, 2048] | 2.400±0.041 | 1.599±0.020 | **1.202±0.018** |
+| [128, 2048] × [2048, 2048] | 10.820±0.039 | 6.458±0.168 | **5.805±0.039** |
+| [256, 2048] × [2048, 2048] | 21.669±0.080 | 12.739±0.183 | **11.882±0.040** |
+| [512, 2048] × [2048, 2048] | 43.257±0.083 | 25.680±0.335 | **23.342±0.082** |
+| [2048, 2048] × [2048, 2048] | 173.175±0.214 | 103.112±0.552 | **93.276±0.612** |
+| [128, 2048] × [2048, 8192] | 43.345±0.090 | 25.541±0.239 | **23.528±0.052** |
+| [128, 8192] × [8192, 2048] | 38.085±0.162 | 23.866±0.096 | **22.569±0.132** |
+
+
+
+### 2. GEMM/GEMV Integration with llama.cpp
+
+Integrated I2_S quantization format into llama.cpp's compute graph:
+
+- **GEMV Operations:** Optimized matrix-vector multiplication for token generation.
+- **GEMM Operations:** Efficient matrix-matrix multiplication for prompt processing.
+- **Tiling Strategy:** Configurable block sizes for optimal cache utilization.
+
+### 3. Configuration Fine-tuning
+
+Fine-tuning kernel parameters for optimal performance on specific hardware:
+
+**Example Configuration (x86, AMD EPYC 7V13):**
+- Method: Activation Parallel
+- Threads: 8
+- Workload: 128 prompt tokens (pp128)
+
+**Fine-tuning Parameters:**
+- **Parallelism Degree:** [2, 4, 8]
+- **Row Block Size:** [2, 4, 8, 16, 32]
+- **Column Block Size:** [32, 64, 128, 256, 512, 1024]
+
+**Fine-tuning Results:**
+
+
+
+

+
+*Shows throughput (tokens/s) for various configurations.*
+
+
+
+**Optimal Configuration:** Under this setup (x86, 8 threads, pp128), the best performance is achieved with parallelism degree = 4, row block size = 4, and column block size = 128.
+
+### 4. Embedding Quantization
+
+Evaluated multiple embedding quantization formats to balance memory usage, model quality, and inference speed:
+
+**Perplexity Comparison:**
+
+
+
+Test configuration: BitNet-b1.58-2B-4T, TG128
+
+| Embedding Type | Wikitext | PTB | LAMBADA | IMDB | AG NEWS |
+|:---:|:---:|:---:|:---:|:---:|:---:|
+| **F32** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
+| **F16** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
+| **Q8_0** | 17.1197±0.1280 | 33.1181±0.4893 | 43.2891±0.6364 | 29.3133±0.2892 | 36.7740±0.3920 |
+| **Q6_K** | 17.1487±0.1282 | 33.2203±0.4914 | 43.3046±0.6362 | 29.3491±0.2897 | 36.7972±0.3921 |
+| **Q5_0** | 17.2379±0.1288 | 33.2439±0.4907 | 43.4631±0.6379 | 29.5481±0.2920 | 36.8539±0.3924 |
+| **Q4_0** | 17.3529±0.1300 | 33.7754±0.5001 | 44.4552±0.6559 | 30.1044±0.2978 | 37.3985±0.3997 |
+| **Q3_K** | 17.6434±0.1320 | 34.3914±0.5089 | 45.4591±0.6735 | 30.8476±0.3069 | 39.5692±0.4259 |
+| **I2_S** | N/A | N/A | N/A | N/A | N/A |
+
+**N/A indicates model failure due to extreme quantization.*
+
+
+
+**Inference Speed Comparison:**
+
+
+
+

+
+*Token generation throughput (tg128) for different embedding quantization types.*
+
+
+
+**Recommendation:** Based on comprehensive evaluation of memory footprint, perplexity preservation, and inference speed, **Q6_K** is selected as the optimal embedding quantization format.
+
+## Performance
+
+Comparison of optimized parallel kernels vs. original implementation:
+
+**Test Configuration:**
+- Model: BitNet-b1.58-2B-4T
+- Hardware: AMD EPYC 7V13
+- Threads: 1 / 2 / 4 / 8 / 12 / 16
+- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
+- Method: Activation Parallel
+
+
+
+

+
+
+
+**Test Configuration:**
+- Model: BitNet-b1.58-2B-4T
+- Hardware: Intel i7-13800H
+- Threads: 1 / 2 / 4 / 6
+- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
+- Method: Activation Parallel
+
+
+
+

+
+
+
+**Test Configuration:**
+- Model: BitNet-b1.58-2B-4T
+- Hardware: Cobalt 100
+- Threads: 1 / 2 / 4 / 8
+- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
+- Method: Activation Parallel
+
+
+
+

+
+
+
+## Technical Details
+
+### Key Files Modified
+
+- `src/ggml-bitnet-mad.cpp`: Parallel kernel implementations
+- `3rdparty/llama.cpp/ggml/src/ggml.c`: GEMM/GEMV integration
+- `include/gemm-config.h`: Configuration file
+
+### Supported Architectures
+
+- ✅ x86-64 with AVX2
+- ✅ ARM with NEON
+- ✅ ARM with DOTPROD extension
diff --git a/src/assets/embedding_throughput.png b/src/assets/embedding_throughput.png
new file mode 100644
index 0000000..b3ebb82
Binary files /dev/null and b/src/assets/embedding_throughput.png differ
diff --git a/src/assets/fine_tuning_result.png b/src/assets/fine_tuning_result.png
new file mode 100644
index 0000000..5bcab69
Binary files /dev/null and b/src/assets/fine_tuning_result.png differ
diff --git a/src/assets/performance_comparison_amd_epyc.png b/src/assets/performance_comparison_amd_epyc.png
new file mode 100644
index 0000000..6ebdb3d
Binary files /dev/null and b/src/assets/performance_comparison_amd_epyc.png differ
diff --git a/src/assets/performance_comparison_cobalt100_dotprod.png b/src/assets/performance_comparison_cobalt100_dotprod.png
new file mode 100644
index 0000000..4d0ef8c
Binary files /dev/null and b/src/assets/performance_comparison_cobalt100_dotprod.png differ
diff --git a/src/assets/performance_comparison_i7-13800h.png b/src/assets/performance_comparison_i7-13800h.png
new file mode 100644
index 0000000..e486d66
Binary files /dev/null and b/src/assets/performance_comparison_i7-13800h.png differ
diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp
index eeca82b..4ba9d65 100644
--- a/src/ggml-bitnet-mad.cpp
+++ b/src/ggml-bitnet-mad.cpp
@@ -1,13 +1,18 @@
#include
#include
-
+#include
#include "ggml-bitnet.h"
#include "ggml-quants.h"
+#include "gemm-config.h"
+#include "ggml-cpu-impl.h"
#include
#include
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#define QK_I2_S 128
-#define QK_I2 128
+#elif defined(__ARM_NEON)
+#define QK_I2_S 64
+#endif
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#include
@@ -44,8 +49,8 @@ static inline int hsum_i32_8(const __m256i a) {
#endif
size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
- // 2 bits per weight
-
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+#if defined(ACT_PARALLEL)
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
int n = nrow * n_per_row;
@@ -73,11 +78,11 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_
// -1, 0, 1
uint8_t* i2_weight = (uint8_t*)dst;
- for (int i = 0; i < n / QK_I2; i++) {
- for (int j = 0; j < QK_I2; j++) {
+ for (int i = 0; i < n / QK_I2_S; i++) {
+ for (int j = 0; j < QK_I2_S; j++) {
int group_idx = j / 32;
int group_pos = j % 32;
- uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx));
+ uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx));
i2_weight[i * 32 + group_pos] |= temp;
}
}
@@ -89,9 +94,207 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_
// 32B for alignment
return nrow * row_size / 4 + 32;
+#else
+ assert((nrow % 4) == 0 && "quantize_i2_s_1x4 requires nrow % 4 == 0");
+
+ size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
+ int64_t n = nrow * n_per_row;
+
+ double max = 0;
+ for (int64_t i = 0; i < n; ++i) {
+ max = fmax(max, (double)fabs((double)src[i]));
+ }
+ double i2_scale = max;
+
+ uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t));
+ for (int64_t i=0; i 0 ? 2 : 0;
+ }
+
+ uint8_t* out = (uint8_t*)dst;
+ memset(out, 0, (size_t)(n / 4));
+
+ // for each group of 4 rows, for each column, write one byte
+ int64_t nrow4 = nrow / 4;
+ for (int64_t rg = 0; rg < nrow4; rg++) {
+ int64_t r0 = rg * 4 + 0;
+ int64_t r1 = rg * 4 + 1;
+ int64_t r2 = rg * 4 + 2;
+ int64_t r3 = rg * 4 + 3;
+
+ int64_t base = rg * n_per_row;
+
+ for (int64_t col = 0; col < n_per_row; col++) {
+ uint8_t q0 = q8[r0 * n_per_row + col];
+ uint8_t q1 = q8[r1 * n_per_row + col];
+ uint8_t q2 = q8[r2 * n_per_row + col];
+ uint8_t q3 = q8[r3 * n_per_row + col];
+
+ uint8_t packed = (uint8_t)((q0 << 6) | (q1 << 4) | (q2 << 2) | (q3 << 0));
+ out[base + col] = packed;
+ }
+ }
+
+ // store scale at the end of quantized data (same location pattern as quantize_i2_s)
+ float* scale_ptr = (float*)((char*)out + n / 4);
+ scale_ptr[0] = (float)i2_scale;
+
+ free(q8);
+
+ // return size (keep same formula as quantize_i2_s)
+ return nrow * row_size / 4 + 32;
+#endif
+#elif defined(__ARM_NEON)
+ size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
+
+ int n = nrow * n_per_row;
+
+ // f32 -> q8
+ double max = 0;
+ for (int i = 0; i < n; ++i) {
+ max = fmax(max, (double)fabs((double)src[i]));
+ }
+ double i2_scale = max;
+
+ uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t));
+ for (int i=0; i 0 ? 2 : 0;
+ }
+
+ memset(dst, 0, n * sizeof(uint8_t) / 4);
+
+ // q8 -> 0, 1, 2
+ // | | |
+ // -1, 0, 1
+
+ uint8_t* i2_weight = (uint8_t*)dst;
+ for (int i = 0; i < n / QK_I2_S; i++) {
+ for (int j = 0; j < QK_I2_S; j++) {
+ int group_idx = j / 16;
+ int group_pos = j % 16;
+ uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx));
+ i2_weight[i * 16 + group_pos] |= temp;
+ }
+ }
+
+ float* scale_ptr = (float*)((char*)i2_weight + n / 4);
+ scale_ptr[0] = i2_scale;
+
+ free(q8);
+
+ // 32B for alignment
+ return nrow * row_size / 4 + 32;
+#endif
}
-void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if defined(__AVX2__)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ __m256i mask = _mm256_set1_epi8(0x03);
+ __m256i one16 = _mm256_set1_epi16(1);
+
+ // 处理多行,nrc表示要处理的行数
+ for (int row = 0; row < nrc; row++) {
+ __m256i accu = _mm256_setzero_si256();
+
+ // 计算当前行的x指针偏移
+ const uint8_t * x_row = x + row * bx / 4;
+
+ for (int i = 0; i < group32_num; i++) {
+ const uint8_t *px = x_row + i * 1024; // 32 * 32
+ const int8_t *py = y + i * 4096; // 32 * 128
+ __m256i accu32 = _mm256_setzero_si256();
+
+ for (int j = 0; j < 32; j++) {
+ // 128 index
+ __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px));
+ __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
+ __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
+ __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
+
+ // each 32 index
+ xq8_3 = _mm256_and_si256(xq8_3, mask);
+ xq8_2 = _mm256_and_si256(xq8_2, mask);
+ xq8_1 = _mm256_and_si256(xq8_1, mask);
+ xq8_0 = _mm256_and_si256(xq8_0, mask);
+
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32));
+ __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64));
+ __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96));
+
+ xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
+ xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
+ xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
+ xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
+
+ accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
+ accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
+
+ px += 32;
+ py += 128;
+ }
+ accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, one16), accu);
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+ __m256i accula = _mm256_setzero_si256();
+ const uint8_t *px = x_row + group32_num * 1024; // 32 * 32
+ const int8_t *py = y + group32_num * 4096; // 32 * 128
+
+ for (int j = 0; j < la_num; j++) {
+ // 128 index
+ __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px));
+ __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
+ __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
+ __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
+
+ // each 32 index
+ xq8_3 = _mm256_and_si256(xq8_3, mask);
+ xq8_2 = _mm256_and_si256(xq8_2, mask);
+ xq8_1 = _mm256_and_si256(xq8_1, mask);
+ xq8_0 = _mm256_and_si256(xq8_0, mask);
+
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32));
+ __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64));
+ __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96));
+
+ xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
+ xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
+ xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
+ xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
+
+ accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
+ accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
+
+ px += 32;
+ py += 128;
+ }
+ accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, one16));
+ }
+
+ int sumi = hsum_i32_8(accu);
+ s[row] = (float)sumi;
+ }
+#elif defined(__ARM_NEON)
const uint8_t * x = (uint8_t *)vx;
const int8_t * y = (int8_t *)vy;
@@ -100,264 +303,754 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b
const int la_num = nb % 32;
const int groupla_num = nb % 32 != 0 ? 1 : 0;
-#if defined(__AVX2__)
-
- __m256i mask = _mm256_set1_epi8(0x03);
- __m256i accu = _mm256_setzero_si256();
-
- for (int i=0; i < group32_num; i++){
- __m256i accu32 = _mm256_setzero_si256();
- for (int j=0; j < 32; j++) {
- // 128 index
- __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32));
- __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
- __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
- __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
-
- // each 32 index
- xq8_3 = _mm256_and_si256(xq8_3, mask);
- xq8_2 = _mm256_and_si256(xq8_2, mask);
- xq8_1 = _mm256_and_si256(xq8_1, mask);
- xq8_0 = _mm256_and_si256(xq8_0, mask);
-
- // each 32 index
- __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0));
- __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32));
- __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64));
- __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96));
-
- // 128 index accumulation add
- // split into 32 accumulation block
- // each block each 128 index accumulated 4index
- // each index maximum 256
- // each block maximum 4 * 256
- // each block accumulation maximum 127 * 256
- // each 32 group index (128 index in one group) needs cast to int32
- xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
- xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
- xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
- xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
-
- accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
- accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
- }
- accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu);
- }
-
- for (int i = 0; i < groupla_num; i++){
- __m256i accula = _mm256_setzero_si256();
- for (int j = 0; j < la_num; j++) {
- // 128 index
- __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32));
- __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
- __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
- __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
-
- // each 32 index
- xq8_3 = _mm256_and_si256(xq8_3, mask);
- xq8_2 = _mm256_and_si256(xq8_2, mask);
- xq8_1 = _mm256_and_si256(xq8_1, mask);
- xq8_0 = _mm256_and_si256(xq8_0, mask);
-
- // each 32 index
- __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0));
- __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32));
- __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64));
- __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96));
-
- // 128 index accumulation add
- // split into 32 accumulation block
- // each block each 128 index accumulated 4index
- // each index maximum 256
- // each block maximum 4 * 256
- // each block accumulation maximum 127 * 256
- // each 32 group index (128 index in one group) needs cast to int32
- xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
- xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
- xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
- xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
-
- accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
- accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
- }
- accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1)));
- }
- int sumi = hsum_i32_8(accu);
- *s = (float)sumi;
-
-#elif defined(__ARM_NEON)
-
- int32x4_t accu_0 = vdupq_n_s32(0);
- int32x4_t accu_1 = vdupq_n_s32(0);
- int32x4_t accu_2 = vdupq_n_s32(0);
- int32x4_t accu_3 = vdupq_n_s32(0);
const uint8x16_t mask = vdupq_n_u8(3);
- for (int i=0; i < group32_num; i++) {
+ // 处理多列,nrc表示要处理的列数
+ for (int row = 0; row < nrc; row++) {
+ int32x4_t accu = vdupq_n_s32(0);
+
+ // 计算当前行的x指针偏移
+ const uint8_t * x_row = x + row * bx / 4;
+
+ for (int i=0; i < group32_num; i++) {
#if defined(__ARM_FEATURE_DOTPROD)
#else
- int16x8_t accu32_0 = vdupq_n_s16(0);
- int16x8_t accu32_1 = vdupq_n_s16(0);
- int16x8_t accu32_2 = vdupq_n_s16(0);
- int16x8_t accu32_3 = vdupq_n_s16(0);
+ int16x8_t accu32 = vdupq_n_s16(0);
#endif
+ for (int j=0; j < 32; j++) {
+ uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
- for (int j=0; j < 32; j++) {
- uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32);
- uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16);
- uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
- uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
- uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
- uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
- uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
- uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
- int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
- int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
- int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
- int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
- int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
- int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
- int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
- int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
-
- const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0);
- const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16);
- const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32);
- const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48);
- const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64);
- const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80);
- const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96);
- const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112);
+ const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0);
+ const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16);
+ const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32);
+ const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48);
#if defined(__ARM_FEATURE_DOTPROD)
- accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
- accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
- accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
- accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
- accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
- accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
- accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
- accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
+ accu = vdotq_s32(accu, q8_0, yq8_0);
+ accu = vdotq_s32(accu, q8_1, yq8_1);
+ accu = vdotq_s32(accu, q8_2, yq8_2);
+ accu = vdotq_s32(accu, q8_3, yq8_3);
#else
- accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
- accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
- accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
- accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
- accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
- accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
- accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
- accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
- accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
- accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
- accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
- accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
- accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
- accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
- accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
- accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
+ accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0));
+ accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3));
+#endif
+ }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32)));
+ accu = vaddq_s32(accu, vmovl_high_s16(accu32));
#endif
}
+ for (int i = 0; i < groupla_num; i++){
#if defined(__ARM_FEATURE_DOTPROD)
#else
- accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0)));
- accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0));
- accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1)));
- accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1));
- accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2)));
- accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2));
- accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3)));
- accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3));
+ int16x8_t accula = vdupq_n_s16(0);
#endif
- }
+ for (int j = 0; j < la_num; j++) {
+ uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
+
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
+
+ const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0);
+ const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16);
+ const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32);
+ const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48);
- for (int i = 0; i < groupla_num; i++){
#if defined(__ARM_FEATURE_DOTPROD)
-
+ accu = vdotq_s32(accu, q8_0, yq8_0);
+ accu = vdotq_s32(accu, q8_1, yq8_1);
+ accu = vdotq_s32(accu, q8_2, yq8_2);
+ accu = vdotq_s32(accu, q8_3, yq8_3);
#else
- int16x8_t accula_0 = vdupq_n_s16(0);
- int16x8_t accula_1 = vdupq_n_s16(0);
- int16x8_t accula_2 = vdupq_n_s16(0);
- int16x8_t accula_3 = vdupq_n_s16(0);
+ accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0));
+ accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3));
#endif
- for (int j = 0; j < la_num; j++) {
- uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32);
- uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16);
- uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
- uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
- uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
- uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
- uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
- uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
-
- int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
- int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
- int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
- int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
- int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
- int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
- int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
- int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
-
- const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0);
- const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16);
- const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32);
- const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48);
- const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64);
- const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80);
- const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96);
- const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112);
-
+ }
#if defined(__ARM_FEATURE_DOTPROD)
- accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
- accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
- accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
- accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
- accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
- accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
- accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
- accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
+
#else
- accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
- accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
- accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
- accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
- accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
- accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
- accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
- accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
- accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
- accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
- accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
- accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
- accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
- accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
- accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
- accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
+ accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula)));
+ accu = vaddq_s32(accu, vmovl_high_s16(accula));
#endif
}
+ int sumi = vaddlvq_s32(accu);
+ s[row] = (float)sumi;
+ }
+#endif
+}
+
+void ggml_vec_dot_i2_i8_s_1x4_32W(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if defined(__AVX2__)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ const __m256i mask = _mm256_set1_epi8(0x03);
+ const __m256i one16 = _mm256_set1_epi16(1);
+
+ // 处理多行,nrc表示要处理的行数
+ for (int row = 0; row < nrc; row+=4) {
+ __m256i accu[4];
+ for(int rb = 0; rb < 4; rb++) {
+ accu[rb] = _mm256_setzero_si256();
+ }
+ const uint8_t * x_row = x + (row) * bx / 4;
+ // 计算当前行的x指针偏移
+
+ for (int i = 0; i < group32_num; i++) {
+ const uint8_t * px = x_row + i * 1024 * 4;
+ __m256i accu32[4];
+ for(int rb = 0; rb < 4; rb++) {
+ accu32[rb] = _mm256_setzero_si256();
+ }
+ const int8_t *py = y + i * 4096;
+
+ for (int j = 0; j < 32 * 4; j++) {
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i xq8[4];
+ xq8[3] = _mm256_loadu_si256((const __m256i*)(px));
+ xq8[2] = _mm256_srli_epi16(xq8[3], 2);
+ xq8[1] = _mm256_srli_epi16(xq8[3], 4);
+ xq8[0] = _mm256_srli_epi16(xq8[3], 6);
+ xq8[3] = _mm256_and_si256(xq8[3], mask);
+ xq8[2] = _mm256_and_si256(xq8[2], mask);
+ xq8[1] = _mm256_and_si256(xq8[1], mask);
+ xq8[0] = _mm256_and_si256(xq8[0], mask);
+ for (int rb = 0; rb < 4; rb++)
+ {
+ xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0);
+ accu32[rb] = _mm256_add_epi16(accu32[rb], xq8[rb]);
+ }
+ px += 32;
+ py += 32;
+ }
+ for(int rb = 0; rb < 4; rb++) {
+ accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]);
+ }
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+ const int8_t *py = y + group32_num * 4096; // 32 * 128
+ __m256i accula[4];
+ for(int rb = 0; rb < 4; rb++) {
+ accula[rb] = _mm256_setzero_si256();
+ }
+ const uint8_t * px = x_row + group32_num * 1024 * 4;
+
+ for (int j = 0; j < la_num * 4; j++) {
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i xq8[4];
+ xq8[3] = _mm256_loadu_si256((const __m256i*)(px));
+ xq8[2] = _mm256_srli_epi16(xq8[3], 2);
+ xq8[1] = _mm256_srli_epi16(xq8[3], 4);
+ xq8[0] = _mm256_srli_epi16(xq8[3], 6);
+ xq8[3] = _mm256_and_si256(xq8[3], mask);
+ xq8[2] = _mm256_and_si256(xq8[2], mask);
+ xq8[1] = _mm256_and_si256(xq8[1], mask);
+ xq8[0] = _mm256_and_si256(xq8[0], mask);
+
+ for (int rb = 0; rb < 4; rb++) {
+ xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0);
+ accula[rb] = _mm256_add_epi16(accula[rb], xq8[rb]);
+ }
+ px += 32;
+ py += 32;
+ }
+ for(int rb = 0; rb < 4; rb++) {
+ accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16));
+ }
+ }
+
+ for(int rb = 0; rb < 4; rb++) {
+ int sumi = hsum_i32_8(accu[rb]);
+ s[row + rb] = (float)sumi;
+ }
+ }
+#elif defined(__ARM_NEON)
+
+#endif
+}
+
+void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if defined(__AVX2__)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ const __m256i mask = _mm256_set1_epi8(0x03);
+ const __m256i one16 = _mm256_set1_epi16(1);
+
+ // 处理多行,nrc表示要处理的行数
+ for (int row = 0; row < nrc; row+=PARALLEL_SIZE) {
+ //__m256i accu = _mm256_setzero_si256();
+ __m256i accu[PARALLEL_SIZE];
+ const uint8_t * x_row[PARALLEL_SIZE];
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = _mm256_setzero_si256();
+ x_row[rb] = x + (row + rb) * bx / 4;
+ }
+ // 计算当前行的x指针偏移
+
+ for (int i = 0; i < group32_num; i++) {
+ const uint8_t * px[PARALLEL_SIZE];
+ __m256i accu32[PARALLEL_SIZE];
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ px[rb] = x_row[rb] + i * 1024; // 32 * 32
+ accu32[rb] = _mm256_setzero_si256();
+ }
+ const int8_t *py = y + i * 4096; // 32 * 128
+
+ for (int j = 0; j < 32; j++) {
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32));
+ __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64));
+ __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96));
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++)
+ {
+ __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb]));
+ __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
+ __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
+ __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
+
+ // each 32 index
+ xq8_3 = _mm256_and_si256(xq8_3, mask);
+ xq8_2 = _mm256_and_si256(xq8_2, mask);
+ xq8_1 = _mm256_and_si256(xq8_1, mask);
+ xq8_0 = _mm256_and_si256(xq8_0, mask);
+
+ xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
+ xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
+ xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
+ xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
+
+ accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_0, xq8_1));
+ accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_2, xq8_3));
+
+ px[rb] += 32;
+ }
+ py += 128;
+ }
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]);
+ }
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+ const int8_t *py = y + group32_num * 4096; // 32 * 128
+ const uint8_t * px[PARALLEL_SIZE];
+ __m256i accula[PARALLEL_SIZE];
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ px[rb] = x_row[rb] + group32_num * 1024; // 32 * 32
+ accula[rb] = _mm256_setzero_si256();
+ }
+
+ for (int j = 0; j < la_num; j++) {
+ // each 32 index
+ __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py));
+ __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32));
+ __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64));
+ __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96));
+
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ // 128 index
+ __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb]));
+ __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
+ __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
+ __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
+
+ // each 32 index
+ xq8_3 = _mm256_and_si256(xq8_3, mask);
+ xq8_2 = _mm256_and_si256(xq8_2, mask);
+ xq8_1 = _mm256_and_si256(xq8_1, mask);
+ xq8_0 = _mm256_and_si256(xq8_0, mask);
+
+
+
+ xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
+ xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
+ xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
+ xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
+
+ accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_0, xq8_1));
+ accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_2, xq8_3));
+
+ px[rb] += 32;
+ }
+ py += 128;
+ }
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16));
+ }
+ }
+
+ for(int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ int sumi = hsum_i32_8(accu[rb]);
+ s[row + rb] = (float)sumi;
+ }
+ }
+#elif defined(__ARM_NEON)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ const uint8x16_t mask = vdupq_n_u8(3);
+
+ // 处理多行,nrc表示要处理的行数
+ for (int row = 0; row < nrc; row += PARALLEL_SIZE) {
+
+ int32x4_t accu[PARALLEL_SIZE];
+ const uint8_t * x_row[PARALLEL_SIZE];
+
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = vdupq_n_s32(0);
+ x_row[rb] = x + (row + rb) * bx / 4;
+ }
+
+ for (int i = 0; i < group32_num; i++) {
#if defined(__ARM_FEATURE_DOTPROD)
#else
- accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0)));
- accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0));
- accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1)));
- accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1));
- accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2)));
- accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2));
- accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3)));
- accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3));
+ int16x8_t accu32[PARALLEL_SIZE];
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu32[rb] = vdupq_n_s16(0);
+ }
#endif
- }
- accu_0 = vaddq_s32(accu_0, accu_1);
- accu_2 = vaddq_s32(accu_2, accu_3);
- accu_0 = vaddq_s32(accu_0, accu_2);
- int sumi = vaddlvq_s32(accu_0);
- *s = (float)sumi;
+ const uint8_t * px[PARALLEL_SIZE];
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ px[rb] = x_row[rb] + i * 32 * 16;
+ }
+
+ for (int j = 0; j < 32; j++) {
+ // 加载 y 数据(对所有行共享)
+ const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0);
+ const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16);
+ const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32);
+ const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48);
+
+ // 处理每一行
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
+
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0);
+ accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1);
+ accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2);
+ accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3);
+#else
+ accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0));
+
+#endif
+ px[rb] += 16;
+ }
+ }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb])));
+ accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb]));
+ }
+#endif
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ int16x8_t accula[PARALLEL_SIZE];
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accula[rb] = vdupq_n_s16(0);
+ }
+#endif
+ const uint8_t * px[PARALLEL_SIZE];
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ px[rb] = x_row[rb] + group32_num * 32 * 16;
+ }
+
+ for (int j = 0; j < la_num; j++) {
+ // 加载 y 数据(对所有行共享)
+ const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0);
+ const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16);
+ const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32);
+ const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48);
+
+ // 处理每一行
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
+
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0);
+ accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1);
+ accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2);
+ accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3);
+#else
+ accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3));
+ accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0));
#endif
+ px[rb] += 16;
+ }
+ }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb])));
+ accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb]));
+ }
+#endif
+ }
+
+ // 合并结果并写回
+ for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
+ int sumi = vaddlvq_s32(accu[rb]);
+ s[row + rb] = (float)sumi;
+ }
+ }
+#endif
+}
+
+void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+#if defined(__AVX2__)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ __m256i mask = _mm256_set1_epi8(0x03);
+ __m256i one16 = _mm256_set1_epi16(1);
+
+ for (int col = 0; col < nrc; col += PARALLEL_SIZE) {
+ __m256i accu[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = _mm256_setzero_si256();
+ }
+
+ int8_t * y_col = y + col * by;
+
+ for (int i = 0; i < group32_num; i++) {
+ const uint8_t *px = x + i * 1024;
+ const int8_t *py = y_col + i * 4096;
+ __m256i accu32[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu32[iy] = _mm256_setzero_si256();
+ }
+
+ for (int j = 0; j < 32; j++) {
+
+ __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px));
+ __m256i xq8_3 = _mm256_and_si256(xq8, mask);
+ __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask);
+ __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask);
+ __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask);
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++)
+ {
+ accu32[iy] = _mm256_add_epi16(accu32[iy], _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))),
+ _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))),
+ _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))),
+ _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by))))));
+ }
+
+ px += 32;
+ py += 128;
+ }
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accu32[iy], one16), accu[iy]);
+ }
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+ const uint8_t *px = x + group32_num * 1024;
+ const int8_t *py = y_col + group32_num * 4096;
+ __m256i accula[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accula[iy] = _mm256_setzero_si256();
+ }
+
+ for (int j = 0; j < la_num; j++) {
+
+ __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px));
+ __m256i xq8_3 = _mm256_and_si256(xq8, mask);
+ __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask);
+ __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask);
+ __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask);
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++)
+ {
+ accula[iy] = _mm256_add_epi16(accula[iy], _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))),
+ _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))),
+ _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))),
+ _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by))))));
+ }
+
+ px += 32;
+ py += 128;
+ }
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accula[iy], one16), accu[iy]);
+ }
+ }
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ int sumi = hsum_i32_8(accu[iy]);
+ s[(col + iy) * bs] = (float)sumi;
+ }
+ }
+#elif defined(__ARM_NEON)
+ const uint8_t * x = (uint8_t *)vx;
+ const int8_t * y = (int8_t *)vy;
+
+ const int nb = n / QK_I2_S;
+ const int group32_num = nb / 32;
+ const int la_num = nb % 32;
+ const int groupla_num = nb % 32 != 0 ? 1 : 0;
+
+ const uint8x16_t mask = vdupq_n_u8(3);
+
+ for (int col = 0; col < nrc; col += PARALLEL_SIZE) {
+ int32x4_t accu[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = vdupq_n_s32(0);
+ }
+
+ const int8_t * y_col = y + col * by;
+
+ for (int i = 0; i < group32_num; i++) {
+ const uint8_t *px = x + i * 512; // i * 32 * 16
+ const int8_t *py = y_col + i * 2048; // i * 32 * 64
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ int16x8_t accu32[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu32[iy] = vdupq_n_s16(0);
+ }
+#endif
+ for (int j = 0; j < 32; j++) {
+ // 加载并解包 x 数据(对所有列共享)
+ uint8x16_t xq8_3 = vld1q_u8(px + 0);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
+
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
+
+ // 处理每一列
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by);
+ const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by);
+ const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by);
+ const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0);
+ accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1);
+ accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2);
+ accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3);
+#else
+ accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3));
+#endif
+ }
+
+ px += 16;
+ py += 64;
+ }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy]))));
+ }
+#endif
+ }
+
+ for (int i = 0; i < groupla_num; i++) {
+ const uint8_t *px = x + group32_num * 512;
+ const int8_t *py = y_col + group32_num * 2048;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ int16x8_t accula[PARALLEL_SIZE];
+
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accula[iy] = vdupq_n_s16(0);
+ }
+#endif
+
+ for (int j = 0; j < la_num; j++) {
+ // 加载并解包 x 数据(对所有列共享)
+ uint8x16_t xq8_3 = vld1q_u8(px + 0);
+ uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2);
+ uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4);
+ uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6);
+
+ int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
+ int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
+ int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
+ int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
+
+ // 处理每一列
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by);
+ const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by);
+ const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by);
+ const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0);
+ accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1);
+ accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2);
+ accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3);
+#else
+ accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0));
+ accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0));
+ accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1));
+ accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1));
+ accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2));
+ accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2));
+ accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3));
+ accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3));
+#endif
+ }
+
+ px += 16;
+ py += 64;
+ }
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+#else
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy]))));
+ }
+#endif
+ }
+
+ // 合并结果并写回
+ for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
+ int sumi = vaddlvq_s32(accu[iy]);
+ s[(col + iy) * bs] = (float)sumi;
+ }
+ }
+#endif
+}
+
+
+void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+ if (nrc % PARALLEL_SIZE == 0)
+ {
+#if defined(ACT_PARALLEL)
+ ggml_vec_dot_i2_i8_s_Nx1(n, s, bs, vx, bx, vy, by, nrc);
+#else
+ ggml_vec_dot_i2_i8_s_1xN(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+ }
+ else
+ {
+ ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc);
+ }
}
\ No newline at end of file
diff --git a/utils/convert-helper-bitnet.py b/utils/convert-helper-bitnet.py
index 5b4149a..9ed8db0 100644
--- a/utils/convert-helper-bitnet.py
+++ b/utils/convert-helper-bitnet.py
@@ -109,12 +109,12 @@ def main():
except OSError as e:
print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
- if gguf_f32_output.exists():
- print(f"Removing f32 GGUF: {gguf_f32_output}")
- try:
- gguf_f32_output.unlink()
- except OSError as e:
- print(f"Warning: Could not remove {gguf_f32_output}: {e}")
+ # if gguf_f32_output.exists():
+ # print(f"Removing f32 GGUF: {gguf_f32_output}")
+ # try:
+ # gguf_f32_output.unlink()
+ # except OSError as e:
+ # print(f"Warning: Could not remove {gguf_f32_output}: {e}")
if input_backup_file.exists():
if not input_file.exists():
diff --git a/utils/kernel_tuning.py b/utils/kernel_tuning.py
deleted file mode 100644
index e69de29..0000000
diff --git a/utils/quantize_embeddings.py b/utils/quantize_embeddings.py
new file mode 100644
index 0000000..90b8020
--- /dev/null
+++ b/utils/quantize_embeddings.py
@@ -0,0 +1,473 @@
+#!/usr/bin/env python3
+"""
+Embedding Quantization Script
+This script converts ggml-model-f32.gguf to multiple quantized versions
+with different token embedding types.
+"""
+
+import subprocess
+import os
+import argparse
+import re
+import csv
+from pathlib import Path
+from datetime import datetime
+
+
+class EmbeddingQuantizer:
+ def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
+ bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
+ self.input_model = Path(input_model)
+ self.output_dir = Path(output_dir)
+ self.quantize_bin = Path(quantize_bin)
+ self.bench_bin = Path(bench_bin)
+ self.stats_dir = Path(stats_dir)
+ self.csv_output = Path(csv_output) if csv_output else None
+
+ # Verify input file exists
+ if not self.input_model.exists():
+ raise FileNotFoundError(f"Input model not found: {self.input_model}")
+
+ # Verify quantize tool exists
+ if not self.quantize_bin.exists():
+ raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
+
+ # Verify bench tool exists
+ if not self.bench_bin.exists():
+ raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
+
+ # Create output directories
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.stats_dir.mkdir(parents=True, exist_ok=True)
+
+ self.results = []
+ self.newly_created_files = set() # Track newly created files
+
+ def quantize(self, embedding_type, output_suffix):
+ """
+ Perform single quantization
+
+ Args:
+ embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
+ output_suffix: Output file suffix (lowercase format, e.g., q6_k)
+
+ Returns:
+ bool: Whether successful
+ """
+ output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ # Check if file already exists
+ file_already_existed = output_file.exists()
+
+ if file_already_existed:
+ print(f"ℹ️ File already exists: {output_file}")
+ print(f" Skipping quantization, will use existing file for benchmark")
+ return True
+
+ cmd = [
+ str(self.quantize_bin),
+ "--token-embedding-type", embedding_type,
+ str(self.input_model),
+ str(output_file),
+ "I2_S",
+ "1",
+ "1"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🔄 Quantizing with embedding type: {embedding_type}")
+ print(f"📥 Input: {self.input_model}")
+ print(f"📤 Output: {output_file}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ start_time = datetime.now()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=600 # 10 minute timeout
+ )
+
+ end_time = datetime.now()
+ duration = (end_time - start_time).total_seconds()
+
+ if result.returncode == 0:
+ # Get output file size
+ file_size_mb = output_file.stat().st_size / (1024 * 1024)
+
+ print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
+
+ # Record newly created file
+ if not file_already_existed:
+ self.newly_created_files.add(output_file)
+
+ # Print part of output
+ if result.stdout:
+ print("\n📊 Quantization output:")
+ print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
+
+ return True
+ else:
+ print(f"❌ Failed with return code {result.returncode}")
+ print(f"Error: {result.stderr}")
+ return False
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Timeout (exceeded 10 minutes)")
+ return False
+
+ except Exception as e:
+ print(f"❌ Exception: {e}")
+ return False
+
+ def benchmark_model(self, output_suffix):
+ """
+ Benchmark model
+
+ Args:
+ output_suffix: Output file suffix (lowercase format, e.g., q6_k)
+
+ Returns:
+ dict: Dictionary with benchmark results, or None if failed
+ """
+ model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ if not model_file.exists():
+ print(f"❌ Model file not found for benchmarking: {model_file}")
+ return None
+
+ cmd = [
+ str(self.bench_bin),
+ "-m", str(model_file),
+ "-p", "128",
+ "-n", "0",
+ "-t", "1,2,4,8",
+ "-ngl", "0"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🏃 Running benchmark for: {output_suffix}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=300 # 5 minute timeout
+ )
+
+ if result.returncode == 0:
+ print("✅ Benchmark completed successfully")
+ print("\n📊 Benchmark output:")
+ print(result.stdout)
+
+ # 解析输出
+ bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
+ return bench_results
+ else:
+ print(f"❌ Benchmark failed with return code {result.returncode}")
+ print(f"Error: {result.stderr}")
+ return None
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Benchmark timeout (exceeded 5 minutes)")
+ return None
+
+ except Exception as e:
+ print(f"❌ Benchmark exception: {e}")
+ return None
+
+ def parse_benchmark_output(self, output, output_suffix):
+ """
+ Parse benchmark output to extract t/s data (mean±std)
+
+ Args:
+ output: Benchmark command output
+ output_suffix: Output file suffix
+
+ Returns:
+ dict: Dictionary with parsed results
+ """
+ results = {
+ 'embedding_type': output_suffix,
+ 'threads_1': None,
+ 'threads_2': None,
+ 'threads_4': None,
+ 'threads_8': None,
+ }
+
+ # Parse table data
+ # Find lines containing pp128 and t/s
+ lines = output.strip().split('\n')
+
+ for line in lines:
+ # Skip header and separator lines
+ if '|' not in line or 'model' in line or '---' in line:
+ continue
+
+ # Try to extract data
+ # Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
+ parts = [p.strip() for p in line.split('|')]
+
+ if len(parts) >= 8 and 'pp128' in parts[6]:
+ threads_str = parts[5].strip()
+ throughput_str = parts[7].strip()
+
+ # Extract thread count
+ try:
+ threads = int(threads_str)
+ except:
+ continue
+
+ # Extract t/s data (format: "405.73 ± 3.69" or "405.73")
+ # Try to match "mean ± std" format
+ match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
+ if match_with_std:
+ mean = float(match_with_std.group(1))
+ std = float(match_with_std.group(2))
+ throughput = f"{mean:.2f}±{std:.2f}"
+ else:
+ # Only mean, no std
+ match = re.search(r'([\d.]+)', throughput_str)
+ if match:
+ throughput = f"{float(match.group(1)):.2f}"
+ else:
+ continue
+
+ # Store result based on thread count
+ if threads == 1:
+ results['threads_1'] = throughput
+ elif threads == 2:
+ results['threads_2'] = throughput
+ elif threads == 4:
+ results['threads_4'] = throughput
+ elif threads == 8:
+ results['threads_8'] = throughput
+
+ return results
+
+ def cleanup_model(self, output_suffix):
+ """
+ Cleanup model files (only delete newly created files)
+
+ Args:
+ output_suffix: Output file suffix
+ """
+ model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ if model_file in self.newly_created_files:
+ try:
+ model_file.unlink()
+ print(f"🗑️ Deleted newly created file: {model_file}")
+ self.newly_created_files.remove(model_file)
+ except Exception as e:
+ print(f"⚠️ Failed to delete {model_file}: {e}")
+ else:
+ print(f"ℹ️ Keeping existing file: {model_file}")
+
+ def run_all_quantizations(self, types_to_quantize):
+ """
+ Run all quantizations
+
+ Args:
+ types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
+ """
+ print(f"\n{'='*80}")
+ print(f"🚀 Starting Embedding Quantization and Benchmarking")
+ print(f"{'='*80}")
+ print(f"📥 Input model: {self.input_model}")
+ print(f"📤 Output directory: {self.output_dir}")
+ print(f"📊 Stats directory: {self.stats_dir}")
+ print(f"🔢 Total quantizations: {len(types_to_quantize)}")
+ print(f"{'='*80}\n")
+
+ total_start = datetime.now()
+
+ for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
+ print(f"\n{'#'*80}")
+ print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
+ print(f"{'#'*80}\n")
+
+ # Quantize model
+ success = self.quantize(embedding_type, output_suffix)
+
+ if not success:
+ print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
+ continue
+
+ # Run benchmark
+ bench_results = self.benchmark_model(output_suffix)
+
+ if bench_results:
+ self.results.append(bench_results)
+ else:
+ print(f"⚠️ Benchmark failed for {output_suffix}")
+
+ # Cleanup model files (only delete newly created files)
+ self.cleanup_model(output_suffix)
+
+ print(f"\n{'#'*80}")
+ print(f"✅ Completed {output_suffix}")
+ print(f"{'#'*80}\n")
+
+ total_end = datetime.now()
+ total_duration = (total_end - total_start).total_seconds()
+
+ # 保存结果到CSV
+ self.save_results_to_csv()
+
+ # 打印总结
+ self.print_summary(total_duration)
+
+ def save_results_to_csv(self):
+ """将benchmark结果保存到CSV文件"""
+ if not self.results:
+ print("⚠️ No results to save")
+ return
+
+ # Use user-specified CSV path, otherwise use default path
+ if self.csv_output:
+ csv_file = self.csv_output
+ # Ensure parent directory exists
+ csv_file.parent.mkdir(parents=True, exist_ok=True)
+ else:
+ csv_file = self.stats_dir / f"embedding_benchmark.csv"
+
+ print(f"\n💾 Saving results to: {csv_file}")
+
+ try:
+ with open(csv_file, 'w', newline='') as f:
+ fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+
+ writer.writeheader()
+ for result in self.results:
+ writer.writerow(result)
+
+ print(f"✅ Results saved successfully")
+
+ # Also print table
+ print(f"\n📊 Benchmark Results:")
+ print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
+ print("-" * 87)
+ for result in self.results:
+ t1 = result['threads_1'] if result['threads_1'] else "N/A"
+ t2 = result['threads_2'] if result['threads_2'] else "N/A"
+ t4 = result['threads_4'] if result['threads_4'] else "N/A"
+ t8 = result['threads_8'] if result['threads_8'] else "N/A"
+ print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
+
+ except Exception as e:
+ print(f"❌ Failed to save results: {e}")
+
+ def print_summary(self, total_duration):
+ """Print quantization summary"""
+ print(f"\n\n{'='*80}")
+ print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
+ print(f"{'='*80}\n")
+
+ successful = len(self.results)
+ total = len(self.results)
+
+ print(f"✅ Completed: {successful} benchmarks")
+ print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
+
+ if self.results:
+ if self.csv_output and self.csv_output.exists():
+ print(f"📁 Results saved to: {self.csv_output}")
+ else:
+ csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
+ if csv_files:
+ latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
+ print(f"📁 Results saved to: {latest_csv}")
+
+ print(f"\n{'='*80}\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
+ parser.add_argument('--input', '-i',
+ default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
+ help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
+ parser.add_argument('--output-dir', '-o',
+ default='../models/BitNet-b1.58-2B-4T',
+ help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
+ parser.add_argument('--quantize-bin', '-q',
+ default='../build/bin/llama-quantize',
+ help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
+ parser.add_argument('--bench-bin', '-b',
+ default='../build/bin/llama-bench',
+ help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
+ parser.add_argument('--stats-dir',
+ default='../stats',
+ help='Directory to save benchmark results (default: ../stats)')
+ parser.add_argument('--csv-output', '-c',
+ help='Custom path for CSV output file (e.g., stats/my_results.csv)')
+ parser.add_argument('--types', '-t',
+ nargs='+',
+ help='Specific types to quantize (e.g., f32 q6_k q4_0)')
+ parser.add_argument('--skip-existing', '-s',
+ action='store_true',
+ help='Skip quantization if output file already exists (will still benchmark existing files)')
+
+ args = parser.parse_args()
+
+ # Define all supported quantization types
+ # Format: (embedding_type for command line, output_suffix for filename)
+ all_types = [
+ ('F32', 'f32'),
+ ('F16', 'f16'),
+ ('Q8_0', 'q8_0'),
+ ('Q6_K', 'q6_k'),
+ ('Q5_0', 'q5_0'),
+ ('Q4_0', 'q4_0'),
+ ('Q3_K', 'q3_k'),
+ ('TQ2_0', 'tq2_0'),
+ ]
+
+ # If specific types are specified, filter the list
+ if args.types:
+ types_lower = [t.lower() for t in args.types]
+ types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
+ if not types_to_quantize:
+ print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
+ return
+ else:
+ types_to_quantize = all_types
+
+ # If skip existing files is enabled, no need to filter
+ # Because new logic will automatically detect and skip during quantization, but will still benchmark
+
+ # 创建量化器并运行
+ try:
+ quantizer = EmbeddingQuantizer(
+ args.input,
+ args.output_dir,
+ args.quantize_bin,
+ args.bench_bin,
+ args.stats_dir,
+ args.csv_output
+ )
+ quantizer.run_all_quantizations(types_to_quantize)
+ except FileNotFoundError as e:
+ print(f"❌ Error: {e}")
+ return 1
+ except KeyboardInterrupt:
+ print("\n\n⚠️ Quantization interrupted by user")
+ return 1
+ except Exception as e:
+ print(f"\n❌ Unexpected error: {e}")
+ import traceback
+ traceback.print_exc()
+ return 1
+
+
+if __name__ == "__main__":
+ exit(main() or 0)
diff --git a/utils/test_gemm_kernel.sh b/utils/test_gemm_kernel.sh
new file mode 100755
index 0000000..ae72c72
--- /dev/null
+++ b/utils/test_gemm_kernel.sh
@@ -0,0 +1,573 @@
+#!/bin/bash
+# Unified GEMM kernel benchmark script
+# Builds, tests, and benchmarks the GEMM kernel with configurable output
+
+set -e
+
+# Default values
+BUILD_DIR="../build"
+ITERATIONS=1000
+OUTPUT_CSV=""
+SKIP_BUILD=false
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Print usage
+print_usage() {
+ cat << EOF
+Usage: $0 [options]
+
+Options:
+ -o, --output Output CSV file path (default: ../stats/gemm_kernel_test_noparal.csv)
+ -i, --iterations Number of iterations per test (default: 1000)
+ -s, --skip-build Skip building the benchmark binary
+ -h, --help Show this help message
+
+Examples:
+ # Run with default settings
+ $0
+
+ # Specify custom output file
+ $0 -o /path/to/my_results.csv
+
+ # Quick test with fewer iterations
+ $0 -i 100 -o quick_test.csv
+
+ # Skip build if already compiled
+ $0 -s -o results.csv
+EOF
+}
+
+# Parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ -o|--output)
+ OUTPUT_CSV="$2"
+ shift 2
+ ;;
+ -i|--iterations)
+ ITERATIONS="$2"
+ shift 2
+ ;;
+ -s|--skip-build)
+ SKIP_BUILD=true
+ shift
+ ;;
+ -h|--help)
+ print_usage
+ exit 0
+ ;;
+ *)
+ echo "Unknown option: $1"
+ print_usage
+ exit 1
+ ;;
+ esac
+done
+
+# Set default output CSV if not specified
+if [ -z "$OUTPUT_CSV" ]; then
+ OUTPUT_CSV="${SCRIPT_DIR}/../stats/gemm_kernel_test_noparal.csv"
+fi
+
+# Create output directory first
+mkdir -p "$(dirname "$OUTPUT_CSV")"
+
+# Convert to absolute path
+if [[ "$OUTPUT_CSV" = /* ]]; then
+ # Already absolute path
+ OUTPUT_CSV="$OUTPUT_CSV"
+else
+ # Convert relative path to absolute
+ OUTPUT_CSV="$(cd "$(dirname "$OUTPUT_CSV")" && pwd)/$(basename "$OUTPUT_CSV")"
+fi
+
+echo "=========================================="
+echo "GEMM Kernel Benchmark Suite"
+echo "=========================================="
+echo "Configuration:"
+echo " Iterations: $ITERATIONS"
+echo " Output CSV: $OUTPUT_CSV"
+echo " Skip build: $SKIP_BUILD"
+echo "=========================================="
+echo ""
+
+# Build the benchmark binary
+if [ "$SKIP_BUILD" = false ]; then
+ echo "Step 1: Building GEMM kernel benchmark..."
+ echo "------------------------------------------"
+
+ CXX=${CXX:-g++}
+
+ # Create build directory if it doesn't exist
+ mkdir -p "${SCRIPT_DIR}/${BUILD_DIR}"
+
+ # Create temporary C++ source file
+ TEMP_CPP="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel_temp.cpp"
+
+ cat > "${TEMP_CPP}" << 'EOF'
+/**
+ * Standalone benchmark for ggml_gemm_i2_i8_s kernel
+ *
+ * This program tests the performance of the ggml_gemm_i2_i8_s kernel
+ * with configurable matrix sizes and iteration counts.
+ *
+ * Usage: ./test_gemm_kernel [options]
+ * -n : embedding dimension (must be divisible by 4, default: 2048)
+ * -r : number of rows in matrix Y (default: 32)
+ * -c : number of columns in matrix X (default: 128)
+ * -i : number of iterations (default: 1000)
+ * -w : number of warmup iterations (default: 10)
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// Include necessary headers
+#include "../include/gemm-config.h"
+
+// Function declarations (from ggml-quants.h)
+extern "C" void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc);
+
+// GEMM kernel definition
+void ggml_gemm_i2_i8_s(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+#if defined(ACT_PARALLEL)
+ const int64_t row_block = ROW_BLOCK_SIZE;
+ const int64_t col_block = COL_BLOCK_SIZE;
+
+ for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
+ int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
+ for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
+ int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
+ const void * vy_r = (const uint8_t *)vy + r0 * n;
+ for (int64_t c = 0; c < cur_c; ++c) {
+ const int64_t col = c0 + c;
+ float * s_col = s + col;
+ const void * vx_col = (const uint8_t *)vx + col * n / 4;
+ ggml_vec_dot_i2_i8_s(n, s_col + r0 * bs, bs, vx_col, n, vy_r, n, cur_r);
+ }
+ }
+ }
+#else
+ const int64_t row_block = ROW_BLOCK_SIZE;
+ const int64_t col_block = COL_BLOCK_SIZE;
+
+ for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
+ int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
+ for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
+ int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
+ const void * vx_c = (const uint8_t *)vx + c0 * n / 4;
+ for (int64_t r = 0; r < cur_r; ++r) {
+ const int64_t row = r0 + r;
+ float * s_row = s + row * bs;
+ const void * vy_row = (const uint8_t *)vy + row * n;
+ ggml_vec_dot_i2_i8_s(n, s_row + c0, bs, vx_c, n, vy_row, n, cur_c);
+ }
+ }
+ }
+#endif
+}
+
+// Helper function to get current time in nanoseconds
+double get_time_ns() {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return ts.tv_sec * 1e9 + ts.tv_nsec;
+}
+
+// Initialize matrix with random i2 values (2-bit quantized)
+void init_matrix_i2(uint8_t* data, int n, int cols) {
+ // i2 format: 4 values per byte (2 bits each)
+ int total_bytes = n * cols / 4;
+ for (int i = 0; i < total_bytes; i++) {
+ data[i] = rand() & 0xFF;
+ }
+}
+
+// Initialize matrix with random i8 values
+void init_matrix_i8(int8_t* data, int n, int rows) {
+ int total_elements = n * rows;
+ for (int i = 0; i < total_elements; i++) {
+ data[i] = (int8_t)((rand() % 256) - 128);
+ }
+}
+
+// Benchmark configuration
+struct BenchmarkConfig {
+ int n; // embedding dimension (must be divisible by 4)
+ int nr; // number of rows in Y matrix
+ int nc; // number of columns in X matrix
+ int iterations; // number of benchmark iterations
+ int warmup; // number of warmup iterations
+};
+
+void print_config(const BenchmarkConfig& config) {
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf("Benchmark Configuration:\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf(" Embedding dimension (n) : %d\n", config.n);
+ printf(" Matrix Y rows (nr) : %d\n", config.nr);
+ printf(" Matrix X columns (nc) : %d\n", config.nc);
+ printf(" Iterations : %d\n", config.iterations);
+ printf(" Warmup iterations : %d\n", config.warmup);
+ printf("\nMatrix sizes:\n");
+ printf(" X (i2): %d x %d (%.2f KB)\n", config.nc, config.n,
+ (config.nc * config.n / 4) / 1024.0);
+ printf(" Y (i8): %d x %d (%.2f KB)\n", config.nr, config.n,
+ (config.nr * config.n) / 1024.0);
+ printf(" S (f32): %d x %d (%.2f KB)\n", config.nr, config.nc,
+ (config.nr * config.nc * sizeof(float)) / 1024.0);
+ printf("\nGEMM Config:\n");
+#if defined(ACT_PARALLEL)
+ printf(" ACT_PARALLEL : ON\n");
+#else
+ printf(" ACT_PARALLEL : OFF\n");
+#endif
+ printf(" ROW_BLOCK_SIZE : %d\n", ROW_BLOCK_SIZE);
+ printf(" COL_BLOCK_SIZE : %d\n", COL_BLOCK_SIZE);
+ printf(" PARALLEL_SIZE : %d\n", PARALLEL_SIZE);
+ printf("=" "=%.78s\n\n", "===============================================================================");
+}
+
+void run_benchmark(const BenchmarkConfig& config) {
+ // Allocate matrices
+ printf("Allocating matrices...\n");
+
+ // X matrix (i2 format): nc x n, but stored as nc x (n/4) bytes
+ // Align to 64 bytes for AVX-512, which is backward compatible with AVX2 (32 bytes)
+ size_t x_size = config.nc * config.n / 4;
+ size_t x_size_aligned = ((x_size + 63) / 64) * 64;
+ uint8_t* X = (uint8_t*)aligned_alloc(64, x_size_aligned);
+
+ // Y matrix (i8 format): nr x n
+ size_t y_size = config.nr * config.n;
+ size_t y_size_aligned = ((y_size + 63) / 64) * 64;
+ int8_t* Y = (int8_t*)aligned_alloc(64, y_size_aligned);
+
+ // Result matrix (float32): nr x nc
+ size_t s_size = config.nr * config.nc * sizeof(float);
+ size_t s_size_aligned = ((s_size + 63) / 64) * 64;
+ float* S = (float*)aligned_alloc(64, s_size_aligned);
+
+ if (!X || !Y || !S) {
+ fprintf(stderr, "Failed to allocate memory\n");
+ exit(1);
+ }
+
+ // Initialize matrices with random data
+ printf("Initializing matrices with random data...\n");
+ srand(time(NULL));
+ init_matrix_i2(X, config.n, config.nc);
+ init_matrix_i8(Y, config.n, config.nr);
+ memset(S, 0, config.nr * config.nc * sizeof(float));
+
+ // Warmup
+ printf("Running %d warmup iterations...\n", config.warmup);
+ for (int i = 0; i < config.warmup; i++) {
+ ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
+ }
+
+ // Benchmark
+ printf("Running %d benchmark iterations...\n", config.iterations);
+ double total_time = 0.0;
+ double min_time = 1e20;
+ double max_time = 0.0;
+
+ for (int i = 0; i < config.iterations; i++) {
+ double start = get_time_ns();
+ ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
+ double end = get_time_ns();
+
+ double elapsed = end - start;
+ total_time += elapsed;
+ if (elapsed < min_time) min_time = elapsed;
+ if (elapsed > max_time) max_time = elapsed;
+
+ if ((i + 1) % 100 == 0) {
+ printf(" Progress: %d/%d iterations\n", i + 1, config.iterations);
+ }
+ }
+
+ // Calculate statistics
+ double avg_time_ns = total_time / config.iterations;
+ double avg_time_ms = avg_time_ns / 1e6;
+ double min_time_ms = min_time / 1e6;
+ double max_time_ms = max_time / 1e6;
+
+ // Calculate GFLOPS
+ // For GEMM: nr x nc x n multiply-adds = 2 * nr * nc * n FLOPs
+ double flops = 2.0 * config.nr * config.nc * config.n;
+ double gflops = (flops / avg_time_ns);
+
+ // Calculate throughput (tokens/s assuming each column is a token)
+ double throughput = (config.nc * 1e9) / avg_time_ns;
+
+ // Print results
+ printf("\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf("Benchmark Results:\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf(" Average time : %.3f ms\n", avg_time_ms);
+ printf(" Min time : %.3f ms\n", min_time_ms);
+ printf(" Max time : %.3f ms\n", max_time_ms);
+ printf(" Std dev : %.3f ms\n", sqrt((max_time_ms - min_time_ms) * (max_time_ms - min_time_ms) / 12));
+ printf("\nPerformance:\n");
+ printf(" GFLOPS : %.2f\n", gflops);
+ printf(" Throughput : %.2f tokens/s\n", throughput);
+ printf(" Latency/token : %.3f us\n", (avg_time_ms * 1000) / config.nc);
+ printf("=" "=%.78s\n", "===============================================================================");
+
+ // Cleanup
+ free(X);
+ free(Y);
+ free(S);
+}
+
+void print_usage(const char* program) {
+ printf("Usage: %s [options]\n", program);
+ printf("Options:\n");
+ printf(" -n Embedding dimension (must be divisible by 4, default: 2048)\n");
+ printf(" -r Number of rows in matrix Y (default: 32)\n");
+ printf(" -c Number of columns in matrix X (default: 128)\n");
+ printf(" -i Number of iterations (default: 1000)\n");
+ printf(" -w Number of warmup iterations (default: 10)\n");
+ printf(" -h Show this help message\n");
+}
+
+int main(int argc, char** argv) {
+ BenchmarkConfig config = {
+ .n = 2048,
+ .nr = 32,
+ .nc = 128,
+ .iterations = 1000,
+ .warmup = 10
+ };
+
+ // Parse command line arguments
+ for (int i = 1; i < argc; i++) {
+ if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) {
+ config.n = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) {
+ config.nr = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) {
+ config.nc = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-i") == 0 && i + 1 < argc) {
+ config.iterations = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-w") == 0 && i + 1 < argc) {
+ config.warmup = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-h") == 0) {
+ print_usage(argv[0]);
+ return 0;
+ } else {
+ fprintf(stderr, "Unknown option: %s\n", argv[i]);
+ print_usage(argv[0]);
+ return 1;
+ }
+ }
+
+ // Validate configuration
+ if (config.n % 4 != 0) {
+ fprintf(stderr, "Error: Embedding dimension (-n) must be divisible by 4\n");
+ return 1;
+ }
+
+ if (config.n <= 0 || config.nr <= 0 || config.nc <= 0 || config.iterations <= 0) {
+ fprintf(stderr, "Error: All size parameters must be positive\n");
+ return 1;
+ }
+
+ // Run benchmark
+ print_config(config);
+ run_benchmark(config);
+
+ return 0;
+}
+EOF
+
+ # Compiler flags
+ CXXFLAGS="-O3 -march=native -mtune=native -std=c++17 -fopenmp"
+ CXXFLAGS+=" -I${SCRIPT_DIR}/.. -I${SCRIPT_DIR}/../include"
+ CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/include"
+ CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/src"
+ CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/include"
+ CXXFLAGS+=" -DNDEBUG -ffast-math"
+
+ # Link flags
+ LDFLAGS="-lm -lpthread"
+
+ # Link with pre-built libraries
+ GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
+ GGML_SO="${GGML_LIB_DIR}/libggml.so"
+
+ if [ ! -f "${GGML_SO}" ]; then
+ echo "❌ Error: Cannot find libggml.so at ${GGML_SO}"
+ echo "Please build the project first with: cmake --build build"
+ rm -f "${TEMP_CPP}"
+ exit 1
+ fi
+
+ LDFLAGS+=" -L${GGML_LIB_DIR} -lggml -Wl,-rpath,${GGML_LIB_DIR}"
+
+ # Output binary
+ BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
+
+ echo "Compiler: ${CXX}"
+ echo "Building from embedded source..."
+ echo ""
+
+ # Build
+ ${CXX} ${CXXFLAGS} "${TEMP_CPP}" -o ${BENCHMARK_BIN} ${LDFLAGS}
+
+ if [ $? -eq 0 ]; then
+ echo "✅ Build successful!"
+ rm -f "${TEMP_CPP}"
+ echo ""
+ else
+ echo "❌ Build failed!"
+ rm -f "${TEMP_CPP}"
+ exit 1
+ fi
+else
+ echo "Step 1: Skipping build (using existing binary)"
+ echo "------------------------------------------"
+ BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
+
+ if [ ! -f "${BENCHMARK_BIN}" ]; then
+ echo "❌ Error: Benchmark binary not found at ${BENCHMARK_BIN}"
+ echo "Please run without -s to build it first."
+ exit 1
+ fi
+ echo "✅ Found existing binary"
+ echo ""
+fi
+
+# Set LD_LIBRARY_PATH to include the GGML library directory
+GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
+export LD_LIBRARY_PATH="${GGML_LIB_DIR}:${LD_LIBRARY_PATH}"
+
+echo "Step 2: Running benchmark tests"
+echo "------------------------------------------"
+echo "Library path: ${GGML_LIB_DIR}"
+echo ""
+
+# Write CSV header
+echo "test_name,n,nr,nc,time_ms,gflops,throughput_tokens_per_sec" > "$OUTPUT_CSV"
+echo "Results will be saved to: $OUTPUT_CSV"
+echo ""
+
+# Function to extract metrics and append to CSV
+extract_and_save() {
+ local test_name="$1"
+ local output="$2"
+
+ # Extract values using grep and awk
+ local n=$(echo "$output" | grep "Embedding dimension" | awk '{print $5}')
+ local nr=$(echo "$output" | grep "Matrix Y rows" | awk '{print $6}')
+ local nc=$(echo "$output" | grep "Matrix X columns" | awk '{print $6}')
+ local avg_time=$(echo "$output" | grep "Average time" | awk '{print $4}')
+ local min_time=$(echo "$output" | grep "Min time" | awk '{print $4}')
+ local max_time=$(echo "$output" | grep "Max time" | awk '{print $4}')
+ local gflops=$(echo "$output" | grep "GFLOPS" | awk '{print $3}')
+ local throughput=$(echo "$output" | grep "Throughput" | awk '{print $3}')
+
+ # Check if values were extracted successfully
+ if [ -z "$avg_time" ] || [ -z "$min_time" ] || [ -z "$max_time" ]; then
+ echo "Warning: Failed to extract timing data for ${test_name}"
+ echo "${test_name},${n},${nr},${nc},N/A,N/A,N/A" >> "$OUTPUT_CSV"
+ return
+ fi
+
+ # Calculate standard deviation estimate from range
+ # Using awk with proper variable passing
+ local std_time=$(awk -v min="$min_time" -v max="$max_time" 'BEGIN {printf "%.4f", (max - min) / 4}')
+
+ # Format as mean±std
+ local time_formatted="${avg_time}±${std_time}"
+
+ # Append to CSV
+ echo "${test_name},${n},${nr},${nc},${time_formatted},${gflops},${throughput}" >> "$OUTPUT_CSV"
+}
+
+# Run benchmark tests
+echo "=========================================="
+echo "BitNet-2B Typical Shapes Performance Test"
+echo "=========================================="
+echo ""
+
+echo "Test 1: Single Token Generation (Attention QKV projection)"
+echo " Scenario: Generating 1 token at a time"
+echo " Shape: n=2048, r=1, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 1 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "single_token_gen" "$OUTPUT"
+echo ""
+
+echo "Test 2: Small Batch Prompt Processing (Attention QKV projection)"
+echo " Scenario: Processing prompt with 128 tokens, batch size 1"
+echo " Shape: n=2048, r=128, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "small_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 3: Medium Batch Prompt Processing (Attention QKV projection)"
+echo " Scenario: Processing prompt with 256 tokens or batch of 256"
+echo " Shape: n=2048, r=256, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 256 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "medium_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 4: Large Batch Processing (Attention QKV projection)"
+echo " Scenario: Processing 512 tokens or batch of 512"
+echo " Shape: n=2048, r=512, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 512 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "large_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 5: FFN Up-projection (Small batch)"
+echo " Scenario: Feed-forward network expansion, 128 tokens"
+echo " Shape: n=2048, r=128, c=8192"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 8192 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "ffn_up_projection" "$OUTPUT"
+echo ""
+
+echo "Test 6: FFN Down-projection (Small batch)"
+echo " Scenario: Feed-forward network reduction, 128 tokens"
+echo " Shape: n=8192, r=128, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 8192 -r 128 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "ffn_down_projection" "$OUTPUT"
+echo ""
+
+echo "Test 7: Long Context Processing"
+echo " Scenario: Processing very long context (2048 tokens)"
+echo " Shape: n=2048, r=2048, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 2048 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "long_context" "$OUTPUT"
+echo ""
+
+echo "Test 8: Batched Token Generation"
+echo " Scenario: Generating tokens for 32 sequences simultaneously"
+echo " Shape: n=2048, r=32, c=2048"
+OUTPUT=$($BENCHMARK_BIN -n 2048 -r 32 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "batched_token_gen" "$OUTPUT"
+echo ""
+
+echo "=========================================="
+echo "All tests completed successfully!"
+echo "=========================================="
+echo "Results saved to: $OUTPUT_CSV"
+echo ""
+echo "Summary:"
+wc -l "$OUTPUT_CSV" | awk '{print " Total records:", $1 - 1}'
+echo " Output file: $OUTPUT_CSV"
+echo "=========================================="
diff --git a/utils/test_perplexity.py b/utils/test_perplexity.py
new file mode 100644
index 0000000..f2d9788
--- /dev/null
+++ b/utils/test_perplexity.py
@@ -0,0 +1,608 @@
+#!/usr/bin/env python3
+"""
+Perplexity Test Script
+Tests GGUF model perplexity on multiple datasets using llama-perplexity.
+"""
+
+import os
+import subprocess
+import time
+import csv
+import re
+from datetime import datetime
+from pathlib import Path
+import argparse
+import tempfile
+import shutil
+import statistics
+
+
+class PerplexityTester:
+ def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity",
+ data_dir="../data", output_dir="perplexity_results", quick_mode=False,
+ quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None):
+ self.model_path = Path(model_path)
+ self.llama_perplexity_bin = Path(llama_perplexity_bin)
+ self.quantize_bin = Path(quantize_bin)
+ self.data_dir = Path(data_dir)
+ self.output_dir = Path(output_dir)
+ self.quick_mode = quick_mode
+ self.test_embeddings = test_embeddings
+ self.csv_output = Path(csv_output) if csv_output else None
+ self.results = []
+ self.created_models = set() # Track newly created model files
+ self.temp_files = [] # Track temporary files for cleanup
+
+ # Embedding types to test
+ self.embedding_types = [
+ ('F32', 'f32'),
+ ('F16', 'f16'),
+ ('Q8_0', 'q8_0'),
+ ('Q6_K', 'q6_k'),
+ ('Q5_0', 'q5_0'),
+ ('Q4_0', 'q4_0'),
+ ('Q3_K', 'q3_k'),
+ ('TQ2_0', 'tq2_0'),
+ ]
+
+ # Create output directory
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Verify llama-perplexity binary exists
+ if not self.llama_perplexity_bin.exists():
+ raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}")
+
+ # Verify quantize binary exists if testing embeddings
+ if self.test_embeddings and not self.quantize_bin.exists():
+ raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}")
+
+ # Verify model file exists
+ if not self.model_path.exists():
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
+
+ def find_datasets(self):
+ """Find all test.txt files in dataset directories."""
+ datasets = []
+
+ if not self.data_dir.exists():
+ print(f"❌ Data directory not found: {self.data_dir}")
+ return datasets
+
+ print(f"\n🔍 Searching for datasets in {self.data_dir}...")
+
+ # Look for test.txt files in subdirectories
+ for dataset_dir in sorted(self.data_dir.iterdir()):
+ if dataset_dir.is_dir():
+ test_file = dataset_dir / "test.txt"
+ if test_file.exists():
+ size_mb = test_file.stat().st_size / (1024 * 1024)
+ datasets.append({
+ 'name': dataset_dir.name,
+ 'path': test_file,
+ 'size': test_file.stat().st_size,
+ 'size_mb': size_mb
+ })
+ print(f" ✅ {dataset_dir.name:<20} ({size_mb:.2f} MB)")
+ else:
+ print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)")
+
+ return datasets
+
+ def create_quick_dataset(self, dataset_path, num_chars=4096):
+ """Create a temporary dataset with only the first N characters for quick testing."""
+ temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8')
+ self.temp_files.append(temp_file.name)
+
+ try:
+ with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read(num_chars)
+ temp_file.write(content)
+ temp_file.close()
+ return Path(temp_file.name)
+ except Exception as e:
+ print(f"⚠️ Failed to create quick dataset: {e}")
+ temp_file.close()
+ return dataset_path
+
+ def cleanup_temp_files(self):
+ """Clean up temporary files."""
+ for temp_file in self.temp_files:
+ try:
+ os.unlink(temp_file)
+ except:
+ pass
+ self.temp_files = []
+
+ def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None):
+ """Run perplexity test on a single dataset."""
+ test_model = model_override if model_override else self.model_path
+
+ print(f"\n{'='*80}")
+ print(f"📊 Testing on dataset: {dataset_name}")
+ print(f" File: {dataset_path}")
+ print(f" Model: {test_model.name}")
+ print(f"{'='*80}")
+
+ cmd = [
+ str(self.llama_perplexity_bin),
+ "-m", str(test_model),
+ "-f", str(dataset_path),
+ "-t", str(threads),
+ "-c", str(ctx_size),
+ "-ngl", "0" # CPU only
+ ]
+
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"⏱️ Starting test...\n")
+
+ start_time = time.time()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ timeout=3600, # 1 hour timeout
+ cwd=os.getcwd()
+ )
+
+ elapsed_time = time.time() - start_time
+
+ if result.returncode == 0:
+ # Parse perplexity from output (check both stdout and stderr)
+ combined_output = result.stdout + "\n" + result.stderr
+ ppl = self.parse_perplexity(combined_output)
+
+ if ppl is not None:
+ print(f"\n✅ Perplexity: {ppl}")
+ print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)")
+ status = "success"
+ else:
+ print(f"\n⚠️ Test completed but could not parse perplexity")
+ print(f"Last 500 chars of stdout:")
+ print(result.stdout[-500:])
+ print(f"Last 500 chars of stderr:")
+ print(result.stderr[-500:])
+ status = "parse_error"
+ ppl = None
+ else:
+ print(f"\n❌ Test failed with return code {result.returncode}")
+ print(f"Error: {result.stderr[:500]}")
+ status = "failed"
+ ppl = None
+ elapsed_time = time.time() - start_time
+
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': ppl,
+ 'time': elapsed_time,
+ 'status': status,
+ 'stdout': result.stdout,
+ 'stderr': result.stderr
+ }
+
+ except subprocess.TimeoutExpired:
+ elapsed_time = time.time() - start_time
+ print(f"\n❌ Timeout after {elapsed_time:.2f}s")
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': None,
+ 'time': elapsed_time,
+ 'status': 'timeout',
+ 'stdout': '',
+ 'stderr': 'Test exceeded 1 hour timeout'
+ }
+ except Exception as e:
+ elapsed_time = time.time() - start_time
+ print(f"\n❌ Error: {e}")
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': None,
+ 'time': elapsed_time,
+ 'status': 'error',
+ 'stdout': '',
+ 'stderr': str(e)
+ }
+
+ def parse_perplexity(self, output):
+ """Parse perplexity value (mean±std format) from llama-perplexity output."""
+ # First try to match "PPL = mean +/- std" format
+ pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)'
+ match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE)
+ if match:
+ try:
+ mean = float(match.group(1))
+ std = float(match.group(2))
+ return f"{mean:.4f}±{std:.4f}"
+ except ValueError:
+ pass
+
+ # Fallback to patterns without std
+ patterns = [
+ r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)',
+ r'Final perplexity:\s*(\d+\.?\d*)',
+ r'PPL\s*=\s*(\d+\.?\d*)',
+ r'PPL:\s*(\d+\.?\d*)',
+ r'perplexity:\s*(\d+\.?\d*)',
+ r'ppl\s*=\s*(\d+\.?\d*)',
+ r'Perplexity:\s*(\d+\.?\d*)',
+ ]
+
+ for pattern in patterns:
+ match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE)
+ if match:
+ try:
+ return f"{float(match.group(1)):.4f}"
+ except ValueError:
+ continue
+
+ return None
+
+ def quantize_embedding(self, embedding_type, output_suffix):
+ """
+ Quantize model with specific embedding type.
+
+ Args:
+ embedding_type: Token embedding type (uppercase, e.g., 'Q6_K')
+ output_suffix: Output file suffix (lowercase, e.g., 'q6_k')
+
+ Returns:
+ Path to quantized model or None if failed
+ """
+ # Construct output path
+ model_dir = self.model_path.parent
+ output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ # Check if file already exists
+ file_existed = output_path.exists()
+
+ if file_existed:
+ print(f"ℹ️ Model already exists: {output_path.name}")
+ return output_path
+
+ cmd = [
+ str(self.quantize_bin),
+ "--token-embedding-type", embedding_type,
+ str(self.model_path),
+ str(output_path),
+ "I2_S",
+ "1",
+ "1"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🔄 Quantizing with embedding type: {embedding_type}")
+ print(f"📥 Input: {self.model_path.name}")
+ print(f"📤 Output: {output_path.name}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ start_time = time.time()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=600 # 10 minutes timeout
+ )
+
+ duration = time.time() - start_time
+
+ if result.returncode == 0:
+ file_size_mb = output_path.stat().st_size / (1024 * 1024)
+ print(f"✅ Quantization successful!")
+ print(f" Duration: {duration:.2f}s")
+ print(f" Size: {file_size_mb:.2f} MB")
+
+ # Mark as newly created
+ self.created_models.add(output_path)
+ return output_path
+ else:
+ print(f"❌ Quantization failed with return code {result.returncode}")
+ print(f"Error: {result.stderr[:500]}")
+ return None
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Quantization timeout (exceeded 10 minutes)")
+ return None
+ except Exception as e:
+ print(f"❌ Quantization error: {e}")
+ return None
+
+ def cleanup_model(self, model_path):
+ """Delete model file if it was created during this session."""
+ if model_path in self.created_models:
+ try:
+ model_path.unlink()
+ print(f"🗑️ Deleted: {model_path.name}")
+ self.created_models.remove(model_path)
+ except Exception as e:
+ print(f"⚠️ Failed to delete {model_path.name}: {e}")
+ else:
+ print(f"ℹ️ Keeping existing file: {model_path.name}")
+
+ def run_all_tests(self, threads=16, ctx_size=512):
+ """Run perplexity tests on all datasets."""
+ datasets = self.find_datasets()
+
+ if not datasets:
+ print(f"\n❌ No datasets found in {self.data_dir}")
+ print(f" Make sure each dataset directory has a test.txt file")
+ return
+
+ # Quick mode: test all datasets but only first 4096 chars with smaller context
+ if self.quick_mode:
+ ctx_size = min(ctx_size, 128) # Use smaller context in quick mode
+ print(f"\n⚡ QUICK TEST MODE ENABLED")
+ print(f" - Testing all datasets with first 4096 characters only")
+ print(f" - Using reduced context size: {ctx_size}")
+
+ # Determine models to test
+ if self.test_embeddings:
+ print(f"\n{'='*80}")
+ print(f"🧪 EMBEDDING QUANTIZATION TEST MODE")
+ print(f"{'='*80}")
+ print(f"📦 Base model: {self.model_path.name}")
+ print(f"🔢 Embedding types to test: {len(self.embedding_types)}")
+ print(f"📊 Datasets: {len(datasets)}")
+ print(f"🧵 Threads: {threads}")
+ print(f"📏 Context size: {ctx_size}")
+ print(f"{'='*80}")
+
+ total_start = time.time()
+
+ # Test each embedding type
+ for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1):
+ print(f"\n\n{'#'*80}")
+ print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})")
+ print(f"{'#'*80}")
+
+ # Quantize model
+ quantized_model = self.quantize_embedding(embedding_type, output_suffix)
+
+ if quantized_model is None:
+ print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure")
+ continue
+
+ # Test on all datasets
+ for j, dataset in enumerate(datasets, 1):
+ print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...")
+
+ # Use quick dataset if in quick mode
+ test_path = dataset['path']
+ if self.quick_mode:
+ test_path = self.create_quick_dataset(dataset['path'])
+
+ result = self.run_perplexity_test(
+ f"{dataset['name']}_embed-{output_suffix}",
+ test_path,
+ threads,
+ ctx_size,
+ model_override=quantized_model
+ )
+ self.results.append(result)
+
+ # Cleanup model after testing
+ print(f"\n🧹 Cleaning up {output_suffix} model...")
+ self.cleanup_model(quantized_model)
+
+ print(f"\n{'#'*80}")
+ print(f"✅ Completed {output_suffix}")
+ print(f"{'#'*80}")
+
+ total_time = time.time() - total_start
+
+ else:
+ # Regular single model test
+ print(f"\n{'='*80}")
+ print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}")
+ print(f"{'='*80}")
+ print(f"📦 Model: {self.model_path.name}")
+ print(f"📁 Model path: {self.model_path}")
+ print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}")
+ print(f"🧵 Threads: {threads}")
+ print(f"📏 Context size: {ctx_size}")
+ print(f"{'='*80}")
+
+ total_start = time.time()
+
+ # Run tests
+ for i, dataset in enumerate(datasets, 1):
+ print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...")
+
+ # Use quick dataset if in quick mode
+ test_path = dataset['path']
+ if self.quick_mode:
+ test_path = self.create_quick_dataset(dataset['path'])
+
+ result = self.run_perplexity_test(
+ dataset['name'],
+ test_path,
+ threads,
+ ctx_size
+ )
+ self.results.append(result)
+
+ total_time = time.time() - total_start
+
+ # Clean up temporary files
+ if self.quick_mode:
+ print(f"\n🧹 Cleaning up temporary files...")
+ self.cleanup_temp_files()
+
+ # Save results
+ self.save_results()
+
+ # Print summary
+ self.print_summary(total_time)
+
+ def save_results(self):
+ """Save results to CSV file."""
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ model_name = self.model_path.stem
+
+ # Use custom CSV path if provided
+ if self.csv_output:
+ csv_file = self.csv_output
+ # Create parent directory if needed
+ csv_file.parent.mkdir(parents=True, exist_ok=True)
+ else:
+ csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv"
+
+ print(f"\n💾 Saving results...")
+
+ with open(csv_file, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status'])
+ writer.writeheader()
+ for result in self.results:
+ writer.writerow({
+ 'dataset': result['dataset'],
+ 'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A',
+ 'time_seconds': f"{result['time']:.2f}",
+ 'status': result['status']
+ })
+
+ print(f" ✅ CSV saved: {csv_file}")
+
+ # Save detailed log
+ log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log"
+ with open(log_file, 'w') as f:
+ f.write(f"Perplexity Test Results\n")
+ f.write(f"{'='*80}\n")
+ f.write(f"Model: {self.model_path}\n")
+ f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+ f.write(f"{'='*80}\n\n")
+
+ for result in self.results:
+ f.write(f"\n{'='*80}\n")
+ f.write(f"Dataset: {result['dataset']}\n")
+ f.write(f"Perplexity: {result['perplexity']}\n")
+ f.write(f"Time: {result['time']:.2f}s\n")
+ f.write(f"Status: {result['status']}\n")
+ f.write(f"\nOutput:\n{result['stdout']}\n")
+ if result['stderr']:
+ f.write(f"\nErrors:\n{result['stderr']}\n")
+
+ print(f" ✅ Log saved: {log_file}")
+
+ def print_summary(self, total_time):
+ """Print summary of all tests."""
+ print(f"\n\n{'='*80}")
+ print(f"📊 TEST SUMMARY")
+ print(f"{'='*80}\n")
+
+ # Sort results by perplexity (lower is better)
+ successful = [r for r in self.results if r['perplexity'] is not None]
+ failed = [r for r in self.results if r['perplexity'] is None]
+
+ if successful:
+ # Extract numeric value from "mean±std" format for sorting
+ def get_ppl_value(result):
+ ppl = result['perplexity']
+ if isinstance(ppl, str) and '±' in ppl:
+ return float(ppl.split('±')[0])
+ elif isinstance(ppl, str):
+ try:
+ return float(ppl)
+ except ValueError:
+ return float('inf')
+ return ppl
+
+ successful_sorted = sorted(successful, key=get_ppl_value)
+
+ print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}")
+ print(f"{'-'*80}")
+
+ for result in successful_sorted:
+ ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A'
+ print(f"{result['dataset']:<20} {ppl_str:>20} "
+ f"{result['time']:>12.2f} {result['status']:<15}")
+
+ best_ppl = str(successful_sorted[0]['perplexity'])
+ print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} "
+ f"(PPL: {best_ppl})")
+
+ if failed:
+ print(f"\n❌ Failed tests ({len(failed)}):")
+ for result in failed:
+ print(f" - {result['dataset']}: {result['status']}")
+
+ print(f"\n{'='*80}")
+ print(f"✅ Completed: {len(successful)}/{len(self.results)}")
+ print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
+ print(f"📁 Results saved in: {self.output_dir}")
+ print(f"{'='*80}\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets')
+ parser.add_argument('--model', '-m',
+ required=True,
+ help='Path to GGUF model file')
+ parser.add_argument('--data-dir', '-d',
+ default='data',
+ help='Directory containing dataset folders (default: data)')
+ parser.add_argument('--threads', '-t',
+ type=int,
+ default=16,
+ help='Number of threads (default: 16)')
+ parser.add_argument('--ctx-size', '-c',
+ type=int,
+ default=512,
+ help='Context size (default: 512)')
+ parser.add_argument('--output-dir', '-o',
+ default='perplexity_results',
+ help='Output directory for results (default: perplexity_results)')
+ parser.add_argument('--llama-perplexity',
+ default='./build/bin/llama-perplexity',
+ help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)')
+ parser.add_argument('--quick', '-q',
+ action='store_true',
+ help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)')
+ parser.add_argument('--test-embeddings', '-e',
+ action='store_true',
+ help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)')
+ parser.add_argument('--csv-output',
+ help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)')
+ parser.add_argument('--quantize-bin',
+ default='./build/bin/llama-quantize',
+ help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)')
+
+ args = parser.parse_args()
+
+ try:
+ tester = PerplexityTester(
+ model_path=args.model,
+ llama_perplexity_bin=args.llama_perplexity,
+ data_dir=args.data_dir,
+ output_dir=args.output_dir,
+ quick_mode=args.quick,
+ quantize_bin=args.quantize_bin,
+ test_embeddings=args.test_embeddings,
+ csv_output=args.csv_output
+ )
+
+ tester.run_all_tests(
+ threads=args.threads,
+ ctx_size=args.ctx_size
+ )
+
+ except FileNotFoundError as e:
+ print(f"❌ Error: {e}")
+ return 1
+ except KeyboardInterrupt:
+ print("\n\n⚠️ Test interrupted by user")
+ return 1
+ except Exception as e:
+ print(f"\n❌ Unexpected error: {e}")
+ import traceback
+ traceback.print_exc()
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())
diff --git a/utils/test_power.sh b/utils/test_power.sh
new file mode 100755
index 0000000..79a1a68
--- /dev/null
+++ b/utils/test_power.sh
@@ -0,0 +1,151 @@
+#!/bin/bash
+# Monitor power consumption for llama-bench with different thread configurations
+# Usage: ./monitor_power.sh
+# Example: ./monitor_power.sh models/model.gguf results.csv "1,2,4,8" "1,2,4,8"
+
+set -e
+
+# Parse arguments
+if [ $# -ne 4 ]; then
+ echo "Usage: $0 "
+ echo "Example: $0 models/model.gguf results.csv \"1,2,4,8\" \"1,2,4,8\""
+ exit 1
+fi
+
+MODEL_PATH="$1"
+OUTPUT_CSV="$2"
+PP_THREADS="$3"
+TG_THREADS="$4"
+
+TEMP_LOG="/tmp/power_monitor_$$.log"
+PID_FILE="/tmp/monitor_$$.pid"
+BENCH_OUTPUT="/tmp/bench_output_$$.txt"
+
+# Validate model exists
+if [ ! -f "$MODEL_PATH" ]; then
+ echo "Error: Model file not found: $MODEL_PATH"
+ exit 1
+fi
+
+# Create output directory if needed
+mkdir -p "$(dirname "$OUTPUT_CSV")"
+
+# Function to monitor CPU stats
+monitor_cpu() {
+ local log_file="$1"
+ echo "Timestamp,CPU_Usage(%),Avg_Freq(MHz)" > "$log_file"
+ while [ -f "$PID_FILE" ]; do
+ cpu_usage=$(top -bn1 | grep "Cpu(s)" | awk '{print 100-$8}')
+ avg_freq=$(grep "cpu MHz" /proc/cpuinfo | awk '{sum+=$4; count++} END {printf "%.0f", sum/count}')
+ timestamp=$(date +%s.%N)
+ echo "$timestamp,$cpu_usage,$avg_freq" >> "$log_file"
+ sleep 0.5
+ done
+}
+
+# Function to calculate average power
+calculate_power() {
+ local log_file="$1"
+ awk -F',' 'NR>1 {sum_cpu+=$2; count++} END {
+ if (count > 0) {
+ avg_cpu = sum_cpu/count
+ est_power = avg_cpu * 200 / 100
+ printf "%.2f", est_power
+ } else {
+ print "0"
+ }
+ }' "$log_file"
+}
+
+# Function to extract throughput from llama-bench output
+extract_throughput() {
+ local bench_output="$1"
+ local workload="$2"
+ grep "$workload" "$bench_output" | awk '{
+ # Extract mean from "mean ± std" format
+ for (i=1; i<=NF; i++) {
+ if ($(i+1) == "±") {
+ printf "%.2f", $i
+ exit
+ }
+ }
+ }'
+}
+
+# Function to run single benchmark
+run_benchmark() {
+ local workload="$1" # "pp" or "tg"
+ local threads="$2"
+ local n_flag=""
+
+ if [ "$workload" = "pp" ]; then
+ n_flag="-n 0"
+ workload_name="pp128"
+ else
+ n_flag="-n 128"
+ workload_name="tg128"
+ fi
+
+ # Output progress to stderr (won't be captured in CSV)
+ echo "Testing $workload_name with $threads threads..." >&2
+
+ # Start monitoring
+ touch "$PID_FILE"
+ monitor_cpu "$TEMP_LOG" &
+ local monitor_pid=$!
+
+ # Run benchmark
+ ./build/bin/llama-bench -m "$MODEL_PATH" -p 128 $n_flag -t "$threads" -ngl 0 > "$BENCH_OUTPUT" 2>&1
+
+ # Stop monitoring
+ rm -f "$PID_FILE"
+ wait $monitor_pid 2>/dev/null || true
+
+ # Extract results
+ local throughput=$(extract_throughput "$BENCH_OUTPUT" "$workload_name")
+ local power=$(calculate_power "$TEMP_LOG")
+
+ if [ -z "$throughput" ] || [ "$throughput" = "0" ]; then
+ echo "Warning: Failed to extract throughput for $workload_name, threads=$threads" >&2
+ throughput="0"
+ fi
+
+ # Calculate J/t (Joules per token)
+ local j_per_token=$(awk -v p="$power" -v t="$throughput" 'BEGIN {
+ if (t > 0) printf "%.4f", p/t; else print "0"
+ }')
+
+ # Output progress to stderr
+ echo " Throughput: $throughput t/s, Power: $power W, Energy: $j_per_token J/t" >&2
+
+ # Only output CSV line to stdout (this will be captured)
+ echo "$workload_name,$threads,$throughput,$power,$j_per_token"
+}
+
+# Initialize CSV
+echo "Workload,Threads,Throughput(t/s),Power(W),Energy(J/t)" > "$OUTPUT_CSV"
+
+# Test PP workloads
+IFS=',' read -ra PP_ARRAY <<< "$PP_THREADS"
+for threads in "${PP_ARRAY[@]}"; do
+ threads=$(echo "$threads" | xargs) # trim whitespace
+ result=$(run_benchmark "pp" "$threads")
+ echo "$result" >> "$OUTPUT_CSV"
+done
+
+# Test TG workloads
+IFS=',' read -ra TG_ARRAY <<< "$TG_THREADS"
+for threads in "${TG_ARRAY[@]}"; do
+ threads=$(echo "$threads" | xargs) # trim whitespace
+ result=$(run_benchmark "tg" "$threads")
+ echo "$result" >> "$OUTPUT_CSV"
+done
+
+# Cleanup
+rm -f "$TEMP_LOG" "$BENCH_OUTPUT" "$PID_FILE"
+
+echo ""
+echo "=== Benchmark Complete ==="
+echo "Results saved to: $OUTPUT_CSV"
+echo ""
+cat "$OUTPUT_CSV"
diff --git a/utils/tune_gemm_config.py b/utils/tune_gemm_config.py
new file mode 100644
index 0000000..e537cd8
--- /dev/null
+++ b/utils/tune_gemm_config.py
@@ -0,0 +1,362 @@
+#!/usr/bin/env python3
+"""
+GEMM Configuration Tuning Script
+This script automatically tunes ROW_BLOCK_SIZE, COL_BLOCK_SIZE, and PARALLEL_SIZE
+to find the optimal configuration for maximum throughput (t/s).
+"""
+
+import subprocess
+import os
+import re
+import csv
+import shutil
+from datetime import datetime
+from pathlib import Path
+import argparse
+
+
+class GemmTuner:
+ def __init__(self, config_path, model_path, threads=16):
+ self.config_path = Path(config_path)
+ self.model_path = model_path
+ self.threads = threads
+ self.backup_path = self.config_path.parent / f"gemm-config.h.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
+ self.build_dir = Path("../build")
+ self.results = []
+
+ def backup_config(self):
+ """Backup current configuration file"""
+ print(f"📦 Backing up current config to {self.backup_path}")
+ shutil.copy2(self.config_path, self.backup_path)
+
+ def restore_config(self):
+ """Restore original configuration file"""
+ print(f"♻️ Restoring original config from {self.backup_path}")
+ shutil.copy2(self.backup_path, self.config_path)
+
+ def generate_config(self, act_parallel, row_block_size, col_block_size, parallel_size):
+ """Generate new configuration file with simplified format"""
+ content = ""
+
+ # Simplified configuration format
+ if act_parallel:
+ content += "#define ACT_PARALLEL\n"
+
+ content += f"#define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f"#define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f"#define PARALLEL_SIZE {parallel_size}\n"
+
+ with open(self.config_path, 'w') as f:
+ f.write(content)
+
+ def rebuild_project(self):
+ """Rebuild project"""
+ print("🔨 Rebuilding project...")
+ result = subprocess.run(
+ ["cmake", "--build", str(self.build_dir), "--target", "llama-bench"],
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd()
+ )
+ if result.returncode != 0:
+ print(f"⚠️ Build warning/error: {result.stderr}")
+ return False
+ return True
+
+ def run_benchmark(self):
+ """Run benchmark test"""
+ cmd = [
+ f"{self.build_dir}/bin/llama-bench",
+ "-m", self.model_path,
+ "-p", "128",
+ "-n", "0",
+ "-t", str(self.threads),
+ "-ngl", "0"
+ ]
+
+ print(f"⚡ Running benchmark: {' '.join(cmd)}")
+
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=300 # 5分钟超时
+ )
+
+ if result.returncode != 0:
+ print(f"❌ Benchmark failed: {result.stderr}")
+ return None
+
+ return result.stdout
+
+ def parse_throughput(self, output):
+ """Parse pp128 throughput from output"""
+ # 匹配 pp128: | pp128 | 501.06 ± 11.37 |
+ pp_pattern = r'\|\s+pp128\s+\|\s+([\d.]+)\s+±\s+([\d.]+)\s+\|'
+ pp_match = re.search(pp_pattern, output)
+
+ if pp_match:
+ pp_throughput = float(pp_match.group(1))
+ pp_std_dev = float(pp_match.group(2))
+
+ return {
+ 'pp_throughput': pp_throughput,
+ 'pp_std_dev': pp_std_dev
+ }
+
+ return None
+
+ def test_configuration(self, act_parallel, row_block_size, col_block_size, parallel_size):
+ """Test single configuration"""
+ config_name = f"ACT_{'ON' if act_parallel else 'OFF'}_R{row_block_size}_C{col_block_size}_P{parallel_size}"
+ print(f"\n{'='*80}")
+ print(f"🧪 Testing configuration: {config_name}")
+ print(f" ACT_PARALLEL: {act_parallel}")
+ print(f" ROW_BLOCK_SIZE: {row_block_size}")
+ print(f" COL_BLOCK_SIZE: {col_block_size}")
+ print(f" PARALLEL_SIZE: {parallel_size}")
+ print(f"{'='*80}")
+
+ # Generate configuration
+ self.generate_config(act_parallel, row_block_size, col_block_size, parallel_size)
+
+ # Rebuild project
+ if not self.rebuild_project():
+ print("⚠️ Build failed, skipping this configuration")
+ return None
+
+ # Run benchmark test
+ output = self.run_benchmark()
+ if output is None:
+ return None
+
+ # Parse results
+ metrics = self.parse_throughput(output)
+
+ if metrics is not None:
+ result = {
+ 'act_parallel': act_parallel,
+ 'row_block_size': row_block_size,
+ 'col_block_size': col_block_size,
+ 'parallel_size': parallel_size,
+ 'config_name': config_name,
+ **metrics
+ }
+ self.results.append(result)
+ print(f"✅ PP128: {metrics['pp_throughput']:.2f} ± {metrics['pp_std_dev']:.2f} t/s")
+ return result
+ else:
+ print("❌ Failed to parse throughput")
+ return None
+
+ def save_results(self, csv_path):
+ """Save results to CSV file"""
+ print(f"\n💾 Saving results to {csv_path}")
+
+ with open(csv_path, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=[
+ 'config_name', 'act_parallel', 'row_block_size',
+ 'col_block_size', 'parallel_size',
+ 'pp_throughput', 'pp_std_dev'
+ ])
+ writer.writeheader()
+ writer.writerows(self.results)
+
+ def find_best_config(self):
+ """Find the best configuration with highest throughput"""
+ if not self.results:
+ print("❌ No valid results found")
+ return None
+
+ best = max(self.results, key=lambda x: x['pp_throughput'])
+ return best
+
+ def run_tuning(self, configurations, output_csv=None):
+ """Run test for all configurations"""
+ print(f"\n🚀 Starting tuning process with {len(configurations)} configurations")
+ print(f"📊 Model: {self.model_path}")
+ print(f"🧵 Threads: {self.threads}\n")
+
+ # Backup configuration
+ self.backup_config()
+
+ try:
+ # Test all configurations
+ for i, config in enumerate(configurations, 1):
+ print(f"\n[{i}/{len(configurations)}]")
+ self.test_configuration(**config)
+
+ # Save results
+ if output_csv is None:
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+ csv_path = f"../stats/tuning_results_{timestamp}.csv"
+ else:
+ csv_path = output_csv
+
+ # Ensure stats directory exists
+ os.makedirs(os.path.dirname(csv_path), exist_ok=True)
+ self.save_results(csv_path)
+
+ # Find best configuration
+ best = self.find_best_config()
+ if best:
+ print(f"\n{'='*80}")
+ print(f"🏆 BEST CONFIGURATION FOUND!")
+ print(f"{'='*80}")
+ print(f"Configuration: {best['config_name']}")
+ print(f"ACT_PARALLEL: {best['act_parallel']}")
+ print(f"ROW_BLOCK_SIZE: {best['row_block_size']}")
+ print(f"COL_BLOCK_SIZE: {best['col_block_size']}")
+ print(f"PARALLEL_SIZE: {best['parallel_size']}")
+ print(f"PP128 Throughput: {best['pp_throughput']:.2f} ± {best['pp_std_dev']:.2f} t/s")
+ print(f"{'='*80}\n")
+
+ # Show the configuration that will be written
+ print("Configuration to be written to gemm-config.h:")
+ print("-" * 80)
+ if best['act_parallel']:
+ print("#define ACT_PARALLEL")
+ print(f"#define ROW_BLOCK_SIZE {best['row_block_size']}")
+ print(f"#define COL_BLOCK_SIZE {best['col_block_size']}")
+ print(f"#define PARALLEL_SIZE {best['parallel_size']}")
+ print("-" * 80)
+
+ # Apply best configuration
+ apply = input("\nDo you want to apply this configuration to gemm-config.h? (y/n): ").strip().lower()
+ if apply == 'y':
+ self.generate_config(
+ best['act_parallel'],
+ best['row_block_size'],
+ best['col_block_size'],
+ best['parallel_size']
+ )
+ self.rebuild_project()
+ print("✅ Best configuration applied and project rebuilt!")
+ else:
+ self.restore_config()
+ print("✅ Original configuration restored")
+
+ # Clean up backup file
+ if self.backup_path.exists():
+ self.backup_path.unlink()
+ print(f"🗑️ Removed backup file: {self.backup_path}")
+
+ except KeyboardInterrupt:
+ print("\n⚠️ Tuning interrupted by user")
+ self.restore_config()
+ # Clean up backup file
+ if self.backup_path.exists():
+ self.backup_path.unlink()
+ print(f"🗑️ Removed backup file: {self.backup_path}")
+ except Exception as e:
+ print(f"\n❌ Error during tuning: {e}")
+ self.restore_config()
+ # Clean up backup file
+ if self.backup_path.exists():
+ self.backup_path.unlink()
+ print(f"🗑️ Removed backup file: {self.backup_path}")
+ raise
+
+
+def generate_configurations():
+ """Generate list of configurations to test"""
+ configurations = []
+
+ act_parallel_options = [True]
+
+ row_sizes = [2, 4, 8]#[2, 4, 8, 16, 32]
+ col_sizes = [32, 64]#[32, 64, 128, 256, 512, 1024]
+ parallelism_degree = [4]
+
+ for act_parallel in act_parallel_options:
+ for row in row_sizes:
+ for col in col_sizes:
+ for parallel in parallelism_degree:
+ # Add filtering conditions
+ if act_parallel:
+ # When ACT_PARALLEL=True, only calculate combinations with parallel < row
+ if parallel > row:
+ continue
+ else:
+ # When ACT_PARALLEL=False, only calculate combinations with parallel < col
+ if parallel > col:
+ continue
+
+ configurations.append({
+ 'act_parallel': act_parallel,
+ 'row_block_size': row,
+ 'col_block_size': col,
+ 'parallel_size': parallel
+ })
+
+ return configurations
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Tune GEMM configuration for optimal performance')
+ parser.add_argument('--config', default='../include/gemm-config.h',
+ help='Path to gemm-config.h file')
+ parser.add_argument('--model', default='../models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf',
+ help='Path to model file')
+ parser.add_argument('--threads', type=int, default=8,
+ help='Number of threads to use')
+ parser.add_argument('--quick', action='store_true',
+ help='Quick test with fewer configurations')
+ parser.add_argument('--custom', action='store_true',
+ help='Manually specify configurations to test')
+ parser.add_argument('--output', type=str, default=None,
+ help='Output CSV file path (default: stats/tuning_results_.csv)')
+
+ args = parser.parse_args()
+
+ tuner = GemmTuner(args.config, args.model, args.threads)
+
+ if args.custom:
+ # Custom configuration mode
+ print("Custom configuration mode")
+ configurations = []
+ while True:
+ print("\nEnter configuration (or 'done' to finish):")
+ act = input("ACT_PARALLEL (y/n): ").strip().lower() == 'y'
+ if input == 'done':
+ break
+ row = int(input("ROW_BLOCK_SIZE: "))
+ col = int(input("COL_BLOCK_SIZE: "))
+ par = int(input("PARALLEL_SIZE: "))
+ configurations.append({
+ 'act_parallel': act,
+ 'row_block_size': row,
+ 'col_block_size': col,
+ 'parallel_size': par
+ })
+ elif args.quick:
+ # Quick test mode - test only a few key configurations
+ configurations = [
+ {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 128, 'parallel_size': 4},
+ {'act_parallel': True, 'row_block_size': 8, 'col_block_size': 128, 'parallel_size': 4},
+ {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 64, 'parallel_size': 4},
+ {'act_parallel': False, 'row_block_size': 32, 'col_block_size': 4, 'parallel_size': 4},
+ {'act_parallel': False, 'row_block_size': 16, 'col_block_size': 4, 'parallel_size': 4},
+ ]
+ else:
+ # Full test mode
+ configurations = generate_configurations()
+
+ print(f"\n{'='*80}")
+ print(f"GEMM Configuration Tuner")
+ print(f"{'='*80}")
+ print(f"Total configurations to test: {len(configurations)}")
+ print(f"Estimated time: ~{len(configurations) * 0.5:.1f} minutes (assuming 30s per test)")
+ print(f"{'='*80}\n")
+
+ proceed = input("Proceed with tuning? (y/n): ").strip().lower()
+ if proceed != 'y':
+ print("Tuning cancelled")
+ return
+
+ tuner.run_tuning(configurations, output_csv=args.output)
+
+
+if __name__ == "__main__":
+ main()