Merge pull request #379 from XsquirrelC/main
BitNet CPU Inference Optimization
@@ -10,10 +10,10 @@ bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.5
|
||||
|
||||
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
|
||||
|
||||
<img src="./assets/m2_performance.jpg" alt="m2_performance" width="800"/>
|
||||
<img src="./assets/intel_performance.jpg" alt="m2_performance" width="800"/>
|
||||
**Latest optimization** introduces parallel kernel implementations with configurable tiling and embedding quantization support, achieving **1.15x to 2.1x** additional speedup over the original implementation across different hardware platforms and workloads. For detailed technical information, see the [optimization guide](src/README.md).
|
||||
|
||||
<img src="./assets/performance.png" alt="performance_comparison" width="800"/>
|
||||
|
||||
>The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
|
||||
|
||||
## Demo
|
||||
|
||||
@@ -22,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
|
||||
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
|
||||
|
||||
## What's New:
|
||||
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) 
|
||||
- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/XsquirrelC/BitNet/blob/main/src/README.md) 
|
||||
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md)
|
||||
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
|
||||
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
|
||||
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
|
||||
|
||||
|
Before Width: | Height: | Size: 353 KiB |
|
Before Width: | Height: | Size: 238 KiB |
|
After Width: | Height: | Size: 1.1 MiB |
@@ -0,0 +1,35 @@
|
||||
#define ACT_PARALLEL
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||
#if defined(ACT_PARALLEL)
|
||||
#define ROW_BLOCK_SIZE 4
|
||||
#define COL_BLOCK_SIZE 128
|
||||
#define PARALLEL_SIZE 4
|
||||
#else
|
||||
#define ROW_BLOCK_SIZE 128
|
||||
#define COL_BLOCK_SIZE 32
|
||||
#define PARALLEL_SIZE 8
|
||||
#endif // ACT_PARALLEL
|
||||
#elif defined(__ARM_NEON)
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
#if defined(ACT_PARALLEL)
|
||||
#define ROW_BLOCK_SIZE 8
|
||||
#define COL_BLOCK_SIZE 256
|
||||
#define PARALLEL_SIZE 8
|
||||
#else
|
||||
#define ROW_BLOCK_SIZE 64
|
||||
#define COL_BLOCK_SIZE 16
|
||||
#define PARALLEL_SIZE 2
|
||||
#endif // ACT_PARALLEL
|
||||
#else
|
||||
#if defined(ACT_PARALLEL)
|
||||
#define ROW_BLOCK_SIZE 8
|
||||
#define COL_BLOCK_SIZE 256
|
||||
#define PARALLEL_SIZE 4
|
||||
#else
|
||||
#define ROW_BLOCK_SIZE 128
|
||||
#define COL_BLOCK_SIZE 32
|
||||
#define PARALLEL_SIZE 4
|
||||
#endif // ACT_PARALLEL
|
||||
#endif // __ARM_FEATURE_DOTPROD
|
||||
#endif // __AVX__
|
||||
|
||||
@@ -64,8 +64,8 @@ SUPPORTED_QUANT_TYPES = {
|
||||
}
|
||||
|
||||
COMPILER_EXTRA_ARGS = {
|
||||
"arm64": ["-DBITNET_ARM_TL1=ON"],
|
||||
"x86_64": ["-DBITNET_X86_TL2=ON"]
|
||||
"arm64": ["-DBITNET_ARM_TL1=OFF"],
|
||||
"x86_64": ["-DBITNET_X86_TL2=OFF"]
|
||||
}
|
||||
|
||||
OS_EXTRA_ARGS = {
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
# BitNet CPU Inference Optimization
|
||||
|
||||
This update provides significant performance improvements for BitNet inference on CPU through paralleled kernel implementations, native I2_S GEMM/GEMV support, configurable tiling block size and embedding quantization.
|
||||
|
||||
## Update
|
||||
|
||||
- **Parallel Weight & Activation Computation**
|
||||
Implemented parallel processing of weights and activations in the W2A8 vet_dot kernel, achieving improved throughput on both x86 and ARM architectures.
|
||||
|
||||
- **Native I2_S GEMM & GEMV Support**
|
||||
Integrated I2_S GEMM and GEMV operations into ggml library, making them fully compatible with the llama.cpp architecture. This enables seamless integration with existing inference pipelines.
|
||||
|
||||
- **Configurable Tiling & Parallelism**
|
||||
Introduced configurable GEMM & GEMV block sizes and parallelism levels, allowing performance fine-tuning for different CPU architectures.
|
||||
|
||||
- **Embedding Quantization**
|
||||
Added support for embedding layer quantization with Q6_K format, reducing memory footprint and improving inference speed while maintaining high accuracy.
|
||||
|
||||
## Usage
|
||||
|
||||
### Configuration Options
|
||||
|
||||
The `include/gemm-config.h` file controls kernel behavior:
|
||||
|
||||
```c
|
||||
#define ROW_BLOCK_SIZE 4
|
||||
#define COL_BLOCK_SIZE 128
|
||||
#define PARALLEL_SIZE 4
|
||||
```
|
||||
|
||||
Modify these values based on your CPU cache size and architecture for optimal performance. Users can fine-tune performance on their machine through `include/gemm-config.h`.
|
||||
|
||||
### Enabling Embedding Quantization
|
||||
|
||||
To use embedding quantization for additional speedup:
|
||||
|
||||
**Using setup_env.py:**
|
||||
```bash
|
||||
python setup_env.py --quant-embd
|
||||
```
|
||||
This automatically converts embeddings to Q6_K format.
|
||||
|
||||
**Manual conversion:**
|
||||
```bash
|
||||
build/bin/llama-quantize --token-embedding-type Q6_K models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf I2_S 1 1
|
||||
```
|
||||
|
||||
## Optimizations
|
||||
|
||||
### 1. Weight & Activation Parallelism
|
||||
|
||||
The kernel implements two parallelization strategies:
|
||||
|
||||
- **Weight Parallel:** Processes multiple weight rows/columns in a single kernel call, reducing kernel launch overhead.
|
||||
|
||||
- **Activation Parallel:** Built on top of weight parallel, amortizes the I2_S weight unpacking cost across multiple activation elements.
|
||||
|
||||
**Recommendation:** For I2_S quantization format, activation parallel is recommended due to the unpack operation benefits. The current kernel defaults to activation parallel.
|
||||
|
||||
**Kernel Performance Comparison:**
|
||||
|
||||
<div align="center">
|
||||
|
||||
Test configuration: AMD EPYC 7V13 (x86), 1 threads, time in milliseconds (mean±std)
|
||||
|
||||
| Matrix Size | No Parallel | Weight Parallel | Activation Parallel |
|
||||
|:---:|:---:|:---:|:---:|
|
||||
| [1, 2048] × [2048, 2048] | 0.075±0.012 | **0.058±0.007** | 0.076±0.011 |
|
||||
| [32, 2048] × [2048, 2048] | 2.400±0.041 | 1.599±0.020 | **1.202±0.018** |
|
||||
| [128, 2048] × [2048, 2048] | 10.820±0.039 | 6.458±0.168 | **5.805±0.039** |
|
||||
| [256, 2048] × [2048, 2048] | 21.669±0.080 | 12.739±0.183 | **11.882±0.040** |
|
||||
| [512, 2048] × [2048, 2048] | 43.257±0.083 | 25.680±0.335 | **23.342±0.082** |
|
||||
| [2048, 2048] × [2048, 2048] | 173.175±0.214 | 103.112±0.552 | **93.276±0.612** |
|
||||
| [128, 2048] × [2048, 8192] | 43.345±0.090 | 25.541±0.239 | **23.528±0.052** |
|
||||
| [128, 8192] × [8192, 2048] | 38.085±0.162 | 23.866±0.096 | **22.569±0.132** |
|
||||
|
||||
</div>
|
||||
|
||||
### 2. GEMM/GEMV Integration with llama.cpp
|
||||
|
||||
Integrated I2_S quantization format into llama.cpp's compute graph:
|
||||
|
||||
- **GEMV Operations:** Optimized matrix-vector multiplication for token generation.
|
||||
- **GEMM Operations:** Efficient matrix-matrix multiplication for prompt processing.
|
||||
- **Tiling Strategy:** Configurable block sizes for optimal cache utilization.
|
||||
|
||||
### 3. Configuration Fine-tuning
|
||||
|
||||
Fine-tuning kernel parameters for optimal performance on specific hardware:
|
||||
|
||||
**Example Configuration (x86, AMD EPYC 7V13):**
|
||||
- Method: Activation Parallel
|
||||
- Threads: 8
|
||||
- Workload: 128 prompt tokens (pp128)
|
||||
|
||||
**Fine-tuning Parameters:**
|
||||
- **Parallelism Degree:** [2, 4, 8]
|
||||
- **Row Block Size:** [2, 4, 8, 16, 32]
|
||||
- **Column Block Size:** [32, 64, 128, 256, 512, 1024]
|
||||
|
||||
**Fine-tuning Results:**
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./assets/fine_tuning_result.png" alt="fine_tune_result" width="800"/>
|
||||
|
||||
*Shows throughput (tokens/s) for various configurations.*
|
||||
|
||||
</div>
|
||||
|
||||
**Optimal Configuration:** Under this setup (x86, 8 threads, pp128), the best performance is achieved with parallelism degree = 4, row block size = 4, and column block size = 128.
|
||||
|
||||
### 4. Embedding Quantization
|
||||
|
||||
Evaluated multiple embedding quantization formats to balance memory usage, model quality, and inference speed:
|
||||
|
||||
**Perplexity Comparison:**
|
||||
|
||||
<div align="center">
|
||||
|
||||
Test configuration: BitNet-b1.58-2B-4T, TG128
|
||||
|
||||
| Embedding Type | Wikitext | PTB | LAMBADA | IMDB | AG NEWS |
|
||||
|:---:|:---:|:---:|:---:|:---:|:---:|
|
||||
| **F32** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
|
||||
| **F16** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 |
|
||||
| **Q8_0** | 17.1197±0.1280 | 33.1181±0.4893 | 43.2891±0.6364 | 29.3133±0.2892 | 36.7740±0.3920 |
|
||||
| **Q6_K** | 17.1487±0.1282 | 33.2203±0.4914 | 43.3046±0.6362 | 29.3491±0.2897 | 36.7972±0.3921 |
|
||||
| **Q5_0** | 17.2379±0.1288 | 33.2439±0.4907 | 43.4631±0.6379 | 29.5481±0.2920 | 36.8539±0.3924 |
|
||||
| **Q4_0** | 17.3529±0.1300 | 33.7754±0.5001 | 44.4552±0.6559 | 30.1044±0.2978 | 37.3985±0.3997 |
|
||||
| **Q3_K** | 17.6434±0.1320 | 34.3914±0.5089 | 45.4591±0.6735 | 30.8476±0.3069 | 39.5692±0.4259 |
|
||||
| **I2_S** | N/A | N/A | N/A | N/A | N/A |
|
||||
|
||||
**N/A indicates model failure due to extreme quantization.*
|
||||
|
||||
</div>
|
||||
|
||||
**Inference Speed Comparison:**
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./assets/embedding_throughput.png" alt="embedding_throughput" width="800"/>
|
||||
|
||||
*Token generation throughput (tg128) for different embedding quantization types.*
|
||||
|
||||
</div>
|
||||
|
||||
**Recommendation:** Based on comprehensive evaluation of memory footprint, perplexity preservation, and inference speed, **Q6_K** is selected as the optimal embedding quantization format.
|
||||
|
||||
## Performance
|
||||
|
||||
Comparison of optimized parallel kernels vs. original implementation:
|
||||
|
||||
**Test Configuration:**
|
||||
- Model: BitNet-b1.58-2B-4T
|
||||
- Hardware: AMD EPYC 7V13
|
||||
- Threads: 1 / 2 / 4 / 8 / 12 / 16
|
||||
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
|
||||
- Method: Activation Parallel
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./assets/performance_comparison_amd_epyc.png" alt="performance_comparison_amd_epyc" width="800"/>
|
||||
|
||||
</div>
|
||||
|
||||
**Test Configuration:**
|
||||
- Model: BitNet-b1.58-2B-4T
|
||||
- Hardware: Intel i7-13800H
|
||||
- Threads: 1 / 2 / 4 / 6
|
||||
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
|
||||
- Method: Activation Parallel
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./assets/performance_comparison_i7-13800h.png" alt="performance_comparison_i7-13800h" width="800"/>
|
||||
|
||||
</div>
|
||||
|
||||
**Test Configuration:**
|
||||
- Model: BitNet-b1.58-2B-4T
|
||||
- Hardware: Cobalt 100
|
||||
- Threads: 1 / 2 / 4 / 8
|
||||
- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128)
|
||||
- Method: Activation Parallel
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="./assets/performance_comparison_cobalt100_dotprod.png" alt="performance_comparison_cobalt100_dotprod" width="800"/>
|
||||
|
||||
</div>
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Key Files Modified
|
||||
|
||||
- `src/ggml-bitnet-mad.cpp`: Parallel kernel implementations
|
||||
- `3rdparty/llama.cpp/ggml/src/ggml.c`: GEMM/GEMV integration
|
||||
- `include/gemm-config.h`: Configuration file
|
||||
|
||||
### Supported Architectures
|
||||
|
||||
- ✅ x86-64 with AVX2
|
||||
- ✅ ARM with NEON
|
||||
- ✅ ARM with DOTPROD extension
|
||||
|
After Width: | Height: | Size: 183 KiB |
|
After Width: | Height: | Size: 341 KiB |
|
After Width: | Height: | Size: 313 KiB |
|
After Width: | Height: | Size: 290 KiB |
|
After Width: | Height: | Size: 260 KiB |
@@ -109,12 +109,12 @@ def main():
|
||||
except OSError as e:
|
||||
print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
|
||||
|
||||
if gguf_f32_output.exists():
|
||||
print(f"Removing f32 GGUF: {gguf_f32_output}")
|
||||
try:
|
||||
gguf_f32_output.unlink()
|
||||
except OSError as e:
|
||||
print(f"Warning: Could not remove {gguf_f32_output}: {e}")
|
||||
# if gguf_f32_output.exists():
|
||||
# print(f"Removing f32 GGUF: {gguf_f32_output}")
|
||||
# try:
|
||||
# gguf_f32_output.unlink()
|
||||
# except OSError as e:
|
||||
# print(f"Warning: Could not remove {gguf_f32_output}: {e}")
|
||||
|
||||
if input_backup_file.exists():
|
||||
if not input_file.exists():
|
||||
|
||||
@@ -0,0 +1,473 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding Quantization Script
|
||||
This script converts ggml-model-f32.gguf to multiple quantized versions
|
||||
with different token embedding types.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import argparse
|
||||
import re
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class EmbeddingQuantizer:
|
||||
def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
|
||||
bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
|
||||
self.input_model = Path(input_model)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.quantize_bin = Path(quantize_bin)
|
||||
self.bench_bin = Path(bench_bin)
|
||||
self.stats_dir = Path(stats_dir)
|
||||
self.csv_output = Path(csv_output) if csv_output else None
|
||||
|
||||
# Verify input file exists
|
||||
if not self.input_model.exists():
|
||||
raise FileNotFoundError(f"Input model not found: {self.input_model}")
|
||||
|
||||
# Verify quantize tool exists
|
||||
if not self.quantize_bin.exists():
|
||||
raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
|
||||
|
||||
# Verify bench tool exists
|
||||
if not self.bench_bin.exists():
|
||||
raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
|
||||
|
||||
# Create output directories
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.stats_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.results = []
|
||||
self.newly_created_files = set() # Track newly created files
|
||||
|
||||
def quantize(self, embedding_type, output_suffix):
|
||||
"""
|
||||
Perform single quantization
|
||||
|
||||
Args:
|
||||
embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
|
||||
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
|
||||
|
||||
Returns:
|
||||
bool: Whether successful
|
||||
"""
|
||||
output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||||
|
||||
# Check if file already exists
|
||||
file_already_existed = output_file.exists()
|
||||
|
||||
if file_already_existed:
|
||||
print(f"ℹ️ File already exists: {output_file}")
|
||||
print(f" Skipping quantization, will use existing file for benchmark")
|
||||
return True
|
||||
|
||||
cmd = [
|
||||
str(self.quantize_bin),
|
||||
"--token-embedding-type", embedding_type,
|
||||
str(self.input_model),
|
||||
str(output_file),
|
||||
"I2_S",
|
||||
"1",
|
||||
"1"
|
||||
]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔄 Quantizing with embedding type: {embedding_type}")
|
||||
print(f"📥 Input: {self.input_model}")
|
||||
print(f"📤 Output: {output_file}")
|
||||
print(f"💻 Command: {' '.join(cmd)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
timeout=600 # 10 minute timeout
|
||||
)
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
if result.returncode == 0:
|
||||
# Get output file size
|
||||
file_size_mb = output_file.stat().st_size / (1024 * 1024)
|
||||
|
||||
print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
|
||||
|
||||
# Record newly created file
|
||||
if not file_already_existed:
|
||||
self.newly_created_files.add(output_file)
|
||||
|
||||
# Print part of output
|
||||
if result.stdout:
|
||||
print("\n📊 Quantization output:")
|
||||
print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Failed with return code {result.returncode}")
|
||||
print(f"Error: {result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"❌ Timeout (exceeded 10 minutes)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Exception: {e}")
|
||||
return False
|
||||
|
||||
def benchmark_model(self, output_suffix):
|
||||
"""
|
||||
Benchmark model
|
||||
|
||||
Args:
|
||||
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with benchmark results, or None if failed
|
||||
"""
|
||||
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||||
|
||||
if not model_file.exists():
|
||||
print(f"❌ Model file not found for benchmarking: {model_file}")
|
||||
return None
|
||||
|
||||
cmd = [
|
||||
str(self.bench_bin),
|
||||
"-m", str(model_file),
|
||||
"-p", "128",
|
||||
"-n", "0",
|
||||
"-t", "1,2,4,8",
|
||||
"-ngl", "0"
|
||||
]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🏃 Running benchmark for: {output_suffix}")
|
||||
print(f"💻 Command: {' '.join(cmd)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
timeout=300 # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✅ Benchmark completed successfully")
|
||||
print("\n📊 Benchmark output:")
|
||||
print(result.stdout)
|
||||
|
||||
# 解析输出
|
||||
bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
|
||||
return bench_results
|
||||
else:
|
||||
print(f"❌ Benchmark failed with return code {result.returncode}")
|
||||
print(f"Error: {result.stderr}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"❌ Benchmark timeout (exceeded 5 minutes)")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Benchmark exception: {e}")
|
||||
return None
|
||||
|
||||
def parse_benchmark_output(self, output, output_suffix):
|
||||
"""
|
||||
Parse benchmark output to extract t/s data (mean±std)
|
||||
|
||||
Args:
|
||||
output: Benchmark command output
|
||||
output_suffix: Output file suffix
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with parsed results
|
||||
"""
|
||||
results = {
|
||||
'embedding_type': output_suffix,
|
||||
'threads_1': None,
|
||||
'threads_2': None,
|
||||
'threads_4': None,
|
||||
'threads_8': None,
|
||||
}
|
||||
|
||||
# Parse table data
|
||||
# Find lines containing pp128 and t/s
|
||||
lines = output.strip().split('\n')
|
||||
|
||||
for line in lines:
|
||||
# Skip header and separator lines
|
||||
if '|' not in line or 'model' in line or '---' in line:
|
||||
continue
|
||||
|
||||
# Try to extract data
|
||||
# Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
|
||||
parts = [p.strip() for p in line.split('|')]
|
||||
|
||||
if len(parts) >= 8 and 'pp128' in parts[6]:
|
||||
threads_str = parts[5].strip()
|
||||
throughput_str = parts[7].strip()
|
||||
|
||||
# Extract thread count
|
||||
try:
|
||||
threads = int(threads_str)
|
||||
except:
|
||||
continue
|
||||
|
||||
# Extract t/s data (format: "405.73 ± 3.69" or "405.73")
|
||||
# Try to match "mean ± std" format
|
||||
match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
|
||||
if match_with_std:
|
||||
mean = float(match_with_std.group(1))
|
||||
std = float(match_with_std.group(2))
|
||||
throughput = f"{mean:.2f}±{std:.2f}"
|
||||
else:
|
||||
# Only mean, no std
|
||||
match = re.search(r'([\d.]+)', throughput_str)
|
||||
if match:
|
||||
throughput = f"{float(match.group(1)):.2f}"
|
||||
else:
|
||||
continue
|
||||
|
||||
# Store result based on thread count
|
||||
if threads == 1:
|
||||
results['threads_1'] = throughput
|
||||
elif threads == 2:
|
||||
results['threads_2'] = throughput
|
||||
elif threads == 4:
|
||||
results['threads_4'] = throughput
|
||||
elif threads == 8:
|
||||
results['threads_8'] = throughput
|
||||
|
||||
return results
|
||||
|
||||
def cleanup_model(self, output_suffix):
|
||||
"""
|
||||
Cleanup model files (only delete newly created files)
|
||||
|
||||
Args:
|
||||
output_suffix: Output file suffix
|
||||
"""
|
||||
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||||
|
||||
if model_file in self.newly_created_files:
|
||||
try:
|
||||
model_file.unlink()
|
||||
print(f"🗑️ Deleted newly created file: {model_file}")
|
||||
self.newly_created_files.remove(model_file)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to delete {model_file}: {e}")
|
||||
else:
|
||||
print(f"ℹ️ Keeping existing file: {model_file}")
|
||||
|
||||
def run_all_quantizations(self, types_to_quantize):
|
||||
"""
|
||||
Run all quantizations
|
||||
|
||||
Args:
|
||||
types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
|
||||
"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🚀 Starting Embedding Quantization and Benchmarking")
|
||||
print(f"{'='*80}")
|
||||
print(f"📥 Input model: {self.input_model}")
|
||||
print(f"📤 Output directory: {self.output_dir}")
|
||||
print(f"📊 Stats directory: {self.stats_dir}")
|
||||
print(f"🔢 Total quantizations: {len(types_to_quantize)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
total_start = datetime.now()
|
||||
|
||||
for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
|
||||
print(f"\n{'#'*80}")
|
||||
print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
|
||||
print(f"{'#'*80}\n")
|
||||
|
||||
# Quantize model
|
||||
success = self.quantize(embedding_type, output_suffix)
|
||||
|
||||
if not success:
|
||||
print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
|
||||
continue
|
||||
|
||||
# Run benchmark
|
||||
bench_results = self.benchmark_model(output_suffix)
|
||||
|
||||
if bench_results:
|
||||
self.results.append(bench_results)
|
||||
else:
|
||||
print(f"⚠️ Benchmark failed for {output_suffix}")
|
||||
|
||||
# Cleanup model files (only delete newly created files)
|
||||
self.cleanup_model(output_suffix)
|
||||
|
||||
print(f"\n{'#'*80}")
|
||||
print(f"✅ Completed {output_suffix}")
|
||||
print(f"{'#'*80}\n")
|
||||
|
||||
total_end = datetime.now()
|
||||
total_duration = (total_end - total_start).total_seconds()
|
||||
|
||||
# 保存结果到CSV
|
||||
self.save_results_to_csv()
|
||||
|
||||
# 打印总结
|
||||
self.print_summary(total_duration)
|
||||
|
||||
def save_results_to_csv(self):
|
||||
"""将benchmark结果保存到CSV文件"""
|
||||
if not self.results:
|
||||
print("⚠️ No results to save")
|
||||
return
|
||||
|
||||
# Use user-specified CSV path, otherwise use default path
|
||||
if self.csv_output:
|
||||
csv_file = self.csv_output
|
||||
# Ensure parent directory exists
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
csv_file = self.stats_dir / f"embedding_benchmark.csv"
|
||||
|
||||
print(f"\n💾 Saving results to: {csv_file}")
|
||||
|
||||
try:
|
||||
with open(csv_file, 'w', newline='') as f:
|
||||
fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
for result in self.results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"✅ Results saved successfully")
|
||||
|
||||
# Also print table
|
||||
print(f"\n📊 Benchmark Results:")
|
||||
print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
|
||||
print("-" * 87)
|
||||
for result in self.results:
|
||||
t1 = result['threads_1'] if result['threads_1'] else "N/A"
|
||||
t2 = result['threads_2'] if result['threads_2'] else "N/A"
|
||||
t4 = result['threads_4'] if result['threads_4'] else "N/A"
|
||||
t8 = result['threads_8'] if result['threads_8'] else "N/A"
|
||||
print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
|
||||
def print_summary(self, total_duration):
|
||||
"""Print quantization summary"""
|
||||
print(f"\n\n{'='*80}")
|
||||
print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
successful = len(self.results)
|
||||
total = len(self.results)
|
||||
|
||||
print(f"✅ Completed: {successful} benchmarks")
|
||||
print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
|
||||
|
||||
if self.results:
|
||||
if self.csv_output and self.csv_output.exists():
|
||||
print(f"📁 Results saved to: {self.csv_output}")
|
||||
else:
|
||||
csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
|
||||
if csv_files:
|
||||
latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
|
||||
print(f"📁 Results saved to: {latest_csv}")
|
||||
|
||||
print(f"\n{'='*80}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
|
||||
parser.add_argument('--input', '-i',
|
||||
default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
|
||||
help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
|
||||
parser.add_argument('--output-dir', '-o',
|
||||
default='../models/BitNet-b1.58-2B-4T',
|
||||
help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
|
||||
parser.add_argument('--quantize-bin', '-q',
|
||||
default='../build/bin/llama-quantize',
|
||||
help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
|
||||
parser.add_argument('--bench-bin', '-b',
|
||||
default='../build/bin/llama-bench',
|
||||
help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
|
||||
parser.add_argument('--stats-dir',
|
||||
default='../stats',
|
||||
help='Directory to save benchmark results (default: ../stats)')
|
||||
parser.add_argument('--csv-output', '-c',
|
||||
help='Custom path for CSV output file (e.g., stats/my_results.csv)')
|
||||
parser.add_argument('--types', '-t',
|
||||
nargs='+',
|
||||
help='Specific types to quantize (e.g., f32 q6_k q4_0)')
|
||||
parser.add_argument('--skip-existing', '-s',
|
||||
action='store_true',
|
||||
help='Skip quantization if output file already exists (will still benchmark existing files)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Define all supported quantization types
|
||||
# Format: (embedding_type for command line, output_suffix for filename)
|
||||
all_types = [
|
||||
('F32', 'f32'),
|
||||
('F16', 'f16'),
|
||||
('Q8_0', 'q8_0'),
|
||||
('Q6_K', 'q6_k'),
|
||||
('Q5_0', 'q5_0'),
|
||||
('Q4_0', 'q4_0'),
|
||||
('Q3_K', 'q3_k'),
|
||||
('TQ2_0', 'tq2_0'),
|
||||
]
|
||||
|
||||
# If specific types are specified, filter the list
|
||||
if args.types:
|
||||
types_lower = [t.lower() for t in args.types]
|
||||
types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
|
||||
if not types_to_quantize:
|
||||
print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
|
||||
return
|
||||
else:
|
||||
types_to_quantize = all_types
|
||||
|
||||
# If skip existing files is enabled, no need to filter
|
||||
# Because new logic will automatically detect and skip during quantization, but will still benchmark
|
||||
|
||||
# 创建量化器并运行
|
||||
try:
|
||||
quantizer = EmbeddingQuantizer(
|
||||
args.input,
|
||||
args.output_dir,
|
||||
args.quantize_bin,
|
||||
args.bench_bin,
|
||||
args.stats_dir,
|
||||
args.csv_output
|
||||
)
|
||||
quantizer.run_all_quantizations(types_to_quantize)
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return 1
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Quantization interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main() or 0)
|
||||
@@ -0,0 +1,573 @@
|
||||
#!/bin/bash
|
||||
# Unified GEMM kernel benchmark script
|
||||
# Builds, tests, and benchmarks the GEMM kernel with configurable output
|
||||
|
||||
set -e
|
||||
|
||||
# Default values
|
||||
BUILD_DIR="../build"
|
||||
ITERATIONS=1000
|
||||
OUTPUT_CSV=""
|
||||
SKIP_BUILD=false
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Print usage
|
||||
print_usage() {
|
||||
cat << EOF
|
||||
Usage: $0 [options]
|
||||
|
||||
Options:
|
||||
-o, --output <path> Output CSV file path (default: ../stats/gemm_kernel_test_noparal.csv)
|
||||
-i, --iterations <num> Number of iterations per test (default: 1000)
|
||||
-s, --skip-build Skip building the benchmark binary
|
||||
-h, --help Show this help message
|
||||
|
||||
Examples:
|
||||
# Run with default settings
|
||||
$0
|
||||
|
||||
# Specify custom output file
|
||||
$0 -o /path/to/my_results.csv
|
||||
|
||||
# Quick test with fewer iterations
|
||||
$0 -i 100 -o quick_test.csv
|
||||
|
||||
# Skip build if already compiled
|
||||
$0 -s -o results.csv
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-o|--output)
|
||||
OUTPUT_CSV="$2"
|
||||
shift 2
|
||||
;;
|
||||
-i|--iterations)
|
||||
ITERATIONS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-s|--skip-build)
|
||||
SKIP_BUILD=true
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
print_usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
print_usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Set default output CSV if not specified
|
||||
if [ -z "$OUTPUT_CSV" ]; then
|
||||
OUTPUT_CSV="${SCRIPT_DIR}/../stats/gemm_kernel_test_noparal.csv"
|
||||
fi
|
||||
|
||||
# Create output directory first
|
||||
mkdir -p "$(dirname "$OUTPUT_CSV")"
|
||||
|
||||
# Convert to absolute path
|
||||
if [[ "$OUTPUT_CSV" = /* ]]; then
|
||||
# Already absolute path
|
||||
OUTPUT_CSV="$OUTPUT_CSV"
|
||||
else
|
||||
# Convert relative path to absolute
|
||||
OUTPUT_CSV="$(cd "$(dirname "$OUTPUT_CSV")" && pwd)/$(basename "$OUTPUT_CSV")"
|
||||
fi
|
||||
|
||||
echo "=========================================="
|
||||
echo "GEMM Kernel Benchmark Suite"
|
||||
echo "=========================================="
|
||||
echo "Configuration:"
|
||||
echo " Iterations: $ITERATIONS"
|
||||
echo " Output CSV: $OUTPUT_CSV"
|
||||
echo " Skip build: $SKIP_BUILD"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# Build the benchmark binary
|
||||
if [ "$SKIP_BUILD" = false ]; then
|
||||
echo "Step 1: Building GEMM kernel benchmark..."
|
||||
echo "------------------------------------------"
|
||||
|
||||
CXX=${CXX:-g++}
|
||||
|
||||
# Create build directory if it doesn't exist
|
||||
mkdir -p "${SCRIPT_DIR}/${BUILD_DIR}"
|
||||
|
||||
# Create temporary C++ source file
|
||||
TEMP_CPP="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel_temp.cpp"
|
||||
|
||||
cat > "${TEMP_CPP}" << 'EOF'
|
||||
/**
|
||||
* Standalone benchmark for ggml_gemm_i2_i8_s kernel
|
||||
*
|
||||
* This program tests the performance of the ggml_gemm_i2_i8_s kernel
|
||||
* with configurable matrix sizes and iteration counts.
|
||||
*
|
||||
* Usage: ./test_gemm_kernel [options]
|
||||
* -n <size> : embedding dimension (must be divisible by 4, default: 2048)
|
||||
* -r <rows> : number of rows in matrix Y (default: 32)
|
||||
* -c <cols> : number of columns in matrix X (default: 128)
|
||||
* -i <iters> : number of iterations (default: 1000)
|
||||
* -w <warmup> : number of warmup iterations (default: 10)
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <assert.h>
|
||||
|
||||
// Include necessary headers
|
||||
#include "../include/gemm-config.h"
|
||||
|
||||
// Function declarations (from ggml-quants.h)
|
||||
extern "C" void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc);
|
||||
|
||||
// GEMM kernel definition
|
||||
void ggml_gemm_i2_i8_s(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
#if defined(ACT_PARALLEL)
|
||||
const int64_t row_block = ROW_BLOCK_SIZE;
|
||||
const int64_t col_block = COL_BLOCK_SIZE;
|
||||
|
||||
for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
|
||||
int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
|
||||
for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
|
||||
int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
|
||||
const void * vy_r = (const uint8_t *)vy + r0 * n;
|
||||
for (int64_t c = 0; c < cur_c; ++c) {
|
||||
const int64_t col = c0 + c;
|
||||
float * s_col = s + col;
|
||||
const void * vx_col = (const uint8_t *)vx + col * n / 4;
|
||||
ggml_vec_dot_i2_i8_s(n, s_col + r0 * bs, bs, vx_col, n, vy_r, n, cur_r);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const int64_t row_block = ROW_BLOCK_SIZE;
|
||||
const int64_t col_block = COL_BLOCK_SIZE;
|
||||
|
||||
for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
|
||||
int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
|
||||
for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
|
||||
int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
|
||||
const void * vx_c = (const uint8_t *)vx + c0 * n / 4;
|
||||
for (int64_t r = 0; r < cur_r; ++r) {
|
||||
const int64_t row = r0 + r;
|
||||
float * s_row = s + row * bs;
|
||||
const void * vy_row = (const uint8_t *)vy + row * n;
|
||||
ggml_vec_dot_i2_i8_s(n, s_row + c0, bs, vx_c, n, vy_row, n, cur_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Helper function to get current time in nanoseconds
|
||||
double get_time_ns() {
|
||||
struct timespec ts;
|
||||
clock_gettime(CLOCK_MONOTONIC, &ts);
|
||||
return ts.tv_sec * 1e9 + ts.tv_nsec;
|
||||
}
|
||||
|
||||
// Initialize matrix with random i2 values (2-bit quantized)
|
||||
void init_matrix_i2(uint8_t* data, int n, int cols) {
|
||||
// i2 format: 4 values per byte (2 bits each)
|
||||
int total_bytes = n * cols / 4;
|
||||
for (int i = 0; i < total_bytes; i++) {
|
||||
data[i] = rand() & 0xFF;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize matrix with random i8 values
|
||||
void init_matrix_i8(int8_t* data, int n, int rows) {
|
||||
int total_elements = n * rows;
|
||||
for (int i = 0; i < total_elements; i++) {
|
||||
data[i] = (int8_t)((rand() % 256) - 128);
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark configuration
|
||||
struct BenchmarkConfig {
|
||||
int n; // embedding dimension (must be divisible by 4)
|
||||
int nr; // number of rows in Y matrix
|
||||
int nc; // number of columns in X matrix
|
||||
int iterations; // number of benchmark iterations
|
||||
int warmup; // number of warmup iterations
|
||||
};
|
||||
|
||||
void print_config(const BenchmarkConfig& config) {
|
||||
printf("=" "=%.78s\n", "===============================================================================");
|
||||
printf("Benchmark Configuration:\n");
|
||||
printf("=" "=%.78s\n", "===============================================================================");
|
||||
printf(" Embedding dimension (n) : %d\n", config.n);
|
||||
printf(" Matrix Y rows (nr) : %d\n", config.nr);
|
||||
printf(" Matrix X columns (nc) : %d\n", config.nc);
|
||||
printf(" Iterations : %d\n", config.iterations);
|
||||
printf(" Warmup iterations : %d\n", config.warmup);
|
||||
printf("\nMatrix sizes:\n");
|
||||
printf(" X (i2): %d x %d (%.2f KB)\n", config.nc, config.n,
|
||||
(config.nc * config.n / 4) / 1024.0);
|
||||
printf(" Y (i8): %d x %d (%.2f KB)\n", config.nr, config.n,
|
||||
(config.nr * config.n) / 1024.0);
|
||||
printf(" S (f32): %d x %d (%.2f KB)\n", config.nr, config.nc,
|
||||
(config.nr * config.nc * sizeof(float)) / 1024.0);
|
||||
printf("\nGEMM Config:\n");
|
||||
#if defined(ACT_PARALLEL)
|
||||
printf(" ACT_PARALLEL : ON\n");
|
||||
#else
|
||||
printf(" ACT_PARALLEL : OFF\n");
|
||||
#endif
|
||||
printf(" ROW_BLOCK_SIZE : %d\n", ROW_BLOCK_SIZE);
|
||||
printf(" COL_BLOCK_SIZE : %d\n", COL_BLOCK_SIZE);
|
||||
printf(" PARALLEL_SIZE : %d\n", PARALLEL_SIZE);
|
||||
printf("=" "=%.78s\n\n", "===============================================================================");
|
||||
}
|
||||
|
||||
void run_benchmark(const BenchmarkConfig& config) {
|
||||
// Allocate matrices
|
||||
printf("Allocating matrices...\n");
|
||||
|
||||
// X matrix (i2 format): nc x n, but stored as nc x (n/4) bytes
|
||||
// Align to 64 bytes for AVX-512, which is backward compatible with AVX2 (32 bytes)
|
||||
size_t x_size = config.nc * config.n / 4;
|
||||
size_t x_size_aligned = ((x_size + 63) / 64) * 64;
|
||||
uint8_t* X = (uint8_t*)aligned_alloc(64, x_size_aligned);
|
||||
|
||||
// Y matrix (i8 format): nr x n
|
||||
size_t y_size = config.nr * config.n;
|
||||
size_t y_size_aligned = ((y_size + 63) / 64) * 64;
|
||||
int8_t* Y = (int8_t*)aligned_alloc(64, y_size_aligned);
|
||||
|
||||
// Result matrix (float32): nr x nc
|
||||
size_t s_size = config.nr * config.nc * sizeof(float);
|
||||
size_t s_size_aligned = ((s_size + 63) / 64) * 64;
|
||||
float* S = (float*)aligned_alloc(64, s_size_aligned);
|
||||
|
||||
if (!X || !Y || !S) {
|
||||
fprintf(stderr, "Failed to allocate memory\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Initialize matrices with random data
|
||||
printf("Initializing matrices with random data...\n");
|
||||
srand(time(NULL));
|
||||
init_matrix_i2(X, config.n, config.nc);
|
||||
init_matrix_i8(Y, config.n, config.nr);
|
||||
memset(S, 0, config.nr * config.nc * sizeof(float));
|
||||
|
||||
// Warmup
|
||||
printf("Running %d warmup iterations...\n", config.warmup);
|
||||
for (int i = 0; i < config.warmup; i++) {
|
||||
ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
printf("Running %d benchmark iterations...\n", config.iterations);
|
||||
double total_time = 0.0;
|
||||
double min_time = 1e20;
|
||||
double max_time = 0.0;
|
||||
|
||||
for (int i = 0; i < config.iterations; i++) {
|
||||
double start = get_time_ns();
|
||||
ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
|
||||
double end = get_time_ns();
|
||||
|
||||
double elapsed = end - start;
|
||||
total_time += elapsed;
|
||||
if (elapsed < min_time) min_time = elapsed;
|
||||
if (elapsed > max_time) max_time = elapsed;
|
||||
|
||||
if ((i + 1) % 100 == 0) {
|
||||
printf(" Progress: %d/%d iterations\n", i + 1, config.iterations);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate statistics
|
||||
double avg_time_ns = total_time / config.iterations;
|
||||
double avg_time_ms = avg_time_ns / 1e6;
|
||||
double min_time_ms = min_time / 1e6;
|
||||
double max_time_ms = max_time / 1e6;
|
||||
|
||||
// Calculate GFLOPS
|
||||
// For GEMM: nr x nc x n multiply-adds = 2 * nr * nc * n FLOPs
|
||||
double flops = 2.0 * config.nr * config.nc * config.n;
|
||||
double gflops = (flops / avg_time_ns);
|
||||
|
||||
// Calculate throughput (tokens/s assuming each column is a token)
|
||||
double throughput = (config.nc * 1e9) / avg_time_ns;
|
||||
|
||||
// Print results
|
||||
printf("\n");
|
||||
printf("=" "=%.78s\n", "===============================================================================");
|
||||
printf("Benchmark Results:\n");
|
||||
printf("=" "=%.78s\n", "===============================================================================");
|
||||
printf(" Average time : %.3f ms\n", avg_time_ms);
|
||||
printf(" Min time : %.3f ms\n", min_time_ms);
|
||||
printf(" Max time : %.3f ms\n", max_time_ms);
|
||||
printf(" Std dev : %.3f ms\n", sqrt((max_time_ms - min_time_ms) * (max_time_ms - min_time_ms) / 12));
|
||||
printf("\nPerformance:\n");
|
||||
printf(" GFLOPS : %.2f\n", gflops);
|
||||
printf(" Throughput : %.2f tokens/s\n", throughput);
|
||||
printf(" Latency/token : %.3f us\n", (avg_time_ms * 1000) / config.nc);
|
||||
printf("=" "=%.78s\n", "===============================================================================");
|
||||
|
||||
// Cleanup
|
||||
free(X);
|
||||
free(Y);
|
||||
free(S);
|
||||
}
|
||||
|
||||
void print_usage(const char* program) {
|
||||
printf("Usage: %s [options]\n", program);
|
||||
printf("Options:\n");
|
||||
printf(" -n <size> Embedding dimension (must be divisible by 4, default: 2048)\n");
|
||||
printf(" -r <rows> Number of rows in matrix Y (default: 32)\n");
|
||||
printf(" -c <cols> Number of columns in matrix X (default: 128)\n");
|
||||
printf(" -i <iters> Number of iterations (default: 1000)\n");
|
||||
printf(" -w <warmup> Number of warmup iterations (default: 10)\n");
|
||||
printf(" -h Show this help message\n");
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
BenchmarkConfig config = {
|
||||
.n = 2048,
|
||||
.nr = 32,
|
||||
.nc = 128,
|
||||
.iterations = 1000,
|
||||
.warmup = 10
|
||||
};
|
||||
|
||||
// Parse command line arguments
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) {
|
||||
config.n = atoi(argv[++i]);
|
||||
} else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) {
|
||||
config.nr = atoi(argv[++i]);
|
||||
} else if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) {
|
||||
config.nc = atoi(argv[++i]);
|
||||
} else if (strcmp(argv[i], "-i") == 0 && i + 1 < argc) {
|
||||
config.iterations = atoi(argv[++i]);
|
||||
} else if (strcmp(argv[i], "-w") == 0 && i + 1 < argc) {
|
||||
config.warmup = atoi(argv[++i]);
|
||||
} else if (strcmp(argv[i], "-h") == 0) {
|
||||
print_usage(argv[0]);
|
||||
return 0;
|
||||
} else {
|
||||
fprintf(stderr, "Unknown option: %s\n", argv[i]);
|
||||
print_usage(argv[0]);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if (config.n % 4 != 0) {
|
||||
fprintf(stderr, "Error: Embedding dimension (-n) must be divisible by 4\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (config.n <= 0 || config.nr <= 0 || config.nc <= 0 || config.iterations <= 0) {
|
||||
fprintf(stderr, "Error: All size parameters must be positive\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Run benchmark
|
||||
print_config(config);
|
||||
run_benchmark(config);
|
||||
|
||||
return 0;
|
||||
}
|
||||
EOF
|
||||
|
||||
# Compiler flags
|
||||
CXXFLAGS="-O3 -march=native -mtune=native -std=c++17 -fopenmp"
|
||||
CXXFLAGS+=" -I${SCRIPT_DIR}/.. -I${SCRIPT_DIR}/../include"
|
||||
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/include"
|
||||
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/src"
|
||||
CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/include"
|
||||
CXXFLAGS+=" -DNDEBUG -ffast-math"
|
||||
|
||||
# Link flags
|
||||
LDFLAGS="-lm -lpthread"
|
||||
|
||||
# Link with pre-built libraries
|
||||
GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
|
||||
GGML_SO="${GGML_LIB_DIR}/libggml.so"
|
||||
|
||||
if [ ! -f "${GGML_SO}" ]; then
|
||||
echo "❌ Error: Cannot find libggml.so at ${GGML_SO}"
|
||||
echo "Please build the project first with: cmake --build build"
|
||||
rm -f "${TEMP_CPP}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
LDFLAGS+=" -L${GGML_LIB_DIR} -lggml -Wl,-rpath,${GGML_LIB_DIR}"
|
||||
|
||||
# Output binary
|
||||
BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
|
||||
|
||||
echo "Compiler: ${CXX}"
|
||||
echo "Building from embedded source..."
|
||||
echo ""
|
||||
|
||||
# Build
|
||||
${CXX} ${CXXFLAGS} "${TEMP_CPP}" -o ${BENCHMARK_BIN} ${LDFLAGS}
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ Build successful!"
|
||||
rm -f "${TEMP_CPP}"
|
||||
echo ""
|
||||
else
|
||||
echo "❌ Build failed!"
|
||||
rm -f "${TEMP_CPP}"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "Step 1: Skipping build (using existing binary)"
|
||||
echo "------------------------------------------"
|
||||
BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel"
|
||||
|
||||
if [ ! -f "${BENCHMARK_BIN}" ]; then
|
||||
echo "❌ Error: Benchmark binary not found at ${BENCHMARK_BIN}"
|
||||
echo "Please run without -s to build it first."
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Found existing binary"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Set LD_LIBRARY_PATH to include the GGML library directory
|
||||
GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src"
|
||||
export LD_LIBRARY_PATH="${GGML_LIB_DIR}:${LD_LIBRARY_PATH}"
|
||||
|
||||
echo "Step 2: Running benchmark tests"
|
||||
echo "------------------------------------------"
|
||||
echo "Library path: ${GGML_LIB_DIR}"
|
||||
echo ""
|
||||
|
||||
# Write CSV header
|
||||
echo "test_name,n,nr,nc,time_ms,gflops,throughput_tokens_per_sec" > "$OUTPUT_CSV"
|
||||
echo "Results will be saved to: $OUTPUT_CSV"
|
||||
echo ""
|
||||
|
||||
# Function to extract metrics and append to CSV
|
||||
extract_and_save() {
|
||||
local test_name="$1"
|
||||
local output="$2"
|
||||
|
||||
# Extract values using grep and awk
|
||||
local n=$(echo "$output" | grep "Embedding dimension" | awk '{print $5}')
|
||||
local nr=$(echo "$output" | grep "Matrix Y rows" | awk '{print $6}')
|
||||
local nc=$(echo "$output" | grep "Matrix X columns" | awk '{print $6}')
|
||||
local avg_time=$(echo "$output" | grep "Average time" | awk '{print $4}')
|
||||
local min_time=$(echo "$output" | grep "Min time" | awk '{print $4}')
|
||||
local max_time=$(echo "$output" | grep "Max time" | awk '{print $4}')
|
||||
local gflops=$(echo "$output" | grep "GFLOPS" | awk '{print $3}')
|
||||
local throughput=$(echo "$output" | grep "Throughput" | awk '{print $3}')
|
||||
|
||||
# Check if values were extracted successfully
|
||||
if [ -z "$avg_time" ] || [ -z "$min_time" ] || [ -z "$max_time" ]; then
|
||||
echo "Warning: Failed to extract timing data for ${test_name}"
|
||||
echo "${test_name},${n},${nr},${nc},N/A,N/A,N/A" >> "$OUTPUT_CSV"
|
||||
return
|
||||
fi
|
||||
|
||||
# Calculate standard deviation estimate from range
|
||||
# Using awk with proper variable passing
|
||||
local std_time=$(awk -v min="$min_time" -v max="$max_time" 'BEGIN {printf "%.4f", (max - min) / 4}')
|
||||
|
||||
# Format as mean±std
|
||||
local time_formatted="${avg_time}±${std_time}"
|
||||
|
||||
# Append to CSV
|
||||
echo "${test_name},${n},${nr},${nc},${time_formatted},${gflops},${throughput}" >> "$OUTPUT_CSV"
|
||||
}
|
||||
|
||||
# Run benchmark tests
|
||||
echo "=========================================="
|
||||
echo "BitNet-2B Typical Shapes Performance Test"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
echo "Test 1: Single Token Generation (Attention QKV projection)"
|
||||
echo " Scenario: Generating 1 token at a time"
|
||||
echo " Shape: n=2048, r=1, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 1 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "single_token_gen" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 2: Small Batch Prompt Processing (Attention QKV projection)"
|
||||
echo " Scenario: Processing prompt with 128 tokens, batch size 1"
|
||||
echo " Shape: n=2048, r=128, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "small_batch_prompt" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 3: Medium Batch Prompt Processing (Attention QKV projection)"
|
||||
echo " Scenario: Processing prompt with 256 tokens or batch of 256"
|
||||
echo " Shape: n=2048, r=256, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 256 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "medium_batch_prompt" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 4: Large Batch Processing (Attention QKV projection)"
|
||||
echo " Scenario: Processing 512 tokens or batch of 512"
|
||||
echo " Shape: n=2048, r=512, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 512 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "large_batch_prompt" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 5: FFN Up-projection (Small batch)"
|
||||
echo " Scenario: Feed-forward network expansion, 128 tokens"
|
||||
echo " Shape: n=2048, r=128, c=8192"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 8192 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "ffn_up_projection" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 6: FFN Down-projection (Small batch)"
|
||||
echo " Scenario: Feed-forward network reduction, 128 tokens"
|
||||
echo " Shape: n=8192, r=128, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 8192 -r 128 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "ffn_down_projection" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 7: Long Context Processing"
|
||||
echo " Scenario: Processing very long context (2048 tokens)"
|
||||
echo " Shape: n=2048, r=2048, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 2048 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "long_context" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "Test 8: Batched Token Generation"
|
||||
echo " Scenario: Generating tokens for 32 sequences simultaneously"
|
||||
echo " Shape: n=2048, r=32, c=2048"
|
||||
OUTPUT=$($BENCHMARK_BIN -n 2048 -r 32 -c 2048 -i $ITERATIONS 2>&1)
|
||||
echo "$OUTPUT"
|
||||
extract_and_save "batched_token_gen" "$OUTPUT"
|
||||
echo ""
|
||||
|
||||
echo "=========================================="
|
||||
echo "All tests completed successfully!"
|
||||
echo "=========================================="
|
||||
echo "Results saved to: $OUTPUT_CSV"
|
||||
echo ""
|
||||
echo "Summary:"
|
||||
wc -l "$OUTPUT_CSV" | awk '{print " Total records:", $1 - 1}'
|
||||
echo " Output file: $OUTPUT_CSV"
|
||||
echo "=========================================="
|
||||
@@ -0,0 +1,608 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Perplexity Test Script
|
||||
Tests GGUF model perplexity on multiple datasets using llama-perplexity.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import csv
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
import statistics
|
||||
|
||||
|
||||
class PerplexityTester:
|
||||
def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity",
|
||||
data_dir="../data", output_dir="perplexity_results", quick_mode=False,
|
||||
quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None):
|
||||
self.model_path = Path(model_path)
|
||||
self.llama_perplexity_bin = Path(llama_perplexity_bin)
|
||||
self.quantize_bin = Path(quantize_bin)
|
||||
self.data_dir = Path(data_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.quick_mode = quick_mode
|
||||
self.test_embeddings = test_embeddings
|
||||
self.csv_output = Path(csv_output) if csv_output else None
|
||||
self.results = []
|
||||
self.created_models = set() # Track newly created model files
|
||||
self.temp_files = [] # Track temporary files for cleanup
|
||||
|
||||
# Embedding types to test
|
||||
self.embedding_types = [
|
||||
('F32', 'f32'),
|
||||
('F16', 'f16'),
|
||||
('Q8_0', 'q8_0'),
|
||||
('Q6_K', 'q6_k'),
|
||||
('Q5_0', 'q5_0'),
|
||||
('Q4_0', 'q4_0'),
|
||||
('Q3_K', 'q3_k'),
|
||||
('TQ2_0', 'tq2_0'),
|
||||
]
|
||||
|
||||
# Create output directory
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Verify llama-perplexity binary exists
|
||||
if not self.llama_perplexity_bin.exists():
|
||||
raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}")
|
||||
|
||||
# Verify quantize binary exists if testing embeddings
|
||||
if self.test_embeddings and not self.quantize_bin.exists():
|
||||
raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}")
|
||||
|
||||
# Verify model file exists
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {self.model_path}")
|
||||
|
||||
def find_datasets(self):
|
||||
"""Find all test.txt files in dataset directories."""
|
||||
datasets = []
|
||||
|
||||
if not self.data_dir.exists():
|
||||
print(f"❌ Data directory not found: {self.data_dir}")
|
||||
return datasets
|
||||
|
||||
print(f"\n🔍 Searching for datasets in {self.data_dir}...")
|
||||
|
||||
# Look for test.txt files in subdirectories
|
||||
for dataset_dir in sorted(self.data_dir.iterdir()):
|
||||
if dataset_dir.is_dir():
|
||||
test_file = dataset_dir / "test.txt"
|
||||
if test_file.exists():
|
||||
size_mb = test_file.stat().st_size / (1024 * 1024)
|
||||
datasets.append({
|
||||
'name': dataset_dir.name,
|
||||
'path': test_file,
|
||||
'size': test_file.stat().st_size,
|
||||
'size_mb': size_mb
|
||||
})
|
||||
print(f" ✅ {dataset_dir.name:<20} ({size_mb:.2f} MB)")
|
||||
else:
|
||||
print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)")
|
||||
|
||||
return datasets
|
||||
|
||||
def create_quick_dataset(self, dataset_path, num_chars=4096):
|
||||
"""Create a temporary dataset with only the first N characters for quick testing."""
|
||||
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8')
|
||||
self.temp_files.append(temp_file.name)
|
||||
|
||||
try:
|
||||
with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read(num_chars)
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
return Path(temp_file.name)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to create quick dataset: {e}")
|
||||
temp_file.close()
|
||||
return dataset_path
|
||||
|
||||
def cleanup_temp_files(self):
|
||||
"""Clean up temporary files."""
|
||||
for temp_file in self.temp_files:
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except:
|
||||
pass
|
||||
self.temp_files = []
|
||||
|
||||
def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None):
|
||||
"""Run perplexity test on a single dataset."""
|
||||
test_model = model_override if model_override else self.model_path
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📊 Testing on dataset: {dataset_name}")
|
||||
print(f" File: {dataset_path}")
|
||||
print(f" Model: {test_model.name}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
cmd = [
|
||||
str(self.llama_perplexity_bin),
|
||||
"-m", str(test_model),
|
||||
"-f", str(dataset_path),
|
||||
"-t", str(threads),
|
||||
"-c", str(ctx_size),
|
||||
"-ngl", "0" # CPU only
|
||||
]
|
||||
|
||||
print(f"💻 Command: {' '.join(cmd)}")
|
||||
print(f"⏱️ Starting test...\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600, # 1 hour timeout
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0:
|
||||
# Parse perplexity from output (check both stdout and stderr)
|
||||
combined_output = result.stdout + "\n" + result.stderr
|
||||
ppl = self.parse_perplexity(combined_output)
|
||||
|
||||
if ppl is not None:
|
||||
print(f"\n✅ Perplexity: {ppl}")
|
||||
print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)")
|
||||
status = "success"
|
||||
else:
|
||||
print(f"\n⚠️ Test completed but could not parse perplexity")
|
||||
print(f"Last 500 chars of stdout:")
|
||||
print(result.stdout[-500:])
|
||||
print(f"Last 500 chars of stderr:")
|
||||
print(result.stderr[-500:])
|
||||
status = "parse_error"
|
||||
ppl = None
|
||||
else:
|
||||
print(f"\n❌ Test failed with return code {result.returncode}")
|
||||
print(f"Error: {result.stderr[:500]}")
|
||||
status = "failed"
|
||||
ppl = None
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
'dataset': dataset_name,
|
||||
'perplexity': ppl,
|
||||
'time': elapsed_time,
|
||||
'status': status,
|
||||
'stdout': result.stdout,
|
||||
'stderr': result.stderr
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\n❌ Timeout after {elapsed_time:.2f}s")
|
||||
return {
|
||||
'dataset': dataset_name,
|
||||
'perplexity': None,
|
||||
'time': elapsed_time,
|
||||
'status': 'timeout',
|
||||
'stdout': '',
|
||||
'stderr': 'Test exceeded 1 hour timeout'
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\n❌ Error: {e}")
|
||||
return {
|
||||
'dataset': dataset_name,
|
||||
'perplexity': None,
|
||||
'time': elapsed_time,
|
||||
'status': 'error',
|
||||
'stdout': '',
|
||||
'stderr': str(e)
|
||||
}
|
||||
|
||||
def parse_perplexity(self, output):
|
||||
"""Parse perplexity value (mean±std format) from llama-perplexity output."""
|
||||
# First try to match "PPL = mean +/- std" format
|
||||
pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)'
|
||||
match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
try:
|
||||
mean = float(match.group(1))
|
||||
std = float(match.group(2))
|
||||
return f"{mean:.4f}±{std:.4f}"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fallback to patterns without std
|
||||
patterns = [
|
||||
r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)',
|
||||
r'Final perplexity:\s*(\d+\.?\d*)',
|
||||
r'PPL\s*=\s*(\d+\.?\d*)',
|
||||
r'PPL:\s*(\d+\.?\d*)',
|
||||
r'perplexity:\s*(\d+\.?\d*)',
|
||||
r'ppl\s*=\s*(\d+\.?\d*)',
|
||||
r'Perplexity:\s*(\d+\.?\d*)',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE)
|
||||
if match:
|
||||
try:
|
||||
return f"{float(match.group(1)):.4f}"
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def quantize_embedding(self, embedding_type, output_suffix):
|
||||
"""
|
||||
Quantize model with specific embedding type.
|
||||
|
||||
Args:
|
||||
embedding_type: Token embedding type (uppercase, e.g., 'Q6_K')
|
||||
output_suffix: Output file suffix (lowercase, e.g., 'q6_k')
|
||||
|
||||
Returns:
|
||||
Path to quantized model or None if failed
|
||||
"""
|
||||
# Construct output path
|
||||
model_dir = self.model_path.parent
|
||||
output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||||
|
||||
# Check if file already exists
|
||||
file_existed = output_path.exists()
|
||||
|
||||
if file_existed:
|
||||
print(f"ℹ️ Model already exists: {output_path.name}")
|
||||
return output_path
|
||||
|
||||
cmd = [
|
||||
str(self.quantize_bin),
|
||||
"--token-embedding-type", embedding_type,
|
||||
str(self.model_path),
|
||||
str(output_path),
|
||||
"I2_S",
|
||||
"1",
|
||||
"1"
|
||||
]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔄 Quantizing with embedding type: {embedding_type}")
|
||||
print(f"📥 Input: {self.model_path.name}")
|
||||
print(f"📤 Output: {output_path.name}")
|
||||
print(f"💻 Command: {' '.join(cmd)}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
timeout=600 # 10 minutes timeout
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
if result.returncode == 0:
|
||||
file_size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
print(f"✅ Quantization successful!")
|
||||
print(f" Duration: {duration:.2f}s")
|
||||
print(f" Size: {file_size_mb:.2f} MB")
|
||||
|
||||
# Mark as newly created
|
||||
self.created_models.add(output_path)
|
||||
return output_path
|
||||
else:
|
||||
print(f"❌ Quantization failed with return code {result.returncode}")
|
||||
print(f"Error: {result.stderr[:500]}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"❌ Quantization timeout (exceeded 10 minutes)")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ Quantization error: {e}")
|
||||
return None
|
||||
|
||||
def cleanup_model(self, model_path):
|
||||
"""Delete model file if it was created during this session."""
|
||||
if model_path in self.created_models:
|
||||
try:
|
||||
model_path.unlink()
|
||||
print(f"🗑️ Deleted: {model_path.name}")
|
||||
self.created_models.remove(model_path)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to delete {model_path.name}: {e}")
|
||||
else:
|
||||
print(f"ℹ️ Keeping existing file: {model_path.name}")
|
||||
|
||||
def run_all_tests(self, threads=16, ctx_size=512):
|
||||
"""Run perplexity tests on all datasets."""
|
||||
datasets = self.find_datasets()
|
||||
|
||||
if not datasets:
|
||||
print(f"\n❌ No datasets found in {self.data_dir}")
|
||||
print(f" Make sure each dataset directory has a test.txt file")
|
||||
return
|
||||
|
||||
# Quick mode: test all datasets but only first 4096 chars with smaller context
|
||||
if self.quick_mode:
|
||||
ctx_size = min(ctx_size, 128) # Use smaller context in quick mode
|
||||
print(f"\n⚡ QUICK TEST MODE ENABLED")
|
||||
print(f" - Testing all datasets with first 4096 characters only")
|
||||
print(f" - Using reduced context size: {ctx_size}")
|
||||
|
||||
# Determine models to test
|
||||
if self.test_embeddings:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🧪 EMBEDDING QUANTIZATION TEST MODE")
|
||||
print(f"{'='*80}")
|
||||
print(f"📦 Base model: {self.model_path.name}")
|
||||
print(f"🔢 Embedding types to test: {len(self.embedding_types)}")
|
||||
print(f"📊 Datasets: {len(datasets)}")
|
||||
print(f"🧵 Threads: {threads}")
|
||||
print(f"📏 Context size: {ctx_size}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
total_start = time.time()
|
||||
|
||||
# Test each embedding type
|
||||
for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1):
|
||||
print(f"\n\n{'#'*80}")
|
||||
print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
# Quantize model
|
||||
quantized_model = self.quantize_embedding(embedding_type, output_suffix)
|
||||
|
||||
if quantized_model is None:
|
||||
print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure")
|
||||
continue
|
||||
|
||||
# Test on all datasets
|
||||
for j, dataset in enumerate(datasets, 1):
|
||||
print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...")
|
||||
|
||||
# Use quick dataset if in quick mode
|
||||
test_path = dataset['path']
|
||||
if self.quick_mode:
|
||||
test_path = self.create_quick_dataset(dataset['path'])
|
||||
|
||||
result = self.run_perplexity_test(
|
||||
f"{dataset['name']}_embed-{output_suffix}",
|
||||
test_path,
|
||||
threads,
|
||||
ctx_size,
|
||||
model_override=quantized_model
|
||||
)
|
||||
self.results.append(result)
|
||||
|
||||
# Cleanup model after testing
|
||||
print(f"\n🧹 Cleaning up {output_suffix} model...")
|
||||
self.cleanup_model(quantized_model)
|
||||
|
||||
print(f"\n{'#'*80}")
|
||||
print(f"✅ Completed {output_suffix}")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
total_time = time.time() - total_start
|
||||
|
||||
else:
|
||||
# Regular single model test
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}")
|
||||
print(f"{'='*80}")
|
||||
print(f"📦 Model: {self.model_path.name}")
|
||||
print(f"📁 Model path: {self.model_path}")
|
||||
print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}")
|
||||
print(f"🧵 Threads: {threads}")
|
||||
print(f"📏 Context size: {ctx_size}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
total_start = time.time()
|
||||
|
||||
# Run tests
|
||||
for i, dataset in enumerate(datasets, 1):
|
||||
print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...")
|
||||
|
||||
# Use quick dataset if in quick mode
|
||||
test_path = dataset['path']
|
||||
if self.quick_mode:
|
||||
test_path = self.create_quick_dataset(dataset['path'])
|
||||
|
||||
result = self.run_perplexity_test(
|
||||
dataset['name'],
|
||||
test_path,
|
||||
threads,
|
||||
ctx_size
|
||||
)
|
||||
self.results.append(result)
|
||||
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Clean up temporary files
|
||||
if self.quick_mode:
|
||||
print(f"\n🧹 Cleaning up temporary files...")
|
||||
self.cleanup_temp_files()
|
||||
|
||||
# Save results
|
||||
self.save_results()
|
||||
|
||||
# Print summary
|
||||
self.print_summary(total_time)
|
||||
|
||||
def save_results(self):
|
||||
"""Save results to CSV file."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_name = self.model_path.stem
|
||||
|
||||
# Use custom CSV path if provided
|
||||
if self.csv_output:
|
||||
csv_file = self.csv_output
|
||||
# Create parent directory if needed
|
||||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv"
|
||||
|
||||
print(f"\n💾 Saving results...")
|
||||
|
||||
with open(csv_file, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status'])
|
||||
writer.writeheader()
|
||||
for result in self.results:
|
||||
writer.writerow({
|
||||
'dataset': result['dataset'],
|
||||
'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A',
|
||||
'time_seconds': f"{result['time']:.2f}",
|
||||
'status': result['status']
|
||||
})
|
||||
|
||||
print(f" ✅ CSV saved: {csv_file}")
|
||||
|
||||
# Save detailed log
|
||||
log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log"
|
||||
with open(log_file, 'w') as f:
|
||||
f.write(f"Perplexity Test Results\n")
|
||||
f.write(f"{'='*80}\n")
|
||||
f.write(f"Model: {self.model_path}\n")
|
||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"{'='*80}\n\n")
|
||||
|
||||
for result in self.results:
|
||||
f.write(f"\n{'='*80}\n")
|
||||
f.write(f"Dataset: {result['dataset']}\n")
|
||||
f.write(f"Perplexity: {result['perplexity']}\n")
|
||||
f.write(f"Time: {result['time']:.2f}s\n")
|
||||
f.write(f"Status: {result['status']}\n")
|
||||
f.write(f"\nOutput:\n{result['stdout']}\n")
|
||||
if result['stderr']:
|
||||
f.write(f"\nErrors:\n{result['stderr']}\n")
|
||||
|
||||
print(f" ✅ Log saved: {log_file}")
|
||||
|
||||
def print_summary(self, total_time):
|
||||
"""Print summary of all tests."""
|
||||
print(f"\n\n{'='*80}")
|
||||
print(f"📊 TEST SUMMARY")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Sort results by perplexity (lower is better)
|
||||
successful = [r for r in self.results if r['perplexity'] is not None]
|
||||
failed = [r for r in self.results if r['perplexity'] is None]
|
||||
|
||||
if successful:
|
||||
# Extract numeric value from "mean±std" format for sorting
|
||||
def get_ppl_value(result):
|
||||
ppl = result['perplexity']
|
||||
if isinstance(ppl, str) and '±' in ppl:
|
||||
return float(ppl.split('±')[0])
|
||||
elif isinstance(ppl, str):
|
||||
try:
|
||||
return float(ppl)
|
||||
except ValueError:
|
||||
return float('inf')
|
||||
return ppl
|
||||
|
||||
successful_sorted = sorted(successful, key=get_ppl_value)
|
||||
|
||||
print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}")
|
||||
print(f"{'-'*80}")
|
||||
|
||||
for result in successful_sorted:
|
||||
ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A'
|
||||
print(f"{result['dataset']:<20} {ppl_str:>20} "
|
||||
f"{result['time']:>12.2f} {result['status']:<15}")
|
||||
|
||||
best_ppl = str(successful_sorted[0]['perplexity'])
|
||||
print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} "
|
||||
f"(PPL: {best_ppl})")
|
||||
|
||||
if failed:
|
||||
print(f"\n❌ Failed tests ({len(failed)}):")
|
||||
for result in failed:
|
||||
print(f" - {result['dataset']}: {result['status']}")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"✅ Completed: {len(successful)}/{len(self.results)}")
|
||||
print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
||||
print(f"📁 Results saved in: {self.output_dir}")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets')
|
||||
parser.add_argument('--model', '-m',
|
||||
required=True,
|
||||
help='Path to GGUF model file')
|
||||
parser.add_argument('--data-dir', '-d',
|
||||
default='data',
|
||||
help='Directory containing dataset folders (default: data)')
|
||||
parser.add_argument('--threads', '-t',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of threads (default: 16)')
|
||||
parser.add_argument('--ctx-size', '-c',
|
||||
type=int,
|
||||
default=512,
|
||||
help='Context size (default: 512)')
|
||||
parser.add_argument('--output-dir', '-o',
|
||||
default='perplexity_results',
|
||||
help='Output directory for results (default: perplexity_results)')
|
||||
parser.add_argument('--llama-perplexity',
|
||||
default='./build/bin/llama-perplexity',
|
||||
help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)')
|
||||
parser.add_argument('--quick', '-q',
|
||||
action='store_true',
|
||||
help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)')
|
||||
parser.add_argument('--test-embeddings', '-e',
|
||||
action='store_true',
|
||||
help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)')
|
||||
parser.add_argument('--csv-output',
|
||||
help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)')
|
||||
parser.add_argument('--quantize-bin',
|
||||
default='./build/bin/llama-quantize',
|
||||
help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
tester = PerplexityTester(
|
||||
model_path=args.model,
|
||||
llama_perplexity_bin=args.llama_perplexity,
|
||||
data_dir=args.data_dir,
|
||||
output_dir=args.output_dir,
|
||||
quick_mode=args.quick,
|
||||
quantize_bin=args.quantize_bin,
|
||||
test_embeddings=args.test_embeddings,
|
||||
csv_output=args.csv_output
|
||||
)
|
||||
|
||||
tester.run_all_tests(
|
||||
threads=args.threads,
|
||||
ctx_size=args.ctx_size
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ Error: {e}")
|
||||
return 1
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Test interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -0,0 +1,151 @@
|
||||
#!/bin/bash
|
||||
# Monitor power consumption for llama-bench with different thread configurations
|
||||
# Usage: ./monitor_power.sh <model_path> <output_csv> <pp_threads> <tg_threads>
|
||||
# Example: ./monitor_power.sh models/model.gguf results.csv "1,2,4,8" "1,2,4,8"
|
||||
|
||||
set -e
|
||||
|
||||
# Parse arguments
|
||||
if [ $# -ne 4 ]; then
|
||||
echo "Usage: $0 <model_path> <output_csv> <pp_threads> <tg_threads>"
|
||||
echo "Example: $0 models/model.gguf results.csv \"1,2,4,8\" \"1,2,4,8\""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODEL_PATH="$1"
|
||||
OUTPUT_CSV="$2"
|
||||
PP_THREADS="$3"
|
||||
TG_THREADS="$4"
|
||||
|
||||
TEMP_LOG="/tmp/power_monitor_$$.log"
|
||||
PID_FILE="/tmp/monitor_$$.pid"
|
||||
BENCH_OUTPUT="/tmp/bench_output_$$.txt"
|
||||
|
||||
# Validate model exists
|
||||
if [ ! -f "$MODEL_PATH" ]; then
|
||||
echo "Error: Model file not found: $MODEL_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create output directory if needed
|
||||
mkdir -p "$(dirname "$OUTPUT_CSV")"
|
||||
|
||||
# Function to monitor CPU stats
|
||||
monitor_cpu() {
|
||||
local log_file="$1"
|
||||
echo "Timestamp,CPU_Usage(%),Avg_Freq(MHz)" > "$log_file"
|
||||
while [ -f "$PID_FILE" ]; do
|
||||
cpu_usage=$(top -bn1 | grep "Cpu(s)" | awk '{print 100-$8}')
|
||||
avg_freq=$(grep "cpu MHz" /proc/cpuinfo | awk '{sum+=$4; count++} END {printf "%.0f", sum/count}')
|
||||
timestamp=$(date +%s.%N)
|
||||
echo "$timestamp,$cpu_usage,$avg_freq" >> "$log_file"
|
||||
sleep 0.5
|
||||
done
|
||||
}
|
||||
|
||||
# Function to calculate average power
|
||||
calculate_power() {
|
||||
local log_file="$1"
|
||||
awk -F',' 'NR>1 {sum_cpu+=$2; count++} END {
|
||||
if (count > 0) {
|
||||
avg_cpu = sum_cpu/count
|
||||
est_power = avg_cpu * 200 / 100
|
||||
printf "%.2f", est_power
|
||||
} else {
|
||||
print "0"
|
||||
}
|
||||
}' "$log_file"
|
||||
}
|
||||
|
||||
# Function to extract throughput from llama-bench output
|
||||
extract_throughput() {
|
||||
local bench_output="$1"
|
||||
local workload="$2"
|
||||
grep "$workload" "$bench_output" | awk '{
|
||||
# Extract mean from "mean ± std" format
|
||||
for (i=1; i<=NF; i++) {
|
||||
if ($(i+1) == "±") {
|
||||
printf "%.2f", $i
|
||||
exit
|
||||
}
|
||||
}
|
||||
}'
|
||||
}
|
||||
|
||||
# Function to run single benchmark
|
||||
run_benchmark() {
|
||||
local workload="$1" # "pp" or "tg"
|
||||
local threads="$2"
|
||||
local n_flag=""
|
||||
|
||||
if [ "$workload" = "pp" ]; then
|
||||
n_flag="-n 0"
|
||||
workload_name="pp128"
|
||||
else
|
||||
n_flag="-n 128"
|
||||
workload_name="tg128"
|
||||
fi
|
||||
|
||||
# Output progress to stderr (won't be captured in CSV)
|
||||
echo "Testing $workload_name with $threads threads..." >&2
|
||||
|
||||
# Start monitoring
|
||||
touch "$PID_FILE"
|
||||
monitor_cpu "$TEMP_LOG" &
|
||||
local monitor_pid=$!
|
||||
|
||||
# Run benchmark
|
||||
./build/bin/llama-bench -m "$MODEL_PATH" -p 128 $n_flag -t "$threads" -ngl 0 > "$BENCH_OUTPUT" 2>&1
|
||||
|
||||
# Stop monitoring
|
||||
rm -f "$PID_FILE"
|
||||
wait $monitor_pid 2>/dev/null || true
|
||||
|
||||
# Extract results
|
||||
local throughput=$(extract_throughput "$BENCH_OUTPUT" "$workload_name")
|
||||
local power=$(calculate_power "$TEMP_LOG")
|
||||
|
||||
if [ -z "$throughput" ] || [ "$throughput" = "0" ]; then
|
||||
echo "Warning: Failed to extract throughput for $workload_name, threads=$threads" >&2
|
||||
throughput="0"
|
||||
fi
|
||||
|
||||
# Calculate J/t (Joules per token)
|
||||
local j_per_token=$(awk -v p="$power" -v t="$throughput" 'BEGIN {
|
||||
if (t > 0) printf "%.4f", p/t; else print "0"
|
||||
}')
|
||||
|
||||
# Output progress to stderr
|
||||
echo " Throughput: $throughput t/s, Power: $power W, Energy: $j_per_token J/t" >&2
|
||||
|
||||
# Only output CSV line to stdout (this will be captured)
|
||||
echo "$workload_name,$threads,$throughput,$power,$j_per_token"
|
||||
}
|
||||
|
||||
# Initialize CSV
|
||||
echo "Workload,Threads,Throughput(t/s),Power(W),Energy(J/t)" > "$OUTPUT_CSV"
|
||||
|
||||
# Test PP workloads
|
||||
IFS=',' read -ra PP_ARRAY <<< "$PP_THREADS"
|
||||
for threads in "${PP_ARRAY[@]}"; do
|
||||
threads=$(echo "$threads" | xargs) # trim whitespace
|
||||
result=$(run_benchmark "pp" "$threads")
|
||||
echo "$result" >> "$OUTPUT_CSV"
|
||||
done
|
||||
|
||||
# Test TG workloads
|
||||
IFS=',' read -ra TG_ARRAY <<< "$TG_THREADS"
|
||||
for threads in "${TG_ARRAY[@]}"; do
|
||||
threads=$(echo "$threads" | xargs) # trim whitespace
|
||||
result=$(run_benchmark "tg" "$threads")
|
||||
echo "$result" >> "$OUTPUT_CSV"
|
||||
done
|
||||
|
||||
# Cleanup
|
||||
rm -f "$TEMP_LOG" "$BENCH_OUTPUT" "$PID_FILE"
|
||||
|
||||
echo ""
|
||||
echo "=== Benchmark Complete ==="
|
||||
echo "Results saved to: $OUTPUT_CSV"
|
||||
echo ""
|
||||
cat "$OUTPUT_CSV"
|
||||
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
GEMM Configuration Tuning Script
|
||||
This script automatically tunes ROW_BLOCK_SIZE, COL_BLOCK_SIZE, and PARALLEL_SIZE
|
||||
to find the optimal configuration for maximum throughput (t/s).
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import csv
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
class GemmTuner:
|
||||
def __init__(self, config_path, model_path, threads=16):
|
||||
self.config_path = Path(config_path)
|
||||
self.model_path = model_path
|
||||
self.threads = threads
|
||||
self.backup_path = self.config_path.parent / f"gemm-config.h.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.build_dir = Path("../build")
|
||||
self.results = []
|
||||
|
||||
def backup_config(self):
|
||||
"""Backup current configuration file"""
|
||||
print(f"📦 Backing up current config to {self.backup_path}")
|
||||
shutil.copy2(self.config_path, self.backup_path)
|
||||
|
||||
def restore_config(self):
|
||||
"""Restore original configuration file"""
|
||||
print(f"♻️ Restoring original config from {self.backup_path}")
|
||||
shutil.copy2(self.backup_path, self.config_path)
|
||||
|
||||
def generate_config(self, act_parallel, row_block_size, col_block_size, parallel_size):
|
||||
"""Generate new configuration file with simplified format"""
|
||||
content = ""
|
||||
|
||||
# Simplified configuration format
|
||||
if act_parallel:
|
||||
content += "#define ACT_PARALLEL\n"
|
||||
|
||||
content += f"#define ROW_BLOCK_SIZE {row_block_size}\n"
|
||||
content += f"#define COL_BLOCK_SIZE {col_block_size}\n"
|
||||
content += f"#define PARALLEL_SIZE {parallel_size}\n"
|
||||
|
||||
with open(self.config_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
def rebuild_project(self):
|
||||
"""Rebuild project"""
|
||||
print("🔨 Rebuilding project...")
|
||||
result = subprocess.run(
|
||||
["cmake", "--build", str(self.build_dir), "--target", "llama-bench"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"⚠️ Build warning/error: {result.stderr}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def run_benchmark(self):
|
||||
"""Run benchmark test"""
|
||||
cmd = [
|
||||
f"{self.build_dir}/bin/llama-bench",
|
||||
"-m", self.model_path,
|
||||
"-p", "128",
|
||||
"-n", "0",
|
||||
"-t", str(self.threads),
|
||||
"-ngl", "0"
|
||||
]
|
||||
|
||||
print(f"⚡ Running benchmark: {' '.join(cmd)}")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
timeout=300 # 5分钟超时
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"❌ Benchmark failed: {result.stderr}")
|
||||
return None
|
||||
|
||||
return result.stdout
|
||||
|
||||
def parse_throughput(self, output):
|
||||
"""Parse pp128 throughput from output"""
|
||||
# 匹配 pp128: | pp128 | 501.06 ± 11.37 |
|
||||
pp_pattern = r'\|\s+pp128\s+\|\s+([\d.]+)\s+±\s+([\d.]+)\s+\|'
|
||||
pp_match = re.search(pp_pattern, output)
|
||||
|
||||
if pp_match:
|
||||
pp_throughput = float(pp_match.group(1))
|
||||
pp_std_dev = float(pp_match.group(2))
|
||||
|
||||
return {
|
||||
'pp_throughput': pp_throughput,
|
||||
'pp_std_dev': pp_std_dev
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def test_configuration(self, act_parallel, row_block_size, col_block_size, parallel_size):
|
||||
"""Test single configuration"""
|
||||
config_name = f"ACT_{'ON' if act_parallel else 'OFF'}_R{row_block_size}_C{col_block_size}_P{parallel_size}"
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🧪 Testing configuration: {config_name}")
|
||||
print(f" ACT_PARALLEL: {act_parallel}")
|
||||
print(f" ROW_BLOCK_SIZE: {row_block_size}")
|
||||
print(f" COL_BLOCK_SIZE: {col_block_size}")
|
||||
print(f" PARALLEL_SIZE: {parallel_size}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# Generate configuration
|
||||
self.generate_config(act_parallel, row_block_size, col_block_size, parallel_size)
|
||||
|
||||
# Rebuild project
|
||||
if not self.rebuild_project():
|
||||
print("⚠️ Build failed, skipping this configuration")
|
||||
return None
|
||||
|
||||
# Run benchmark test
|
||||
output = self.run_benchmark()
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
# Parse results
|
||||
metrics = self.parse_throughput(output)
|
||||
|
||||
if metrics is not None:
|
||||
result = {
|
||||
'act_parallel': act_parallel,
|
||||
'row_block_size': row_block_size,
|
||||
'col_block_size': col_block_size,
|
||||
'parallel_size': parallel_size,
|
||||
'config_name': config_name,
|
||||
**metrics
|
||||
}
|
||||
self.results.append(result)
|
||||
print(f"✅ PP128: {metrics['pp_throughput']:.2f} ± {metrics['pp_std_dev']:.2f} t/s")
|
||||
return result
|
||||
else:
|
||||
print("❌ Failed to parse throughput")
|
||||
return None
|
||||
|
||||
def save_results(self, csv_path):
|
||||
"""Save results to CSV file"""
|
||||
print(f"\n💾 Saving results to {csv_path}")
|
||||
|
||||
with open(csv_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=[
|
||||
'config_name', 'act_parallel', 'row_block_size',
|
||||
'col_block_size', 'parallel_size',
|
||||
'pp_throughput', 'pp_std_dev'
|
||||
])
|
||||
writer.writeheader()
|
||||
writer.writerows(self.results)
|
||||
|
||||
def find_best_config(self):
|
||||
"""Find the best configuration with highest throughput"""
|
||||
if not self.results:
|
||||
print("❌ No valid results found")
|
||||
return None
|
||||
|
||||
best = max(self.results, key=lambda x: x['pp_throughput'])
|
||||
return best
|
||||
|
||||
def run_tuning(self, configurations, output_csv=None):
|
||||
"""Run test for all configurations"""
|
||||
print(f"\n🚀 Starting tuning process with {len(configurations)} configurations")
|
||||
print(f"📊 Model: {self.model_path}")
|
||||
print(f"🧵 Threads: {self.threads}\n")
|
||||
|
||||
# Backup configuration
|
||||
self.backup_config()
|
||||
|
||||
try:
|
||||
# Test all configurations
|
||||
for i, config in enumerate(configurations, 1):
|
||||
print(f"\n[{i}/{len(configurations)}]")
|
||||
self.test_configuration(**config)
|
||||
|
||||
# Save results
|
||||
if output_csv is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
csv_path = f"../stats/tuning_results_{timestamp}.csv"
|
||||
else:
|
||||
csv_path = output_csv
|
||||
|
||||
# Ensure stats directory exists
|
||||
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
||||
self.save_results(csv_path)
|
||||
|
||||
# Find best configuration
|
||||
best = self.find_best_config()
|
||||
if best:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🏆 BEST CONFIGURATION FOUND!")
|
||||
print(f"{'='*80}")
|
||||
print(f"Configuration: {best['config_name']}")
|
||||
print(f"ACT_PARALLEL: {best['act_parallel']}")
|
||||
print(f"ROW_BLOCK_SIZE: {best['row_block_size']}")
|
||||
print(f"COL_BLOCK_SIZE: {best['col_block_size']}")
|
||||
print(f"PARALLEL_SIZE: {best['parallel_size']}")
|
||||
print(f"PP128 Throughput: {best['pp_throughput']:.2f} ± {best['pp_std_dev']:.2f} t/s")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
# Show the configuration that will be written
|
||||
print("Configuration to be written to gemm-config.h:")
|
||||
print("-" * 80)
|
||||
if best['act_parallel']:
|
||||
print("#define ACT_PARALLEL")
|
||||
print(f"#define ROW_BLOCK_SIZE {best['row_block_size']}")
|
||||
print(f"#define COL_BLOCK_SIZE {best['col_block_size']}")
|
||||
print(f"#define PARALLEL_SIZE {best['parallel_size']}")
|
||||
print("-" * 80)
|
||||
|
||||
# Apply best configuration
|
||||
apply = input("\nDo you want to apply this configuration to gemm-config.h? (y/n): ").strip().lower()
|
||||
if apply == 'y':
|
||||
self.generate_config(
|
||||
best['act_parallel'],
|
||||
best['row_block_size'],
|
||||
best['col_block_size'],
|
||||
best['parallel_size']
|
||||
)
|
||||
self.rebuild_project()
|
||||
print("✅ Best configuration applied and project rebuilt!")
|
||||
else:
|
||||
self.restore_config()
|
||||
print("✅ Original configuration restored")
|
||||
|
||||
# Clean up backup file
|
||||
if self.backup_path.exists():
|
||||
self.backup_path.unlink()
|
||||
print(f"🗑️ Removed backup file: {self.backup_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Tuning interrupted by user")
|
||||
self.restore_config()
|
||||
# Clean up backup file
|
||||
if self.backup_path.exists():
|
||||
self.backup_path.unlink()
|
||||
print(f"🗑️ Removed backup file: {self.backup_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error during tuning: {e}")
|
||||
self.restore_config()
|
||||
# Clean up backup file
|
||||
if self.backup_path.exists():
|
||||
self.backup_path.unlink()
|
||||
print(f"🗑️ Removed backup file: {self.backup_path}")
|
||||
raise
|
||||
|
||||
|
||||
def generate_configurations():
|
||||
"""Generate list of configurations to test"""
|
||||
configurations = []
|
||||
|
||||
act_parallel_options = [True]
|
||||
|
||||
row_sizes = [2, 4, 8]#[2, 4, 8, 16, 32]
|
||||
col_sizes = [32, 64]#[32, 64, 128, 256, 512, 1024]
|
||||
parallelism_degree = [4]
|
||||
|
||||
for act_parallel in act_parallel_options:
|
||||
for row in row_sizes:
|
||||
for col in col_sizes:
|
||||
for parallel in parallelism_degree:
|
||||
# Add filtering conditions
|
||||
if act_parallel:
|
||||
# When ACT_PARALLEL=True, only calculate combinations with parallel < row
|
||||
if parallel > row:
|
||||
continue
|
||||
else:
|
||||
# When ACT_PARALLEL=False, only calculate combinations with parallel < col
|
||||
if parallel > col:
|
||||
continue
|
||||
|
||||
configurations.append({
|
||||
'act_parallel': act_parallel,
|
||||
'row_block_size': row,
|
||||
'col_block_size': col,
|
||||
'parallel_size': parallel
|
||||
})
|
||||
|
||||
return configurations
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tune GEMM configuration for optimal performance')
|
||||
parser.add_argument('--config', default='../include/gemm-config.h',
|
||||
help='Path to gemm-config.h file')
|
||||
parser.add_argument('--model', default='../models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf',
|
||||
help='Path to model file')
|
||||
parser.add_argument('--threads', type=int, default=8,
|
||||
help='Number of threads to use')
|
||||
parser.add_argument('--quick', action='store_true',
|
||||
help='Quick test with fewer configurations')
|
||||
parser.add_argument('--custom', action='store_true',
|
||||
help='Manually specify configurations to test')
|
||||
parser.add_argument('--output', type=str, default=None,
|
||||
help='Output CSV file path (default: stats/tuning_results_<timestamp>.csv)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tuner = GemmTuner(args.config, args.model, args.threads)
|
||||
|
||||
if args.custom:
|
||||
# Custom configuration mode
|
||||
print("Custom configuration mode")
|
||||
configurations = []
|
||||
while True:
|
||||
print("\nEnter configuration (or 'done' to finish):")
|
||||
act = input("ACT_PARALLEL (y/n): ").strip().lower() == 'y'
|
||||
if input == 'done':
|
||||
break
|
||||
row = int(input("ROW_BLOCK_SIZE: "))
|
||||
col = int(input("COL_BLOCK_SIZE: "))
|
||||
par = int(input("PARALLEL_SIZE: "))
|
||||
configurations.append({
|
||||
'act_parallel': act,
|
||||
'row_block_size': row,
|
||||
'col_block_size': col,
|
||||
'parallel_size': par
|
||||
})
|
||||
elif args.quick:
|
||||
# Quick test mode - test only a few key configurations
|
||||
configurations = [
|
||||
{'act_parallel': True, 'row_block_size': 4, 'col_block_size': 128, 'parallel_size': 4},
|
||||
{'act_parallel': True, 'row_block_size': 8, 'col_block_size': 128, 'parallel_size': 4},
|
||||
{'act_parallel': True, 'row_block_size': 4, 'col_block_size': 64, 'parallel_size': 4},
|
||||
{'act_parallel': False, 'row_block_size': 32, 'col_block_size': 4, 'parallel_size': 4},
|
||||
{'act_parallel': False, 'row_block_size': 16, 'col_block_size': 4, 'parallel_size': 4},
|
||||
]
|
||||
else:
|
||||
# Full test mode
|
||||
configurations = generate_configurations()
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"GEMM Configuration Tuner")
|
||||
print(f"{'='*80}")
|
||||
print(f"Total configurations to test: {len(configurations)}")
|
||||
print(f"Estimated time: ~{len(configurations) * 0.5:.1f} minutes (assuming 30s per test)")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
proceed = input("Proceed with tuning? (y/n): ").strip().lower()
|
||||
if proceed != 'y':
|
||||
print("Tuning cancelled")
|
||||
return
|
||||
|
||||
tuner.run_tuning(configurations, output_csv=args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||