Merge pull request #379 from XsquirrelC/main

BitNet CPU Inference Optimization
This commit is contained in:
tsong-ms
2026-01-27 11:24:02 +08:00
committed by GitHub
21 changed files with 3349 additions and 248 deletions
+5 -4
View File
@@ -10,10 +10,10 @@ bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.5
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
<img src="./assets/m2_performance.jpg" alt="m2_performance" width="800"/>
<img src="./assets/intel_performance.jpg" alt="m2_performance" width="800"/>
**Latest optimization** introduces parallel kernel implementations with configurable tiling and embedding quantization support, achieving **1.15x to 2.1x** additional speedup over the original implementation across different hardware platforms and workloads. For detailed technical information, see the [optimization guide](src/README.md).
<img src="./assets/performance.png" alt="performance_comparison" width="800"/>
>The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
## Demo
@@ -22,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
## What's New:
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) ![NEW](https://img.shields.io/badge/NEW-red)
- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/XsquirrelC/BitNet/blob/main/src/README.md) ![NEW](https://img.shields.io/badge/NEW-red)
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md)
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
Binary file not shown.

Before

Width:  |  Height:  |  Size: 353 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

+35
View File
@@ -0,0 +1,35 @@
#define ACT_PARALLEL
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(ACT_PARALLEL)
#define ROW_BLOCK_SIZE 4
#define COL_BLOCK_SIZE 128
#define PARALLEL_SIZE 4
#else
#define ROW_BLOCK_SIZE 128
#define COL_BLOCK_SIZE 32
#define PARALLEL_SIZE 8
#endif // ACT_PARALLEL
#elif defined(__ARM_NEON)
#if defined(__ARM_FEATURE_DOTPROD)
#if defined(ACT_PARALLEL)
#define ROW_BLOCK_SIZE 8
#define COL_BLOCK_SIZE 256
#define PARALLEL_SIZE 8
#else
#define ROW_BLOCK_SIZE 64
#define COL_BLOCK_SIZE 16
#define PARALLEL_SIZE 2
#endif // ACT_PARALLEL
#else
#if defined(ACT_PARALLEL)
#define ROW_BLOCK_SIZE 8
#define COL_BLOCK_SIZE 256
#define PARALLEL_SIZE 4
#else
#define ROW_BLOCK_SIZE 128
#define COL_BLOCK_SIZE 32
#define PARALLEL_SIZE 4
#endif // ACT_PARALLEL
#endif // __ARM_FEATURE_DOTPROD
#endif // __AVX__
+2 -2
View File
@@ -64,8 +64,8 @@ SUPPORTED_QUANT_TYPES = {
}
COMPILER_EXTRA_ARGS = {
"arm64": ["-DBITNET_ARM_TL1=ON"],
"x86_64": ["-DBITNET_X86_TL2=ON"]
"arm64": ["-DBITNET_ARM_TL1=OFF"],
"x86_64": ["-DBITNET_X86_TL2=OFF"]
}
OS_EXTRA_ARGS = {
+205
View File
@@ -0,0 +1,205 @@
# BitNet CPU Inference Optimization
This update provides significant performance improvements for BitNet inference on CPU through paralleled kernel implementations, native I2_S GEMM/GEMV support, configurable tiling block size and embedding quantization.
## Update
- **Parallel Weight & Activation Computation**
Implemented parallel processing of weights and activations in the W2A8 vet_dot kernel, achieving improved throughput on both x86 and ARM architectures.
- **Native I2_S GEMM & GEMV Support**
Integrated I2_S GEMM and GEMV operations into ggml library, making them fully compatible with the llama.cpp architecture. This enables seamless integration with existing inference pipelines.
- **Configurable Tiling & Parallelism**
Introduced configurable GEMM & GEMV block sizes and parallelism levels, allowing performance fine-tuning for different CPU architectures.
- **Embedding Quantization**
Added support for embedding layer quantization with Q6_K format, reducing memory footprint and improving inference speed while maintaining high accuracy.
## Usage
### Configuration Options
The `include/gemm-config.h` file controls kernel behavior:
```c
#define ROW_BLOCK_SIZE 4
#define COL_BLOCK_SIZE 128
#define PARALLEL_SIZE 4
```
Modify these values based on your CPU cache size and architecture for optimal performance. Users can fine-tune performance on their machine through `include/gemm-config.h`.
### Enabling Embedding Quantization
To use embedding quantization for additional speedup:
**Using setup_env.py:**
```bash
python setup_env.py --quant-embd
```
This automatically converts embeddings to Q6_K format.
**Manual conversion:**
```bash
build/bin/llama-quantize --token-embedding-type Q6_K models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf I2_S 1 1
```
## Optimizations
### 1. Weight & Activation Parallelism
The kernel implements two parallelization strategies:
- **Weight Parallel:** Processes multiple weight rows/columns in a single kernel call, reducing kernel launch overhead.
- **Activation Parallel:** Built on top of weight parallel, amortizes the I2_S weight unpacking cost across multiple activation elements.
**Recommendation:** For I2_S quantization format, activation parallel is recommended due to the unpack operation benefits. The current kernel defaults to activation parallel.
**Kernel Performance Comparison:**
<div align="center">
Test configuration: AMD EPYC 7V13 (x86), 1 threads, time in milliseconds (mean±std)
| Matrix Size | No Parallel | Weight Parallel | Activation Parallel |
|:---:|:---:|:---:|:---:|
| [1, 2048] × [2048, 2048] | 0.075±0.012 | **0.058±0.007** | 0.076±0.011 |
| [32, 2048] × [2048, 2048] | 2.400±0.041 | 1.599±0.020 | **1.202±0.018** |
| [128, 2048] × [2048, 2048] | 10.820±0.039 | 6.458±0.168 | **5.805±0.039** |
| [256, 2048] × [2048, 2048] | 21.669±0.080 | 12.739±0.183 | **11.882±0.040** |
| [512, 2048] × [2048, 2048] | 43.257±0.083 | 25.680±0.335 | **23.342±0.082** |
| [2048, 2048] × [2048, 2048] | 173.175±0.214 | 103.112±0.552 | **93.276±0.612** |
| [128, 2048] × [2048, 8192] | 43.345±0.090 | 25.541±0.239 | **23.528±0.052** |
| [128, 8192] × [8192, 2048] | 38.085±0.162 | 23.866±0.096 | **22.569±0.132** |
</div>
### 2. GEMM/GEMV Integration with llama.cpp
Integrated I2_S quantization format into llama.cpp's compute graph:
- **GEMV Operations:** Optimized matrix-vector multiplication for token generation.
- **GEMM Operations:** Efficient matrix-matrix multiplication for prompt processing.
- **Tiling Strategy:** Configurable block sizes for optimal cache utilization.
### 3. Configuration Fine-tuning
Fine-tuning kernel parameters for optimal performance on specific hardware:
**Example Configuration (x86, AMD EPYC 7V13):**
- Method: Activation Parallel
- Threads: 8
- Workload: 128 prompt tokens (pp128)
**Fine-tuning Parameters:**
- **Parallelism Degree:** [2, 4, 8]
- **Row Block Size:** [2, 4, 8, 16, 32]
- **Column Block Size:** [32, 64, 128, 256, 512, 1024]
**Fine-tuning Results:**
<div align="center">
<img src="./assets/fine_tuning_result.png" alt="fine_tune_result" width="800"/>
*Shows throughput (tokens/s) for various configurations.*
</div>
**Optimal Configuration:** Under this setup (x86, 8 threads, pp128), the best performance is achieved with parallelism degree = 4, row block size = 4, and column block size = 128.
### 4. Embedding Quantization
Evaluated multiple embedding quantization formats to balance memory usage, model quality, and inference speed:
**Perplexity Comparison:**
<div align="center">
Test configuration: BitNet-b1.58-2B-4T, TG128
| Embedding Type | Wikitext | PTB | LAMBADA | IMDB | AG NEWS |
|:---:|:---:|:---:|:---:|:---:|:---:|
| **F32** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
| **F16** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
| **Q8_0** | 17.1197±0.1280 | 33.1181±0.4893 | 43.2891±0.6364 | 29.3133±0.2892 | 36.7740±0.3920 |
| **Q6_K** | 17.1487±0.1282 | 33.2203±0.4914 | 43.3046±0.6362 | 29.3491±0.2897 | 36.7972±0.3921 |
| **Q5_0** | 17.2379±0.1288 | 33.2439±0.4907 | 43.4631±0.6379 | 29.5481±0.2920 | 36.8539±0.3924 |
| **Q4_0** | 17.3529±0.1300 | 33.7754±0.5001 | 44.4552±0.6559 | 30.1044±0.2978 | 37.3985±0.3997 |
| **Q3_K** | 17.6434±0.1320 | 34.3914±0.5089 | 45.4591±0.6735 | 30.8476±0.3069 | 39.5692±0.4259 |
| **I2_S** | N/A | N/A | N/A | N/A | N/A |
**N/A indicates model failure due to extreme quantization.*
</div>
**Inference Speed Comparison:**
<div align="center">
<img src="./assets/embedding_throughput.png" alt="embedding_throughput" width="800"/>
*Token generation throughput (tg128) for different embedding quantization types.*
</div>
**Recommendation:** Based on comprehensive evaluation of memory footprint, perplexity preservation, and inference speed, **Q6_K** is selected as the optimal embedding quantization format.
## Performance
Comparison of optimized parallel kernels vs. original implementation:
**Test Configuration:**
- Model: BitNet-b1.58-2B-4T
- Hardware: AMD EPYC 7V13
- Threads: 1 / 2 / 4 / 8 / 12 / 16
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
- Method: Activation Parallel
<div align="center">
<img src="./assets/performance_comparison_amd_epyc.png" alt="performance_comparison_amd_epyc" width="800"/>
</div>
**Test Configuration:**
- Model: BitNet-b1.58-2B-4T
- Hardware: Intel i7-13800H
- Threads: 1 / 2 / 4 / 6
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
- Method: Activation Parallel
<div align="center">
<img src="./assets/performance_comparison_i7-13800h.png" alt="performance_comparison_i7-13800h" width="800"/>
</div>
**Test Configuration:**
- Model: BitNet-b1.58-2B-4T
- Hardware: Cobalt 100
- Threads: 1 / 2 / 4 / 8
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
- Method: Activation Parallel
<div align="center">
<img src="./assets/performance_comparison_cobalt100_dotprod.png" alt="performance_comparison_cobalt100_dotprod" width="800"/>
</div>
## Technical Details
### Key Files Modified
- `src/ggml-bitnet-mad.cpp`: Parallel kernel implementations
- `3rdparty/llama.cpp/ggml/src/ggml.c`: GEMM/GEMV integration
- `include/gemm-config.h`: Configuration file
### Supported Architectures
- ✅ x86-64 with AVX2
- ✅ ARM with NEON
- ✅ ARM with DOTPROD extension
Binary file not shown.

After

Width:  |  Height:  |  Size: 183 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 313 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 290 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

+928 -235
View File
File diff suppressed because it is too large Load Diff
+6 -6
View File
@@ -109,12 +109,12 @@ def main():
except OSError as e:
print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
if gguf_f32_output.exists():
print(f"Removing f32 GGUF: {gguf_f32_output}")
try:
gguf_f32_output.unlink()
except OSError as e:
print(f"Warning: Could not remove {gguf_f32_output}: {e}")
# if gguf_f32_output.exists():
# print(f"Removing f32 GGUF: {gguf_f32_output}")
# try:
# gguf_f32_output.unlink()
# except OSError as e:
# print(f"Warning: Could not remove {gguf_f32_output}: {e}")
if input_backup_file.exists():
if not input_file.exists():
View File
+473
View File
@@ -0,0 +1,473 @@
#!/usr/bin/env python3
"""
Embedding Quantization Script
This script converts ggml-model-f32.gguf to multiple quantized versions
with different token embedding types.
"""
import subprocess
import os
import argparse
import re
import csv
from pathlib import Path
from datetime import datetime
class EmbeddingQuantizer:
def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
self.input_model = Path(input_model)
self.output_dir = Path(output_dir)
self.quantize_bin = Path(quantize_bin)
self.bench_bin = Path(bench_bin)
self.stats_dir = Path(stats_dir)
self.csv_output = Path(csv_output) if csv_output else None
# Verify input file exists
if not self.input_model.exists():
raise FileNotFoundError(f"Input model not found: {self.input_model}")
# Verify quantize tool exists
if not self.quantize_bin.exists():
raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
# Verify bench tool exists
if not self.bench_bin.exists():
raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
# Create output directories
self.output_dir.mkdir(parents=True, exist_ok=True)
self.stats_dir.mkdir(parents=True, exist_ok=True)
self.results = []
self.newly_created_files = set() # Track newly created files
def quantize(self, embedding_type, output_suffix):
"""
Perform single quantization
Args:
embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
Returns:
bool: Whether successful
"""
output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
# Check if file already exists
file_already_existed = output_file.exists()
if file_already_existed:
print(f"️ File already exists: {output_file}")
print(f" Skipping quantization, will use existing file for benchmark")
return True
cmd = [
str(self.quantize_bin),
"--token-embedding-type", embedding_type,
str(self.input_model),
str(output_file),
"I2_S",
"1",
"1"
]
print(f"\n{'='*80}")
print(f"🔄 Quantizing with embedding type: {embedding_type}")
print(f"📥 Input: {self.input_model}")
print(f"📤 Output: {output_file}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
start_time = datetime.now()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=600 # 10 minute timeout
)
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
if result.returncode == 0:
# Get output file size
file_size_mb = output_file.stat().st_size / (1024 * 1024)
print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
# Record newly created file
if not file_already_existed:
self.newly_created_files.add(output_file)
# Print part of output
if result.stdout:
print("\n📊 Quantization output:")
print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
return True
else:
print(f"❌ Failed with return code {result.returncode}")
print(f"Error: {result.stderr}")
return False
except subprocess.TimeoutExpired:
print(f"❌ Timeout (exceeded 10 minutes)")
return False
except Exception as e:
print(f"❌ Exception: {e}")
return False
def benchmark_model(self, output_suffix):
"""
Benchmark model
Args:
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
Returns:
dict: Dictionary with benchmark results, or None if failed
"""
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
if not model_file.exists():
print(f"❌ Model file not found for benchmarking: {model_file}")
return None
cmd = [
str(self.bench_bin),
"-m", str(model_file),
"-p", "128",
"-n", "0",
"-t", "1,2,4,8",
"-ngl", "0"
]
print(f"\n{'='*80}")
print(f"🏃 Running benchmark for: {output_suffix}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=300 # 5 minute timeout
)
if result.returncode == 0:
print("✅ Benchmark completed successfully")
print("\n📊 Benchmark output:")
print(result.stdout)
# 解析输出
bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
return bench_results
else:
print(f"❌ Benchmark failed with return code {result.returncode}")
print(f"Error: {result.stderr}")
return None
except subprocess.TimeoutExpired:
print(f"❌ Benchmark timeout (exceeded 5 minutes)")
return None
except Exception as e:
print(f"❌ Benchmark exception: {e}")
return None
def parse_benchmark_output(self, output, output_suffix):
"""
Parse benchmark output to extract t/s data (mean±std)
Args:
output: Benchmark command output
output_suffix: Output file suffix
Returns:
dict: Dictionary with parsed results
"""
results = {
'embedding_type': output_suffix,
'threads_1': None,
'threads_2': None,
'threads_4': None,
'threads_8': None,
}
# Parse table data
# Find lines containing pp128 and t/s
lines = output.strip().split('\n')
for line in lines:
# Skip header and separator lines
if '|' not in line or 'model' in line or '---' in line:
continue
# Try to extract data
# Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
parts = [p.strip() for p in line.split('|')]
if len(parts) >= 8 and 'pp128' in parts[6]:
threads_str = parts[5].strip()
throughput_str = parts[7].strip()
# Extract thread count
try:
threads = int(threads_str)
except:
continue
# Extract t/s data (format: "405.73 ± 3.69" or "405.73")
# Try to match "mean ± std" format
match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
if match_with_std:
mean = float(match_with_std.group(1))
std = float(match_with_std.group(2))
throughput = f"{mean:.2f}±{std:.2f}"
else:
# Only mean, no std
match = re.search(r'([\d.]+)', throughput_str)
if match:
throughput = f"{float(match.group(1)):.2f}"
else:
continue
# Store result based on thread count
if threads == 1:
results['threads_1'] = throughput
elif threads == 2:
results['threads_2'] = throughput
elif threads == 4:
results['threads_4'] = throughput
elif threads == 8:
results['threads_8'] = throughput
return results
def cleanup_model(self, output_suffix):
"""
Cleanup model files (only delete newly created files)
Args:
output_suffix: Output file suffix
"""
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
if model_file in self.newly_created_files:
try:
model_file.unlink()
print(f"🗑️ Deleted newly created file: {model_file}")
self.newly_created_files.remove(model_file)
except Exception as e:
print(f"⚠️ Failed to delete {model_file}: {e}")
else:
print(f"️ Keeping existing file: {model_file}")
def run_all_quantizations(self, types_to_quantize):
"""
Run all quantizations
Args:
types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
"""
print(f"\n{'='*80}")
print(f"🚀 Starting Embedding Quantization and Benchmarking")
print(f"{'='*80}")
print(f"📥 Input model: {self.input_model}")
print(f"📤 Output directory: {self.output_dir}")
print(f"📊 Stats directory: {self.stats_dir}")
print(f"🔢 Total quantizations: {len(types_to_quantize)}")
print(f"{'='*80}\n")
total_start = datetime.now()
for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
print(f"\n{'#'*80}")
print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
print(f"{'#'*80}\n")
# Quantize model
success = self.quantize(embedding_type, output_suffix)
if not success:
print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
continue
# Run benchmark
bench_results = self.benchmark_model(output_suffix)
if bench_results:
self.results.append(bench_results)
else:
print(f"⚠️ Benchmark failed for {output_suffix}")
# Cleanup model files (only delete newly created files)
self.cleanup_model(output_suffix)
print(f"\n{'#'*80}")
print(f"✅ Completed {output_suffix}")
print(f"{'#'*80}\n")
total_end = datetime.now()
total_duration = (total_end - total_start).total_seconds()
# 保存结果到CSV
self.save_results_to_csv()
# 打印总结
self.print_summary(total_duration)
def save_results_to_csv(self):
"""将benchmark结果保存到CSV文件"""
if not self.results:
print("⚠️ No results to save")
return
# Use user-specified CSV path, otherwise use default path
if self.csv_output:
csv_file = self.csv_output
# Ensure parent directory exists
csv_file.parent.mkdir(parents=True, exist_ok=True)
else:
csv_file = self.stats_dir / f"embedding_benchmark.csv"
print(f"\n💾 Saving results to: {csv_file}")
try:
with open(csv_file, 'w', newline='') as f:
fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for result in self.results:
writer.writerow(result)
print(f"✅ Results saved successfully")
# Also print table
print(f"\n📊 Benchmark Results:")
print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
print("-" * 87)
for result in self.results:
t1 = result['threads_1'] if result['threads_1'] else "N/A"
t2 = result['threads_2'] if result['threads_2'] else "N/A"
t4 = result['threads_4'] if result['threads_4'] else "N/A"
t8 = result['threads_8'] if result['threads_8'] else "N/A"
print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
except Exception as e:
print(f"❌ Failed to save results: {e}")
def print_summary(self, total_duration):
"""Print quantization summary"""
print(f"\n\n{'='*80}")
print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
print(f"{'='*80}\n")
successful = len(self.results)
total = len(self.results)
print(f"✅ Completed: {successful} benchmarks")
print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
if self.results:
if self.csv_output and self.csv_output.exists():
print(f"📁 Results saved to: {self.csv_output}")
else:
csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
if csv_files:
latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
print(f"📁 Results saved to: {latest_csv}")
print(f"\n{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
parser.add_argument('--input', '-i',
default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
parser.add_argument('--output-dir', '-o',
default='../models/BitNet-b1.58-2B-4T',
help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
parser.add_argument('--quantize-bin', '-q',
default='../build/bin/llama-quantize',
help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
parser.add_argument('--bench-bin', '-b',
default='../build/bin/llama-bench',
help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
parser.add_argument('--stats-dir',
default='../stats',
help='Directory to save benchmark results (default: ../stats)')
parser.add_argument('--csv-output', '-c',
help='Custom path for CSV output file (e.g., stats/my_results.csv)')
parser.add_argument('--types', '-t',
nargs='+',
help='Specific types to quantize (e.g., f32 q6_k q4_0)')
parser.add_argument('--skip-existing', '-s',
action='store_true',
help='Skip quantization if output file already exists (will still benchmark existing files)')
args = parser.parse_args()
# Define all supported quantization types
# Format: (embedding_type for command line, output_suffix for filename)
all_types = [
('F32', 'f32'),
('F16', 'f16'),
('Q8_0', 'q8_0'),
('Q6_K', 'q6_k'),
('Q5_0', 'q5_0'),
('Q4_0', 'q4_0'),
('Q3_K', 'q3_k'),
('TQ2_0', 'tq2_0'),
]
# If specific types are specified, filter the list
if args.types:
types_lower = [t.lower() for t in args.types]
types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
if not types_to_quantize:
print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
return
else:
types_to_quantize = all_types
# If skip existing files is enabled, no need to filter
# Because new logic will automatically detect and skip during quantization, but will still benchmark
# 创建量化器并运行
try:
quantizer = EmbeddingQuantizer(
args.input,
args.output_dir,
args.quantize_bin,
args.bench_bin,
args.stats_dir,
args.csv_output
)
quantizer.run_all_quantizations(types_to_quantize)
except FileNotFoundError as e:
print(f"❌ Error: {e}")
return 1
except KeyboardInterrupt:
print("\n\n⚠️ Quantization interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
exit(main() or 0)
+573
View File
@@ -0,0 +1,573 @@
#!/bin/bash
# Unified GEMM kernel benchmark script
# Builds, tests, and benchmarks the GEMM kernel with configurable output
set -e
# Default values
BUILD_DIR="../build"
ITERATIONS=1000
OUTPUT_CSV=""
SKIP_BUILD=false
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# Print usage
print_usage() {
cat << EOF
Usage: $0 [options]
Options:
-o, --output <path> Output CSV file path (default: ../stats/gemm_kernel_test_noparal.csv)
-i, --iterations <num> Number of iterations per test (default: 1000)
-s, --skip-build Skip building the benchmark binary
-h, --help Show this help message
Examples:
# Run with default settings
$0
# Specify custom output file
$0 -o /path/to/my_results.csv
# Quick test with fewer iterations
$0 -i 100 -o quick_test.csv
# Skip build if already compiled
$0 -s -o results.csv
EOF
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-o|--output)
OUTPUT_CSV="$2"
shift 2
;;
-i|--iterations)
ITERATIONS="$2"
shift 2
;;
-s|--skip-build)
SKIP_BUILD=true
shift
;;
-h|--help)
print_usage
exit 0
;;
*)
echo "Unknown option: $1"
print_usage
exit 1
;;
esac
done
# Set default output CSV if not specified
if [ -z "$OUTPUT_CSV" ]; then
OUTPUT_CSV="${SCRIPT_DIR}/../stats/gemm_kernel_test_noparal.csv"
fi
# Create output directory first
mkdir -p "$(dirname "$OUTPUT_CSV")"
# Convert to absolute path
if [[ "$OUTPUT_CSV" = /* ]]; then
# Already absolute path
OUTPUT_CSV="$OUTPUT_CSV"
else
# Convert relative path to absolute
OUTPUT_CSV="$(cd "$(dirname "$OUTPUT_CSV")" && pwd)/$(basename "$OUTPUT_CSV")"
fi
echo "=========================================="
echo "GEMM Kernel Benchmark Suite"
echo "=========================================="
echo "Configuration:"
echo " Iterations: $ITERATIONS"
echo " Output CSV: $OUTPUT_CSV"
echo " Skip build: $SKIP_BUILD"
echo "=========================================="
echo ""
# Build the benchmark binary
if [ "$SKIP_BUILD" = false ]; then
echo "Step 1: Building GEMM kernel benchmark..."
echo "------------------------------------------"
CXX=${CXX:-g++}
# Create build directory if it doesn't exist
mkdir -p "${SCRIPT_DIR}/${BUILD_DIR}"
# Create temporary C++ source file
TEMP_CPP="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel_temp.cpp"
cat > "${TEMP_CPP}" << 'EOF'
/**
* Standalone benchmark for ggml_gemm_i2_i8_s kernel
*
* This program tests the performance of the ggml_gemm_i2_i8_s kernel
* with configurable matrix sizes and iteration counts.
*
* Usage: ./test_gemm_kernel [options]
* -n <size> : embedding dimension (must be divisible by 4, default: 2048)
* -r <rows> : number of rows in matrix Y (default: 32)
* -c <cols> : number of columns in matrix X (default: 128)
* -i <iters> : number of iterations (default: 1000)
* -w <warmup> : number of warmup iterations (default: 10)
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <stdint.h>
#include <math.h>
#include <assert.h>
// Include necessary headers
#include "../include/gemm-config.h"
// Function declarations (from ggml-quants.h)
extern "C" void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc);
// GEMM kernel definition
void ggml_gemm_i2_i8_s(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
#if defined(ACT_PARALLEL)
const int64_t row_block = ROW_BLOCK_SIZE;
const int64_t col_block = COL_BLOCK_SIZE;
for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
const void * vy_r = (const uint8_t *)vy + r0 * n;
for (int64_t c = 0; c < cur_c; ++c) {
const int64_t col = c0 + c;
float * s_col = s + col;
const void * vx_col = (const uint8_t *)vx + col * n / 4;
ggml_vec_dot_i2_i8_s(n, s_col + r0 * bs, bs, vx_col, n, vy_r, n, cur_r);
}
}
}
#else
const int64_t row_block = ROW_BLOCK_SIZE;
const int64_t col_block = COL_BLOCK_SIZE;
for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
const void * vx_c = (const uint8_t *)vx + c0 * n / 4;
for (int64_t r = 0; r < cur_r; ++r) {
const int64_t row = r0 + r;
float * s_row = s + row * bs;
const void * vy_row = (const uint8_t *)vy + row * n;
ggml_vec_dot_i2_i8_s(n, s_row + c0, bs, vx_c, n, vy_row, n, cur_c);
}
}
}
#endif
}
// Helper function to get current time in nanoseconds
double get_time_ns() {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_sec * 1e9 + ts.tv_nsec;
}
// Initialize matrix with random i2 values (2-bit quantized)
void init_matrix_i2(uint8_t* data, int n, int cols) {
// i2 format: 4 values per byte (2 bits each)
int total_bytes = n * cols / 4;
for (int i = 0; i < total_bytes; i++) {
data[i] = rand() & 0xFF;
}
}
// Initialize matrix with random i8 values
void init_matrix_i8(int8_t* data, int n, int rows) {
int total_elements = n * rows;
for (int i = 0; i < total_elements; i++) {
data[i] = (int8_t)((rand() % 256) - 128);
}
}
// Benchmark configuration
struct BenchmarkConfig {
int n; // embedding dimension (must be divisible by 4)
int nr; // number of rows in Y matrix
int nc; // number of columns in X matrix
int iterations; // number of benchmark iterations
int warmup; // number of warmup iterations
};
void print_config(const BenchmarkConfig& config) {
printf("=" "=%.78s\n", "===============================================================================");
printf("Benchmark Configuration:\n");
printf("=" "=%.78s\n", "===============================================================================");
printf(" Embedding dimension (n) : %d\n", config.n);
printf(" Matrix Y rows (nr) : %d\n", config.nr);
printf(" Matrix X columns (nc) : %d\n", config.nc);
printf(" Iterations : %d\n", config.iterations);
printf(" Warmup iterations : %d\n", config.warmup);
printf("\nMatrix sizes:\n");
printf(" X (i2): %d x %d (%.2f KB)\n", config.nc, config.n,
(config.nc * config.n / 4) / 1024.0);
printf(" Y (i8): %d x %d (%.2f KB)\n", config.nr, config.n,
(config.nr * config.n) / 1024.0);
printf(" S (f32): %d x %d (%.2f KB)\n", config.nr, config.nc,
(config.nr * config.nc * sizeof(float)) / 1024.0);
printf("\nGEMM Config:\n");
#if defined(ACT_PARALLEL)
printf(" ACT_PARALLEL : ON\n");
#else
printf(" ACT_PARALLEL : OFF\n");
#endif
printf(" ROW_BLOCK_SIZE : %d\n", ROW_BLOCK_SIZE);
printf(" COL_BLOCK_SIZE : %d\n", COL_BLOCK_SIZE);
printf(" PARALLEL_SIZE : %d\n", PARALLEL_SIZE);
printf("=" "=%.78s\n\n", "===============================================================================");
}
void run_benchmark(const BenchmarkConfig& config) {
// Allocate matrices
printf("Allocating matrices...\n");
// X matrix (i2 format): nc x n, but stored as nc x (n/4) bytes
// Align to 64 bytes for AVX-512, which is backward compatible with AVX2 (32 bytes)
size_t x_size = config.nc * config.n / 4;
size_t x_size_aligned = ((x_size + 63) / 64) * 64;
uint8_t* X = (uint8_t*)aligned_alloc(64, x_size_aligned);
// Y matrix (i8 format): nr x n
size_t y_size = config.nr * config.n;
size_t y_size_aligned = ((y_size + 63) / 64) * 64;
int8_t* Y = (int8_t*)aligned_alloc(64, y_size_aligned);
// Result matrix (float32): nr x nc
size_t s_size = config.nr * config.nc * sizeof(float);
size_t s_size_aligned = ((s_size + 63) / 64) * 64;
float* S = (float*)aligned_alloc(64, s_size_aligned);
if (!X || !Y || !S) {
fprintf(stderr, "Failed to allocate memory\n");
exit(1);
}
// Initialize matrices with random data
printf("Initializing matrices with random data...\n");
srand(time(NULL));
init_matrix_i2(X, config.n, config.nc);
init_matrix_i8(Y, config.n, config.nr);
memset(S, 0, config.nr * config.nc * sizeof(float));
// Warmup
printf("Running %d warmup iterations...\n", config.warmup);
for (int i = 0; i < config.warmup; i++) {
ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
}
// Benchmark
printf("Running %d benchmark iterations...\n", config.iterations);
double total_time = 0.0;
double min_time = 1e20;
double max_time = 0.0;
for (int i = 0; i < config.iterations; i++) {
double start = get_time_ns();
ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
double end = get_time_ns();
double elapsed = end - start;
total_time += elapsed;
if (elapsed < min_time) min_time = elapsed;
if (elapsed > max_time) max_time = elapsed;
if ((i + 1) % 100 == 0) {
printf(" Progress: %d/%d iterations\n", i + 1, config.iterations);
}
}
// Calculate statistics
double avg_time_ns = total_time / config.iterations;
double avg_time_ms = avg_time_ns / 1e6;
double min_time_ms = min_time / 1e6;
double max_time_ms = max_time / 1e6;
// Calculate GFLOPS
// For GEMM: nr x nc x n multiply-adds = 2 * nr * nc * n FLOPs
double flops = 2.0 * config.nr * config.nc * config.n;
double gflops = (flops / avg_time_ns);
// Calculate throughput (tokens/s assuming each column is a token)
double throughput = (config.nc * 1e9) / avg_time_ns;
// Print results
printf("\n");
printf("=" "=%.78s\n", "===============================================================================");
printf("Benchmark Results:\n");
printf("=" "=%.78s\n", "===============================================================================");
printf(" Average time : %.3f ms\n", avg_time_ms);
printf(" Min time : %.3f ms\n", min_time_ms);
printf(" Max time : %.3f ms\n", max_time_ms);
printf(" Std dev : %.3f ms\n", sqrt((max_time_ms - min_time_ms) * (max_time_ms - min_time_ms) / 12));
printf("\nPerformance:\n");
printf(" GFLOPS : %.2f\n", gflops);
printf(" Throughput : %.2f tokens/s\n", throughput);
printf(" Latency/token : %.3f us\n", (avg_time_ms * 1000) / config.nc);
printf("=" "=%.78s\n", "===============================================================================");
// Cleanup
free(X);
free(Y);
free(S);
}
void print_usage(const char* program) {
printf("Usage: %s [options]\n", program);
printf("Options:\n");
printf(" -n <size> Embedding dimension (must be divisible by 4, default: 2048)\n");
printf(" -r <rows> Number of rows in matrix Y (default: 32)\n");
printf(" -c <cols> Number of columns in matrix X (default: 128)\n");
printf(" -i <iters> Number of iterations (default: 1000)\n");
printf(" -w <warmup> Number of warmup iterations (default: 10)\n");
printf(" -h Show this help message\n");
}
int main(int argc, char** argv) {
BenchmarkConfig config = {
.n = 2048,
.nr = 32,
.nc = 128,
.iterations = 1000,
.warmup = 10
};
// Parse command line arguments
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) {
config.n = atoi(argv[++i]);
} else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) {
config.nr = atoi(argv[++i]);
} else if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) {
config.nc = atoi(argv[++i]);
} else if (strcmp(argv[i], "-i") == 0 && i + 1 < argc) {
config.iterations = atoi(argv[++i]);
} else if (strcmp(argv[i], "-w") == 0 && i + 1 < argc) {
config.warmup = atoi(argv[++i]);
} else if (strcmp(argv[i], "-h") == 0) {
print_usage(argv[0]);
return 0;
} else {
fprintf(stderr, "Unknown option: %s\n", argv[i]);
print_usage(argv[0]);
return 1;
}
}
// Validate configuration
if (config.n % 4 != 0) {
fprintf(stderr, "Error: Embedding dimension (-n) must be divisible by 4\n");
return 1;
}
if (config.n <= 0 || config.nr <= 0 || config.nc <= 0 || config.iterations <= 0) {
fprintf(stderr, "Error: All size parameters must be positive\n");
return 1;
}
// Run benchmark
print_config(config);
run_benchmark(config);
return 0;
}
EOF
# Compiler flags
CXXFLAGS="-O3 -march=native -mtune=native -std=c++17 -fopenmp"
CXXFLAGS+=" -I${SCRIPT_DIR}/.. -I${SCRIPT_DIR}/../include"
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/include"
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/src"
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/include"
CXXFLAGS+=" -DNDEBUG -ffast-math"
# Link flags
LDFLAGS="-lm -lpthread"
# Link with pre-built libraries
GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
GGML_SO="${GGML_LIB_DIR}/libggml.so"
if [ ! -f "${GGML_SO}" ]; then
echo "❌ Error: Cannot find libggml.so at ${GGML_SO}"
echo "Please build the project first with: cmake --build build"
rm -f "${TEMP_CPP}"
exit 1
fi
LDFLAGS+=" -L${GGML_LIB_DIR} -lggml -Wl,-rpath,${GGML_LIB_DIR}"
# Output binary
BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
echo "Compiler: ${CXX}"
echo "Building from embedded source..."
echo ""
# Build
${CXX} ${CXXFLAGS} "${TEMP_CPP}" -o ${BENCHMARK_BIN} ${LDFLAGS}
if [ $? -eq 0 ]; then
echo "✅ Build successful!"
rm -f "${TEMP_CPP}"
echo ""
else
echo "❌ Build failed!"
rm -f "${TEMP_CPP}"
exit 1
fi
else
echo "Step 1: Skipping build (using existing binary)"
echo "------------------------------------------"
BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
if [ ! -f "${BENCHMARK_BIN}" ]; then
echo "❌ Error: Benchmark binary not found at ${BENCHMARK_BIN}"
echo "Please run without -s to build it first."
exit 1
fi
echo "✅ Found existing binary"
echo ""
fi
# Set LD_LIBRARY_PATH to include the GGML library directory
GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
export LD_LIBRARY_PATH="${GGML_LIB_DIR}:${LD_LIBRARY_PATH}"
echo "Step 2: Running benchmark tests"
echo "------------------------------------------"
echo "Library path: ${GGML_LIB_DIR}"
echo ""
# Write CSV header
echo "test_name,n,nr,nc,time_ms,gflops,throughput_tokens_per_sec" > "$OUTPUT_CSV"
echo "Results will be saved to: $OUTPUT_CSV"
echo ""
# Function to extract metrics and append to CSV
extract_and_save() {
local test_name="$1"
local output="$2"
# Extract values using grep and awk
local n=$(echo "$output" | grep "Embedding dimension" | awk '{print $5}')
local nr=$(echo "$output" | grep "Matrix Y rows" | awk '{print $6}')
local nc=$(echo "$output" | grep "Matrix X columns" | awk '{print $6}')
local avg_time=$(echo "$output" | grep "Average time" | awk '{print $4}')
local min_time=$(echo "$output" | grep "Min time" | awk '{print $4}')
local max_time=$(echo "$output" | grep "Max time" | awk '{print $4}')
local gflops=$(echo "$output" | grep "GFLOPS" | awk '{print $3}')
local throughput=$(echo "$output" | grep "Throughput" | awk '{print $3}')
# Check if values were extracted successfully
if [ -z "$avg_time" ] || [ -z "$min_time" ] || [ -z "$max_time" ]; then
echo "Warning: Failed to extract timing data for ${test_name}"
echo "${test_name},${n},${nr},${nc},N/A,N/A,N/A" >> "$OUTPUT_CSV"
return
fi
# Calculate standard deviation estimate from range
# Using awk with proper variable passing
local std_time=$(awk -v min="$min_time" -v max="$max_time" 'BEGIN {printf "%.4f", (max - min) / 4}')
# Format as mean±std
local time_formatted="${avg_time}±${std_time}"
# Append to CSV
echo "${test_name},${n},${nr},${nc},${time_formatted},${gflops},${throughput}" >> "$OUTPUT_CSV"
}
# Run benchmark tests
echo "=========================================="
echo "BitNet-2B Typical Shapes Performance Test"
echo "=========================================="
echo ""
echo "Test 1: Single Token Generation (Attention QKV projection)"
echo " Scenario: Generating 1 token at a time"
echo " Shape: n=2048, r=1, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 1 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "single_token_gen" "$OUTPUT"
echo ""
echo "Test 2: Small Batch Prompt Processing (Attention QKV projection)"
echo " Scenario: Processing prompt with 128 tokens, batch size 1"
echo " Shape: n=2048, r=128, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "small_batch_prompt" "$OUTPUT"
echo ""
echo "Test 3: Medium Batch Prompt Processing (Attention QKV projection)"
echo " Scenario: Processing prompt with 256 tokens or batch of 256"
echo " Shape: n=2048, r=256, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 256 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "medium_batch_prompt" "$OUTPUT"
echo ""
echo "Test 4: Large Batch Processing (Attention QKV projection)"
echo " Scenario: Processing 512 tokens or batch of 512"
echo " Shape: n=2048, r=512, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 512 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "large_batch_prompt" "$OUTPUT"
echo ""
echo "Test 5: FFN Up-projection (Small batch)"
echo " Scenario: Feed-forward network expansion, 128 tokens"
echo " Shape: n=2048, r=128, c=8192"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 8192 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "ffn_up_projection" "$OUTPUT"
echo ""
echo "Test 6: FFN Down-projection (Small batch)"
echo " Scenario: Feed-forward network reduction, 128 tokens"
echo " Shape: n=8192, r=128, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 8192 -r 128 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "ffn_down_projection" "$OUTPUT"
echo ""
echo "Test 7: Long Context Processing"
echo " Scenario: Processing very long context (2048 tokens)"
echo " Shape: n=2048, r=2048, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 2048 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "long_context" "$OUTPUT"
echo ""
echo "Test 8: Batched Token Generation"
echo " Scenario: Generating tokens for 32 sequences simultaneously"
echo " Shape: n=2048, r=32, c=2048"
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 32 -c 2048 -i $ITERATIONS 2>&1)
echo "$OUTPUT"
extract_and_save "batched_token_gen" "$OUTPUT"
echo ""
echo "=========================================="
echo "All tests completed successfully!"
echo "=========================================="
echo "Results saved to: $OUTPUT_CSV"
echo ""
echo "Summary:"
wc -l "$OUTPUT_CSV" | awk '{print " Total records:", $1 - 1}'
echo " Output file: $OUTPUT_CSV"
echo "=========================================="
+608
View File
@@ -0,0 +1,608 @@
#!/usr/bin/env python3
"""
Perplexity Test Script
Tests GGUF model perplexity on multiple datasets using llama-perplexity.
"""
import os
import subprocess
import time
import csv
import re
from datetime import datetime
from pathlib import Path
import argparse
import tempfile
import shutil
import statistics
class PerplexityTester:
def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity",
data_dir="../data", output_dir="perplexity_results", quick_mode=False,
quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None):
self.model_path = Path(model_path)
self.llama_perplexity_bin = Path(llama_perplexity_bin)
self.quantize_bin = Path(quantize_bin)
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.quick_mode = quick_mode
self.test_embeddings = test_embeddings
self.csv_output = Path(csv_output) if csv_output else None
self.results = []
self.created_models = set() # Track newly created model files
self.temp_files = [] # Track temporary files for cleanup
# Embedding types to test
self.embedding_types = [
('F32', 'f32'),
('F16', 'f16'),
('Q8_0', 'q8_0'),
('Q6_K', 'q6_k'),
('Q5_0', 'q5_0'),
('Q4_0', 'q4_0'),
('Q3_K', 'q3_k'),
('TQ2_0', 'tq2_0'),
]
# Create output directory
self.output_dir.mkdir(parents=True, exist_ok=True)
# Verify llama-perplexity binary exists
if not self.llama_perplexity_bin.exists():
raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}")
# Verify quantize binary exists if testing embeddings
if self.test_embeddings and not self.quantize_bin.exists():
raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}")
# Verify model file exists
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
def find_datasets(self):
"""Find all test.txt files in dataset directories."""
datasets = []
if not self.data_dir.exists():
print(f"❌ Data directory not found: {self.data_dir}")
return datasets
print(f"\n🔍 Searching for datasets in {self.data_dir}...")
# Look for test.txt files in subdirectories
for dataset_dir in sorted(self.data_dir.iterdir()):
if dataset_dir.is_dir():
test_file = dataset_dir / "test.txt"
if test_file.exists():
size_mb = test_file.stat().st_size / (1024 * 1024)
datasets.append({
'name': dataset_dir.name,
'path': test_file,
'size': test_file.stat().st_size,
'size_mb': size_mb
})
print(f"{dataset_dir.name:<20} ({size_mb:.2f} MB)")
else:
print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)")
return datasets
def create_quick_dataset(self, dataset_path, num_chars=4096):
"""Create a temporary dataset with only the first N characters for quick testing."""
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8')
self.temp_files.append(temp_file.name)
try:
with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read(num_chars)
temp_file.write(content)
temp_file.close()
return Path(temp_file.name)
except Exception as e:
print(f"⚠️ Failed to create quick dataset: {e}")
temp_file.close()
return dataset_path
def cleanup_temp_files(self):
"""Clean up temporary files."""
for temp_file in self.temp_files:
try:
os.unlink(temp_file)
except:
pass
self.temp_files = []
def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None):
"""Run perplexity test on a single dataset."""
test_model = model_override if model_override else self.model_path
print(f"\n{'='*80}")
print(f"📊 Testing on dataset: {dataset_name}")
print(f" File: {dataset_path}")
print(f" Model: {test_model.name}")
print(f"{'='*80}")
cmd = [
str(self.llama_perplexity_bin),
"-m", str(test_model),
"-f", str(dataset_path),
"-t", str(threads),
"-c", str(ctx_size),
"-ngl", "0" # CPU only
]
print(f"💻 Command: {' '.join(cmd)}")
print(f"⏱️ Starting test...\n")
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600, # 1 hour timeout
cwd=os.getcwd()
)
elapsed_time = time.time() - start_time
if result.returncode == 0:
# Parse perplexity from output (check both stdout and stderr)
combined_output = result.stdout + "\n" + result.stderr
ppl = self.parse_perplexity(combined_output)
if ppl is not None:
print(f"\n✅ Perplexity: {ppl}")
print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)")
status = "success"
else:
print(f"\n⚠️ Test completed but could not parse perplexity")
print(f"Last 500 chars of stdout:")
print(result.stdout[-500:])
print(f"Last 500 chars of stderr:")
print(result.stderr[-500:])
status = "parse_error"
ppl = None
else:
print(f"\n❌ Test failed with return code {result.returncode}")
print(f"Error: {result.stderr[:500]}")
status = "failed"
ppl = None
elapsed_time = time.time() - start_time
return {
'dataset': dataset_name,
'perplexity': ppl,
'time': elapsed_time,
'status': status,
'stdout': result.stdout,
'stderr': result.stderr
}
except subprocess.TimeoutExpired:
elapsed_time = time.time() - start_time
print(f"\n❌ Timeout after {elapsed_time:.2f}s")
return {
'dataset': dataset_name,
'perplexity': None,
'time': elapsed_time,
'status': 'timeout',
'stdout': '',
'stderr': 'Test exceeded 1 hour timeout'
}
except Exception as e:
elapsed_time = time.time() - start_time
print(f"\n❌ Error: {e}")
return {
'dataset': dataset_name,
'perplexity': None,
'time': elapsed_time,
'status': 'error',
'stdout': '',
'stderr': str(e)
}
def parse_perplexity(self, output):
"""Parse perplexity value (mean±std format) from llama-perplexity output."""
# First try to match "PPL = mean +/- std" format
pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)'
match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE)
if match:
try:
mean = float(match.group(1))
std = float(match.group(2))
return f"{mean:.4f}±{std:.4f}"
except ValueError:
pass
# Fallback to patterns without std
patterns = [
r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)',
r'Final perplexity:\s*(\d+\.?\d*)',
r'PPL\s*=\s*(\d+\.?\d*)',
r'PPL:\s*(\d+\.?\d*)',
r'perplexity:\s*(\d+\.?\d*)',
r'ppl\s*=\s*(\d+\.?\d*)',
r'Perplexity:\s*(\d+\.?\d*)',
]
for pattern in patterns:
match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE)
if match:
try:
return f"{float(match.group(1)):.4f}"
except ValueError:
continue
return None
def quantize_embedding(self, embedding_type, output_suffix):
"""
Quantize model with specific embedding type.
Args:
embedding_type: Token embedding type (uppercase, e.g., 'Q6_K')
output_suffix: Output file suffix (lowercase, e.g., 'q6_k')
Returns:
Path to quantized model or None if failed
"""
# Construct output path
model_dir = self.model_path.parent
output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
# Check if file already exists
file_existed = output_path.exists()
if file_existed:
print(f"️ Model already exists: {output_path.name}")
return output_path
cmd = [
str(self.quantize_bin),
"--token-embedding-type", embedding_type,
str(self.model_path),
str(output_path),
"I2_S",
"1",
"1"
]
print(f"\n{'='*80}")
print(f"🔄 Quantizing with embedding type: {embedding_type}")
print(f"📥 Input: {self.model_path.name}")
print(f"📤 Output: {output_path.name}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=600 # 10 minutes timeout
)
duration = time.time() - start_time
if result.returncode == 0:
file_size_mb = output_path.stat().st_size / (1024 * 1024)
print(f"✅ Quantization successful!")
print(f" Duration: {duration:.2f}s")
print(f" Size: {file_size_mb:.2f} MB")
# Mark as newly created
self.created_models.add(output_path)
return output_path
else:
print(f"❌ Quantization failed with return code {result.returncode}")
print(f"Error: {result.stderr[:500]}")
return None
except subprocess.TimeoutExpired:
print(f"❌ Quantization timeout (exceeded 10 minutes)")
return None
except Exception as e:
print(f"❌ Quantization error: {e}")
return None
def cleanup_model(self, model_path):
"""Delete model file if it was created during this session."""
if model_path in self.created_models:
try:
model_path.unlink()
print(f"🗑️ Deleted: {model_path.name}")
self.created_models.remove(model_path)
except Exception as e:
print(f"⚠️ Failed to delete {model_path.name}: {e}")
else:
print(f"️ Keeping existing file: {model_path.name}")
def run_all_tests(self, threads=16, ctx_size=512):
"""Run perplexity tests on all datasets."""
datasets = self.find_datasets()
if not datasets:
print(f"\n❌ No datasets found in {self.data_dir}")
print(f" Make sure each dataset directory has a test.txt file")
return
# Quick mode: test all datasets but only first 4096 chars with smaller context
if self.quick_mode:
ctx_size = min(ctx_size, 128) # Use smaller context in quick mode
print(f"\n⚡ QUICK TEST MODE ENABLED")
print(f" - Testing all datasets with first 4096 characters only")
print(f" - Using reduced context size: {ctx_size}")
# Determine models to test
if self.test_embeddings:
print(f"\n{'='*80}")
print(f"🧪 EMBEDDING QUANTIZATION TEST MODE")
print(f"{'='*80}")
print(f"📦 Base model: {self.model_path.name}")
print(f"🔢 Embedding types to test: {len(self.embedding_types)}")
print(f"📊 Datasets: {len(datasets)}")
print(f"🧵 Threads: {threads}")
print(f"📏 Context size: {ctx_size}")
print(f"{'='*80}")
total_start = time.time()
# Test each embedding type
for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1):
print(f"\n\n{'#'*80}")
print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})")
print(f"{'#'*80}")
# Quantize model
quantized_model = self.quantize_embedding(embedding_type, output_suffix)
if quantized_model is None:
print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure")
continue
# Test on all datasets
for j, dataset in enumerate(datasets, 1):
print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...")
# Use quick dataset if in quick mode
test_path = dataset['path']
if self.quick_mode:
test_path = self.create_quick_dataset(dataset['path'])
result = self.run_perplexity_test(
f"{dataset['name']}_embed-{output_suffix}",
test_path,
threads,
ctx_size,
model_override=quantized_model
)
self.results.append(result)
# Cleanup model after testing
print(f"\n🧹 Cleaning up {output_suffix} model...")
self.cleanup_model(quantized_model)
print(f"\n{'#'*80}")
print(f"✅ Completed {output_suffix}")
print(f"{'#'*80}")
total_time = time.time() - total_start
else:
# Regular single model test
print(f"\n{'='*80}")
print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}")
print(f"{'='*80}")
print(f"📦 Model: {self.model_path.name}")
print(f"📁 Model path: {self.model_path}")
print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}")
print(f"🧵 Threads: {threads}")
print(f"📏 Context size: {ctx_size}")
print(f"{'='*80}")
total_start = time.time()
# Run tests
for i, dataset in enumerate(datasets, 1):
print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...")
# Use quick dataset if in quick mode
test_path = dataset['path']
if self.quick_mode:
test_path = self.create_quick_dataset(dataset['path'])
result = self.run_perplexity_test(
dataset['name'],
test_path,
threads,
ctx_size
)
self.results.append(result)
total_time = time.time() - total_start
# Clean up temporary files
if self.quick_mode:
print(f"\n🧹 Cleaning up temporary files...")
self.cleanup_temp_files()
# Save results
self.save_results()
# Print summary
self.print_summary(total_time)
def save_results(self):
"""Save results to CSV file."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = self.model_path.stem
# Use custom CSV path if provided
if self.csv_output:
csv_file = self.csv_output
# Create parent directory if needed
csv_file.parent.mkdir(parents=True, exist_ok=True)
else:
csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv"
print(f"\n💾 Saving results...")
with open(csv_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status'])
writer.writeheader()
for result in self.results:
writer.writerow({
'dataset': result['dataset'],
'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A',
'time_seconds': f"{result['time']:.2f}",
'status': result['status']
})
print(f" ✅ CSV saved: {csv_file}")
# Save detailed log
log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log"
with open(log_file, 'w') as f:
f.write(f"Perplexity Test Results\n")
f.write(f"{'='*80}\n")
f.write(f"Model: {self.model_path}\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"{'='*80}\n\n")
for result in self.results:
f.write(f"\n{'='*80}\n")
f.write(f"Dataset: {result['dataset']}\n")
f.write(f"Perplexity: {result['perplexity']}\n")
f.write(f"Time: {result['time']:.2f}s\n")
f.write(f"Status: {result['status']}\n")
f.write(f"\nOutput:\n{result['stdout']}\n")
if result['stderr']:
f.write(f"\nErrors:\n{result['stderr']}\n")
print(f" ✅ Log saved: {log_file}")
def print_summary(self, total_time):
"""Print summary of all tests."""
print(f"\n\n{'='*80}")
print(f"📊 TEST SUMMARY")
print(f"{'='*80}\n")
# Sort results by perplexity (lower is better)
successful = [r for r in self.results if r['perplexity'] is not None]
failed = [r for r in self.results if r['perplexity'] is None]
if successful:
# Extract numeric value from "mean±std" format for sorting
def get_ppl_value(result):
ppl = result['perplexity']
if isinstance(ppl, str) and '±' in ppl:
return float(ppl.split('±')[0])
elif isinstance(ppl, str):
try:
return float(ppl)
except ValueError:
return float('inf')
return ppl
successful_sorted = sorted(successful, key=get_ppl_value)
print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}")
print(f"{'-'*80}")
for result in successful_sorted:
ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A'
print(f"{result['dataset']:<20} {ppl_str:>20} "
f"{result['time']:>12.2f} {result['status']:<15}")
best_ppl = str(successful_sorted[0]['perplexity'])
print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} "
f"(PPL: {best_ppl})")
if failed:
print(f"\n❌ Failed tests ({len(failed)}):")
for result in failed:
print(f" - {result['dataset']}: {result['status']}")
print(f"\n{'='*80}")
print(f"✅ Completed: {len(successful)}/{len(self.results)}")
print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f"📁 Results saved in: {self.output_dir}")
print(f"{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets')
parser.add_argument('--model', '-m',
required=True,
help='Path to GGUF model file')
parser.add_argument('--data-dir', '-d',
default='data',
help='Directory containing dataset folders (default: data)')
parser.add_argument('--threads', '-t',
type=int,
default=16,
help='Number of threads (default: 16)')
parser.add_argument('--ctx-size', '-c',
type=int,
default=512,
help='Context size (default: 512)')
parser.add_argument('--output-dir', '-o',
default='perplexity_results',
help='Output directory for results (default: perplexity_results)')
parser.add_argument('--llama-perplexity',
default='./build/bin/llama-perplexity',
help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)')
parser.add_argument('--quick', '-q',
action='store_true',
help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)')
parser.add_argument('--test-embeddings', '-e',
action='store_true',
help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)')
parser.add_argument('--csv-output',
help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)')
parser.add_argument('--quantize-bin',
default='./build/bin/llama-quantize',
help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)')
args = parser.parse_args()
try:
tester = PerplexityTester(
model_path=args.model,
llama_perplexity_bin=args.llama_perplexity,
data_dir=args.data_dir,
output_dir=args.output_dir,
quick_mode=args.quick,
quantize_bin=args.quantize_bin,
test_embeddings=args.test_embeddings,
csv_output=args.csv_output
)
tester.run_all_tests(
threads=args.threads,
ctx_size=args.ctx_size
)
except FileNotFoundError as e:
print(f"❌ Error: {e}")
return 1
except KeyboardInterrupt:
print("\n\n⚠️ Test interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())
+151
View File
@@ -0,0 +1,151 @@
#!/bin/bash
# Monitor power consumption for llama-bench with different thread configurations
# Usage: ./monitor_power.sh <model_path> <output_csv> <pp_threads> <tg_threads>
# Example: ./monitor_power.sh models/model.gguf results.csv "1,2,4,8" "1,2,4,8"
set -e
# Parse arguments
if [ $# -ne 4 ]; then
echo "Usage: $0 <model_path> <output_csv> <pp_threads> <tg_threads>"
echo "Example: $0 models/model.gguf results.csv \"1,2,4,8\" \"1,2,4,8\""
exit 1
fi
MODEL_PATH="$1"
OUTPUT_CSV="$2"
PP_THREADS="$3"
TG_THREADS="$4"
TEMP_LOG="/tmp/power_monitor_$$.log"
PID_FILE="/tmp/monitor_$$.pid"
BENCH_OUTPUT="/tmp/bench_output_$$.txt"
# Validate model exists
if [ ! -f "$MODEL_PATH" ]; then
echo "Error: Model file not found: $MODEL_PATH"
exit 1
fi
# Create output directory if needed
mkdir -p "$(dirname "$OUTPUT_CSV")"
# Function to monitor CPU stats
monitor_cpu() {
local log_file="$1"
echo "Timestamp,CPU_Usage(%),Avg_Freq(MHz)" > "$log_file"
while [ -f "$PID_FILE" ]; do
cpu_usage=$(top -bn1 | grep "Cpu(s)" | awk '{print 100-$8}')
avg_freq=$(grep "cpu MHz" /proc/cpuinfo | awk '{sum+=$4; count++} END {printf "%.0f", sum/count}')
timestamp=$(date +%s.%N)
echo "$timestamp,$cpu_usage,$avg_freq" >> "$log_file"
sleep 0.5
done
}
# Function to calculate average power
calculate_power() {
local log_file="$1"
awk -F',' 'NR>1 {sum_cpu+=$2; count++} END {
if (count > 0) {
avg_cpu = sum_cpu/count
est_power = avg_cpu * 200 / 100
printf "%.2f", est_power
} else {
print "0"
}
}' "$log_file"
}
# Function to extract throughput from llama-bench output
extract_throughput() {
local bench_output="$1"
local workload="$2"
grep "$workload" "$bench_output" | awk '{
# Extract mean from "mean ± std" format
for (i=1; i<=NF; i++) {
if ($(i+1) == "±") {
printf "%.2f", $i
exit
}
}
}'
}
# Function to run single benchmark
run_benchmark() {
local workload="$1" # "pp" or "tg"
local threads="$2"
local n_flag=""
if [ "$workload" = "pp" ]; then
n_flag="-n 0"
workload_name="pp128"
else
n_flag="-n 128"
workload_name="tg128"
fi
# Output progress to stderr (won't be captured in CSV)
echo "Testing $workload_name with $threads threads..." >&2
# Start monitoring
touch "$PID_FILE"
monitor_cpu "$TEMP_LOG" &
local monitor_pid=$!
# Run benchmark
./build/bin/llama-bench -m "$MODEL_PATH" -p 128 $n_flag -t "$threads" -ngl 0 > "$BENCH_OUTPUT" 2>&1
# Stop monitoring
rm -f "$PID_FILE"
wait $monitor_pid 2>/dev/null || true
# Extract results
local throughput=$(extract_throughput "$BENCH_OUTPUT" "$workload_name")
local power=$(calculate_power "$TEMP_LOG")
if [ -z "$throughput" ] || [ "$throughput" = "0" ]; then
echo "Warning: Failed to extract throughput for $workload_name, threads=$threads" >&2
throughput="0"
fi
# Calculate J/t (Joules per token)
local j_per_token=$(awk -v p="$power" -v t="$throughput" 'BEGIN {
if (t > 0) printf "%.4f", p/t; else print "0"
}')
# Output progress to stderr
echo " Throughput: $throughput t/s, Power: $power W, Energy: $j_per_token J/t" >&2
# Only output CSV line to stdout (this will be captured)
echo "$workload_name,$threads,$throughput,$power,$j_per_token"
}
# Initialize CSV
echo "Workload,Threads,Throughput(t/s),Power(W),Energy(J/t)" > "$OUTPUT_CSV"
# Test PP workloads
IFS=',' read -ra PP_ARRAY <<< "$PP_THREADS"
for threads in "${PP_ARRAY[@]}"; do
threads=$(echo "$threads" | xargs) # trim whitespace
result=$(run_benchmark "pp" "$threads")
echo "$result" >> "$OUTPUT_CSV"
done
# Test TG workloads
IFS=',' read -ra TG_ARRAY <<< "$TG_THREADS"
for threads in "${TG_ARRAY[@]}"; do
threads=$(echo "$threads" | xargs) # trim whitespace
result=$(run_benchmark "tg" "$threads")
echo "$result" >> "$OUTPUT_CSV"
done
# Cleanup
rm -f "$TEMP_LOG" "$BENCH_OUTPUT" "$PID_FILE"
echo ""
echo "=== Benchmark Complete ==="
echo "Results saved to: $OUTPUT_CSV"
echo ""
cat "$OUTPUT_CSV"
+362
View File
@@ -0,0 +1,362 @@
#!/usr/bin/env python3
"""
GEMM Configuration Tuning Script
This script automatically tunes ROW_BLOCK_SIZE, COL_BLOCK_SIZE, and PARALLEL_SIZE
to find the optimal configuration for maximum throughput (t/s).
"""
import subprocess
import os
import re
import csv
import shutil
from datetime import datetime
from pathlib import Path
import argparse
class GemmTuner:
def __init__(self, config_path, model_path, threads=16):
self.config_path = Path(config_path)
self.model_path = model_path
self.threads = threads
self.backup_path = self.config_path.parent / f"gemm-config.h.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.build_dir = Path("../build")
self.results = []
def backup_config(self):
"""Backup current configuration file"""
print(f"📦 Backing up current config to {self.backup_path}")
shutil.copy2(self.config_path, self.backup_path)
def restore_config(self):
"""Restore original configuration file"""
print(f"♻️ Restoring original config from {self.backup_path}")
shutil.copy2(self.backup_path, self.config_path)
def generate_config(self, act_parallel, row_block_size, col_block_size, parallel_size):
"""Generate new configuration file with simplified format"""
content = ""
# Simplified configuration format
if act_parallel:
content += "#define ACT_PARALLEL\n"
content += f"#define ROW_BLOCK_SIZE {row_block_size}\n"
content += f"#define COL_BLOCK_SIZE {col_block_size}\n"
content += f"#define PARALLEL_SIZE {parallel_size}\n"
with open(self.config_path, 'w') as f:
f.write(content)
def rebuild_project(self):
"""Rebuild project"""
print("🔨 Rebuilding project...")
result = subprocess.run(
["cmake", "--build", str(self.build_dir), "--target", "llama-bench"],
capture_output=True,
text=True,
cwd=os.getcwd()
)
if result.returncode != 0:
print(f"⚠️ Build warning/error: {result.stderr}")
return False
return True
def run_benchmark(self):
"""Run benchmark test"""
cmd = [
f"{self.build_dir}/bin/llama-bench",
"-m", self.model_path,
"-p", "128",
"-n", "0",
"-t", str(self.threads),
"-ngl", "0"
]
print(f"⚡ Running benchmark: {' '.join(cmd)}")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"❌ Benchmark failed: {result.stderr}")
return None
return result.stdout
def parse_throughput(self, output):
"""Parse pp128 throughput from output"""
# 匹配 pp128: | pp128 | 501.06 ± 11.37 |
pp_pattern = r'\|\s+pp128\s+\|\s+([\d.]+)\s+±\s+([\d.]+)\s+\|'
pp_match = re.search(pp_pattern, output)
if pp_match:
pp_throughput = float(pp_match.group(1))
pp_std_dev = float(pp_match.group(2))
return {
'pp_throughput': pp_throughput,
'pp_std_dev': pp_std_dev
}
return None
def test_configuration(self, act_parallel, row_block_size, col_block_size, parallel_size):
"""Test single configuration"""
config_name = f"ACT_{'ON' if act_parallel else 'OFF'}_R{row_block_size}_C{col_block_size}_P{parallel_size}"
print(f"\n{'='*80}")
print(f"🧪 Testing configuration: {config_name}")
print(f" ACT_PARALLEL: {act_parallel}")
print(f" ROW_BLOCK_SIZE: {row_block_size}")
print(f" COL_BLOCK_SIZE: {col_block_size}")
print(f" PARALLEL_SIZE: {parallel_size}")
print(f"{'='*80}")
# Generate configuration
self.generate_config(act_parallel, row_block_size, col_block_size, parallel_size)
# Rebuild project
if not self.rebuild_project():
print("⚠️ Build failed, skipping this configuration")
return None
# Run benchmark test
output = self.run_benchmark()
if output is None:
return None
# Parse results
metrics = self.parse_throughput(output)
if metrics is not None:
result = {
'act_parallel': act_parallel,
'row_block_size': row_block_size,
'col_block_size': col_block_size,
'parallel_size': parallel_size,
'config_name': config_name,
**metrics
}
self.results.append(result)
print(f"✅ PP128: {metrics['pp_throughput']:.2f} ± {metrics['pp_std_dev']:.2f} t/s")
return result
else:
print("❌ Failed to parse throughput")
return None
def save_results(self, csv_path):
"""Save results to CSV file"""
print(f"\n💾 Saving results to {csv_path}")
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=[
'config_name', 'act_parallel', 'row_block_size',
'col_block_size', 'parallel_size',
'pp_throughput', 'pp_std_dev'
])
writer.writeheader()
writer.writerows(self.results)
def find_best_config(self):
"""Find the best configuration with highest throughput"""
if not self.results:
print("❌ No valid results found")
return None
best = max(self.results, key=lambda x: x['pp_throughput'])
return best
def run_tuning(self, configurations, output_csv=None):
"""Run test for all configurations"""
print(f"\n🚀 Starting tuning process with {len(configurations)} configurations")
print(f"📊 Model: {self.model_path}")
print(f"🧵 Threads: {self.threads}\n")
# Backup configuration
self.backup_config()
try:
# Test all configurations
for i, config in enumerate(configurations, 1):
print(f"\n[{i}/{len(configurations)}]")
self.test_configuration(**config)
# Save results
if output_csv is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
csv_path = f"../stats/tuning_results_{timestamp}.csv"
else:
csv_path = output_csv
# Ensure stats directory exists
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
self.save_results(csv_path)
# Find best configuration
best = self.find_best_config()
if best:
print(f"\n{'='*80}")
print(f"🏆 BEST CONFIGURATION FOUND!")
print(f"{'='*80}")
print(f"Configuration: {best['config_name']}")
print(f"ACT_PARALLEL: {best['act_parallel']}")
print(f"ROW_BLOCK_SIZE: {best['row_block_size']}")
print(f"COL_BLOCK_SIZE: {best['col_block_size']}")
print(f"PARALLEL_SIZE: {best['parallel_size']}")
print(f"PP128 Throughput: {best['pp_throughput']:.2f} ± {best['pp_std_dev']:.2f} t/s")
print(f"{'='*80}\n")
# Show the configuration that will be written
print("Configuration to be written to gemm-config.h:")
print("-" * 80)
if best['act_parallel']:
print("#define ACT_PARALLEL")
print(f"#define ROW_BLOCK_SIZE {best['row_block_size']}")
print(f"#define COL_BLOCK_SIZE {best['col_block_size']}")
print(f"#define PARALLEL_SIZE {best['parallel_size']}")
print("-" * 80)
# Apply best configuration
apply = input("\nDo you want to apply this configuration to gemm-config.h? (y/n): ").strip().lower()
if apply == 'y':
self.generate_config(
best['act_parallel'],
best['row_block_size'],
best['col_block_size'],
best['parallel_size']
)
self.rebuild_project()
print("✅ Best configuration applied and project rebuilt!")
else:
self.restore_config()
print("✅ Original configuration restored")
# Clean up backup file
if self.backup_path.exists():
self.backup_path.unlink()
print(f"🗑️ Removed backup file: {self.backup_path}")
except KeyboardInterrupt:
print("\n⚠️ Tuning interrupted by user")
self.restore_config()
# Clean up backup file
if self.backup_path.exists():
self.backup_path.unlink()
print(f"🗑️ Removed backup file: {self.backup_path}")
except Exception as e:
print(f"\n❌ Error during tuning: {e}")
self.restore_config()
# Clean up backup file
if self.backup_path.exists():
self.backup_path.unlink()
print(f"🗑️ Removed backup file: {self.backup_path}")
raise
def generate_configurations():
"""Generate list of configurations to test"""
configurations = []
act_parallel_options = [True]
row_sizes = [2, 4, 8]#[2, 4, 8, 16, 32]
col_sizes = [32, 64]#[32, 64, 128, 256, 512, 1024]
parallelism_degree = [4]
for act_parallel in act_parallel_options:
for row in row_sizes:
for col in col_sizes:
for parallel in parallelism_degree:
# Add filtering conditions
if act_parallel:
# When ACT_PARALLEL=True, only calculate combinations with parallel < row
if parallel > row:
continue
else:
# When ACT_PARALLEL=False, only calculate combinations with parallel < col
if parallel > col:
continue
configurations.append({
'act_parallel': act_parallel,
'row_block_size': row,
'col_block_size': col,
'parallel_size': parallel
})
return configurations
def main():
parser = argparse.ArgumentParser(description='Tune GEMM configuration for optimal performance')
parser.add_argument('--config', default='../include/gemm-config.h',
help='Path to gemm-config.h file')
parser.add_argument('--model', default='../models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf',
help='Path to model file')
parser.add_argument('--threads', type=int, default=8,
help='Number of threads to use')
parser.add_argument('--quick', action='store_true',
help='Quick test with fewer configurations')
parser.add_argument('--custom', action='store_true',
help='Manually specify configurations to test')
parser.add_argument('--output', type=str, default=None,
help='Output CSV file path (default: stats/tuning_results_<timestamp>.csv)')
args = parser.parse_args()
tuner = GemmTuner(args.config, args.model, args.threads)
if args.custom:
# Custom configuration mode
print("Custom configuration mode")
configurations = []
while True:
print("\nEnter configuration (or 'done' to finish):")
act = input("ACT_PARALLEL (y/n): ").strip().lower() == 'y'
if input == 'done':
break
row = int(input("ROW_BLOCK_SIZE: "))
col = int(input("COL_BLOCK_SIZE: "))
par = int(input("PARALLEL_SIZE: "))
configurations.append({
'act_parallel': act,
'row_block_size': row,
'col_block_size': col,
'parallel_size': par
})
elif args.quick:
# Quick test mode - test only a few key configurations
configurations = [
{'act_parallel': True, 'row_block_size': 4, 'col_block_size': 128, 'parallel_size': 4},
{'act_parallel': True, 'row_block_size': 8, 'col_block_size': 128, 'parallel_size': 4},
{'act_parallel': True, 'row_block_size': 4, 'col_block_size': 64, 'parallel_size': 4},
{'act_parallel': False, 'row_block_size': 32, 'col_block_size': 4, 'parallel_size': 4},
{'act_parallel': False, 'row_block_size': 16, 'col_block_size': 4, 'parallel_size': 4},
]
else:
# Full test mode
configurations = generate_configurations()
print(f"\n{'='*80}")
print(f"GEMM Configuration Tuner")
print(f"{'='*80}")
print(f"Total configurations to test: {len(configurations)}")
print(f"Estimated time: ~{len(configurations) * 0.5:.1f} minutes (assuming 30s per test)")
print(f"{'='*80}\n")
proceed = input("Proceed with tuning? (y/n): ").strip().lower()
if proceed != 'y':
print("Tuning cancelled")
return
tuner.run_tuning(configurations, output_csv=args.output)
if __name__ == "__main__":
main()