diff --git a/README.md b/README.md
index 4318061..bfb09a6 100644
--- a/README.md
+++ b/README.md
@@ -10,10 +10,10 @@ bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.5
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
-
-
+**Latest optimization** introduces parallel kernel implementations with configurable tiling and embedding quantization support, achieving **1.5x to 2.1x** additional speedup over the original implementation across different hardware platforms and workloads. For detailed technical information, see the [optimization guide](src/README.md).
+
+
->The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
## Demo
@@ -214,7 +214,7 @@ optional arguments:
Directory to save the logging info
--quant-type {i2_s,tl1}, -q {i2_s,tl1}
Quantization type
- --quant-embd Quantize the embeddings to q6_k
+ --quant-embd Quantize the embeddings to f16
--use-pretuned, -p Use the pretuned kernel parameters
## Usage
diff --git a/assets/intel_performance.jpg b/assets/intel_performance.jpg
deleted file mode 100644
index 38a1bcf..0000000
Binary files a/assets/intel_performance.jpg and /dev/null differ
diff --git a/assets/m2_performance.jpg b/assets/m2_performance.jpg
deleted file mode 100644
index 9b59348..0000000
Binary files a/assets/m2_performance.jpg and /dev/null differ
diff --git a/assets/performance.png b/assets/performance.png
new file mode 100644
index 0000000..03d477d
Binary files /dev/null and b/assets/performance.png differ
diff --git a/demo_benchmark.sh b/demo_benchmark.sh
deleted file mode 100755
index dad999b..0000000
--- a/demo_benchmark.sh
+++ /dev/null
@@ -1,130 +0,0 @@
-#!/bin/bash
-
-################################################################################
-# Quick Demo of Benchmark Automation
-# This runs a subset of benchmarks to verify the script works
-################################################################################
-
-set -euo pipefail
-
-GREEN='\033[0;32m'
-BLUE='\033[0;34m'
-NC='\033[0m'
-
-STATS_DIR="stats/demo_$(date +%Y%m%d_%H%M%S)"
-mkdir -p "${STATS_DIR}"
-
-echo -e "${BLUE}========================================${NC}"
-echo -e "${BLUE}Quick Benchmark Demo (< 2 mins)${NC}"
-echo -e "${BLUE}========================================${NC}"
-echo ""
-echo "Output directory: ${STATS_DIR}"
-echo ""
-
-# Test 1: Machine info
-echo -e "${GREEN}[1/3] Collecting machine info...${NC}"
-{
- echo "=== Machine Information ==="
- echo "Architecture: $(uname -m)"
- echo "CPU cores: $(nproc)"
- echo "Timestamp: $(date)"
- echo ""
- lscpu | head -20
-} | tee "${STATS_DIR}/machine_info.txt"
-echo ""
-
-# Test 2: Quick benchmark test
-echo -e "${GREEN}[2/2] Running quick benchmark (single thread, minimal tokens)...${NC}"
-if [[ -f "build/bin/llama-bench" ]] && [[ -f "models/BitNet-b1.58-2B-4T/ggml-model-i2_s_embed_q6_k.gguf" ]]; then
- ./build/bin/llama-bench \
- -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s_embed_q6_k.gguf \
- -p 32 -n 32 -t 1 -ngl 0 \
- 2>&1 | tee "${STATS_DIR}/bench_quick.txt"
-
- # Parse results
- {
- echo "# Quick Benchmark Results"
- echo ""
- echo "| Threads | Test | Tokens/sec |"
- echo "|---------|------|------------|"
-
- awk -F '|' '
- /bitnet.*pp128/ || /bitnet.*tg128/ {
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $6);
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $7);
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $8);
- split($8, perf, "±");
- printf "| %7s | %4s | %10s |\n", $6, $7, perf[1];
- }
- ' "${STATS_DIR}/bench_quick.txt"
- } > "${STATS_DIR}/bench_results.md"
-
- echo ""
- echo -e "${GREEN}Results saved to: ${STATS_DIR}/bench_results.md${NC}"
- cat "${STATS_DIR}/bench_results.md"
-else
- echo "Skipping benchmark (model or binary not found)"
-fi
-echo ""
-
-# Test 3: Quick PPL test (using simplified dataset)
-echo -e "${GREEN}[3/3] Running quick PPL test (wiki.simple, 1 embed type)...${NC}"
-
-# Create simplified dataset if needed (first 100 lines for quick demo)
-if [[ -f "data/wikitext-2-raw/wiki.test.raw" ]]; then
- echo "Creating simplified dataset (100 lines)..."
- head -100 data/wikitext-2-raw/wiki.test.raw > data/wikitext-2-raw/wiki.simple.raw
-fi
-
-if [[ -f "build/bin/llama-perplexity" ]] && [[ -f "data/wikitext-2-raw/wiki.simple.raw" ]]; then
- {
- echo "# Quick PPL Test (Simplified Dataset)"
- echo ""
- echo "| Embed Type | Dataset | PPL |"
- echo "|------------|---------|-----|"
-
- # Test only one embed type with simplified dataset for speed
- embed="q6_k"
- model="models/BitNet-b1.58-2B-4T/ggml-model-i2_s_embed_${embed}.gguf"
- if [[ -f "$model" ]]; then
- echo "Testing: $embed on wiki.simple..."
- output=$(./build/bin/llama-perplexity \
- -m "$model" \
- -f data/wikitext-2-raw/wiki.simple.raw \
- -t 4 -ngl 0 2>&1 || true)
-
- ppl=$(echo "$output" | awk '
- /Final estimate/ && /PPL/ {
- if (match($0, /PPL[[:space:]]*=[[:space:]]*([0-9]+(\.[0-9]+)?)/, m)) {
- print m[1];
- exit;
- }
- }
- ')
-
- if [[ -n "$ppl" ]]; then
- echo "| $embed | wiki.simple | $ppl |"
- else
- echo "| $embed | wiki.simple | N/A |"
- fi
- fi
- } | tee "${STATS_DIR}/ppl_quick.md"
-
- echo ""
- echo -e "${GREEN}Results saved to: ${STATS_DIR}/ppl_quick.md${NC}"
- cat "${STATS_DIR}/ppl_quick.md"
-else
- echo "Skipping PPL test (binary or simplified dataset not found)"
- echo "Note: Full PPL test available in: ./run_paper_benchmarks.sh"
-fi
-echo ""
-
-echo -e "${BLUE}========================================${NC}"
-echo -e "${GREEN}Demo completed! (Fast mode - PPL skipped)${NC}"
-echo -e "${BLUE}========================================${NC}"
-echo ""
-echo "All results in: ${STATS_DIR}/"
-echo ""
-echo "To run the full automation script:"
-echo " ./run_paper_benchmarks.sh"
-echo ""
diff --git a/run_paper_benchmarks.sh b/run_paper_benchmarks.sh
deleted file mode 100755
index 975ddde..0000000
--- a/run_paper_benchmarks.sh
+++ /dev/null
@@ -1,720 +0,0 @@
-#!/bin/bash
-
-################################################################################
-# Paper Benchmark Automation Script
-# This script automates all experiments needed for the paper on both Intel and ARM
-################################################################################
-
-set -euo pipefail
-
-# Color codes for output
-RED='\033[0;31m'
-GREEN='\033[0;32m'
-YELLOW='\033[1;33m'
-BLUE='\033[0;34m'
-NC='\033[0m' # No Color
-
-# Configuration
-STATS_DIR="stats"
-MODEL_NAME="BitNet-b1.58-2B-4T"
-MODEL_DIR="models/${MODEL_NAME}"
-HF_REPO="microsoft/${MODEL_NAME}"
-TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
-MACHINE_INFO_FILE="${STATS_DIR}/machine_info_${TIMESTAMP}.txt"
-BENCH_RESULTS_FILE="${STATS_DIR}/bench_results_${TIMESTAMP}.md"
-BENCH_RAW_FILE="${STATS_DIR}/bench_raw_${TIMESTAMP}.txt"
-PPL_RESULTS_FILE="${STATS_DIR}/ppl_results_${TIMESTAMP}.md"
-PPL_CSV_FILE="${STATS_DIR}/ppl_results_${TIMESTAMP}.csv"
-
-# Create stats directory if not exists
-mkdir -p "${STATS_DIR}"
-
-################################################################################
-# Helper Functions
-################################################################################
-
-log_info() {
- echo -e "${BLUE}[INFO]${NC} $1"
-}
-
-log_success() {
- echo -e "${GREEN}[SUCCESS]${NC} $1"
-}
-
-log_warning() {
- echo -e "${YELLOW}[WARNING]${NC} $1"
-}
-
-log_error() {
- echo -e "${RED}[ERROR]${NC} $1"
-}
-
-section_header() {
- echo ""
- echo "================================================================================"
- echo -e "${GREEN}$1${NC}"
- echo "================================================================================"
-}
-
-################################################################################
-# Step 1: Machine Information and Environment Setup
-################################################################################
-
-step1_machine_info() {
- section_header "STEP 1: Machine Information and Environment Setup"
-
- log_info "Collecting machine information..."
-
- {
- echo "================================"
- echo "Machine Information"
- echo "================================"
- echo "Timestamp: $(date)"
- echo ""
-
- echo "--- System Architecture ---"
- uname -a
- echo ""
-
- echo "--- CPU Information ---"
- if command -v lscpu &> /dev/null; then
- lscpu
- elif [[ -f /proc/cpuinfo ]]; then
- cat /proc/cpuinfo
- else
- log_warning "Could not get CPU information"
- fi
- echo ""
-
- echo "--- CPU Cores ---"
- NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "unknown")
- echo "Number of CPU cores: ${NPROC}"
- echo ""
-
- echo "--- Memory Information ---"
- if command -v free &> /dev/null; then
- free -h
- elif command -v vm_stat &> /dev/null; then
- vm_stat
- else
- log_warning "Could not get memory information"
- fi
- echo ""
-
- echo "--- Architecture Detection ---"
- ARCH=$(uname -m)
- echo "Architecture: ${ARCH}"
- if [[ "${ARCH}" == "x86_64" ]]; then
- echo "Platform: Intel/AMD x86_64"
- elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then
- echo "Platform: ARM64"
- else
- echo "Platform: Other (${ARCH})"
- fi
- echo ""
-
- echo "--- Compiler Information ---"
- if command -v clang &> /dev/null; then
- clang --version
- fi
- if command -v gcc &> /dev/null; then
- gcc --version
- fi
- if command -v cmake &> /dev/null; then
- cmake --version
- fi
- echo ""
-
- echo "--- Python Environment ---"
- python --version || python3 --version
- if command -v conda &> /dev/null; then
- conda --version
- echo "Active conda environment: ${CONDA_DEFAULT_ENV:-none}"
- fi
- echo ""
-
- } | tee "${MACHINE_INFO_FILE}"
-
- log_success "Machine information saved to: ${MACHINE_INFO_FILE}"
-
- # Install dependencies according to README
- log_info "Installing Python dependencies..."
- if [[ -f requirements.txt ]]; then
- pip install -r requirements.txt
- log_success "Python dependencies installed"
- else
- log_warning "requirements.txt not found, skipping dependency installation"
- fi
-}
-
-################################################################################
-# Step 2: Build Project
-################################################################################
-
-step2_build() {
- section_header "STEP 2: Building Project"
-
- log_info "Configuring CMake..."
- cmake -B build -DCMAKE_BUILD_TYPE=Release
-
- log_info "Building project..."
- cmake --build build --config Release
-
- log_success "Build completed successfully"
-}
-
-################################################################################
-# Step 3: Download and Convert Model
-################################################################################
-
-step3_download_convert() {
- section_header "STEP 3: Download and Convert Model"
-
- if [[ -d "${MODEL_DIR}" ]] && [[ -f "${MODEL_DIR}/ggml-model-f32.gguf" ]]; then
- log_warning "Model directory already exists and contains f32 model, skipping download"
- read -p "Do you want to re-download and convert? (y/N): " -n 1 -r
- echo
- if [[ ! $REPLY =~ ^[Yy]$ ]]; then
- return
- fi
- fi
-
- # Create model directory
- mkdir -p "${MODEL_DIR}"
-
- # Download from HuggingFace
- log_info "Downloading model from HuggingFace: ${HF_REPO}"
- if command -v huggingface-cli &> /dev/null; then
- huggingface-cli download "${HF_REPO}" --local-dir "${MODEL_DIR}"
- else
- log_error "huggingface-cli not found. Please install it with: pip install huggingface_hub"
- exit 1
- fi
-
- # Convert to f32 GGUF using the helper script
- log_info "Converting model to f32 GGUF format..."
- if [[ -f "utils/convert-helper-bitnet.py" ]]; then
- # The script creates ggml-model-f32-bitnet.gguf, we'll rename it
- python utils/convert-helper-bitnet.py "${MODEL_DIR}"
-
- # Rename the output to match expected name
- if [[ -f "${MODEL_DIR}/ggml-model-f32-bitnet.gguf" ]]; then
- mv "${MODEL_DIR}/ggml-model-f32-bitnet.gguf" "${MODEL_DIR}/ggml-model-f32.gguf"
- fi
- else
- log_error "Convert helper script not found"
- exit 1
- fi
-
- log_success "Model downloaded and converted to f32 GGUF"
-}
-
-################################################################################
-# Step 4: Quantize Embeddings
-################################################################################
-
-step4_quantize_embeddings() {
- section_header "STEP 4: Quantize Embeddings"
-
- log_info "Running embed_quant.sh to create different embedding quantization variants..."
-
- if [[ ! -f "embed_quant.sh" ]]; then
- log_error "embed_quant.sh not found"
- exit 1
- fi
-
- bash embed_quant.sh
-
- log_success "Embedding quantization completed"
-}
-
-################################################################################
-# Step 5: Tune GEMM Block Sizes
-################################################################################
-
-step5_tune_gemm() {
- section_header "STEP 5: Tune GEMM Block Sizes"
-
- log_info "Running GEMM block size tuning..."
-
- # Backup original tune script if needed
- if [[ ! -f "tune_gemm_blocks.sh.bak" ]]; then
- cp tune_gemm_blocks.sh tune_gemm_blocks.sh.bak
- fi
-
- # Get number of threads
- NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "8")
-
- # Update the tuning script to use a broader search space
- log_info "Updating tune_gemm_blocks.sh for comprehensive search..."
-
- # Create a temporary tuning script with broader search
- cat > tune_gemm_blocks_auto.sh << 'EOF'
-#!/bin/bash
-set -e
-
-HEADER_FILE="include/gemm-config.h"
-BENCH_CMD="./build/bin/llama-bench -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s_embed_i2_s.gguf -p 128 -n 0 -t 16 -ngl 0"
-BUILD_CMD="cmake --build build --config Release -j"
-
-ACT_PARALLEL_DEFINE=true
-
-# Expanded search space for better tuning
-ROW_BLOCK_VALUES=(2 4 8)
-COL_BLOCK_VALUES=(64 128 256)
-PARALLEL_SIZE_VALUES=(2 4 8)
-
-BEST_PERF=0
-BEST_ROW_BLOCK=0
-BEST_COL_BLOCK=0
-BEST_PARALLEL_SIZE=0
-LOG_FILE="stats/tuning_log.csv"
-
-if [ -f "$HEADER_FILE" ]; then
- cp "$HEADER_FILE" "${HEADER_FILE}.bak"
-fi
-
-echo "Starting comprehensive tuning process..."
-echo "row_block,col_block,parallel_size,tokens_per_second" > "$LOG_FILE"
-
-cleanup() {
- echo "Restoring original header file..."
- if [ -f "${HEADER_FILE}.bak" ]; then
- mv "${HEADER_FILE}.bak" "$HEADER_FILE"
- fi
- echo "Tuning finished."
- echo "Best: ROW_BLOCK=${BEST_ROW_BLOCK}, COL_BLOCK=${BEST_COL_BLOCK}, PARALLEL=${BEST_PARALLEL_SIZE} -> ${BEST_PERF} tokens/s"
-}
-
-trap cleanup EXIT
-
-for ps in "${PARALLEL_SIZE_VALUES[@]}"; do
- for rb in "${ROW_BLOCK_VALUES[@]}"; do
- for cb in "${COL_BLOCK_VALUES[@]}"; do
- echo "Testing: ROW=${rb}, COL=${cb}, PARALLEL=${ps}"
-
- echo "// Auto-generated by tuning script" > "$HEADER_FILE"
- if [ "$ACT_PARALLEL_DEFINE" = "true" ]; then
- echo "#define ACT_PARALLEL" >> "$HEADER_FILE"
- fi
- echo "#if defined(ACT_PARALLEL)" >> "$HEADER_FILE"
- echo " #define ROW_BLOCK_SIZE ${rb}" >> "$HEADER_FILE"
- echo " #define COL_BLOCK_SIZE ${cb}" >> "$HEADER_FILE"
- echo " #define PARALLEL_SIZE ${ps}" >> "$HEADER_FILE"
- echo "#else" >> "$HEADER_FILE"
- echo " #define ROW_BLOCK_SIZE ${rb}" >> "$HEADER_FILE"
- echo " #define COL_BLOCK_SIZE ${cb}" >> "$HEADER_FILE"
- echo " #define PARALLEL_SIZE ${ps}" >> "$HEADER_FILE"
- echo "#endif" >> "$HEADER_FILE"
-
- $BUILD_CMD > /dev/null 2>&1
-
- output=$(eval "$BENCH_CMD" 2>&1)
-
- perf=$(echo "$output" | awk -F '|' '
- /pp128/ && /bitnet/ {
- gsub(/ /, "", $8);
- split($8, perf, "±");
- print perf[1];
- exit;
- }
- ')
-
- if [ -z "$perf" ]; then
- perf=0
- fi
-
- echo "Performance: ${perf} tokens/s"
- echo "${rb},${cb},${ps},${perf}" >> "$LOG_FILE"
-
- if (( $(echo "$perf > $BEST_PERF" | bc -l) )); then
- BEST_PERF=$perf
- BEST_ROW_BLOCK=$rb
- BEST_COL_BLOCK=$cb
- BEST_PARALLEL_SIZE=$ps
- echo "*** New best found! ***"
- fi
- done
- done
-done
-
-echo "Best configuration: ROW=${BEST_ROW_BLOCK}, COL=${BEST_COL_BLOCK}, PARALLEL=${BEST_PARALLEL_SIZE}"
-echo "Best performance: ${BEST_PERF} tokens/s"
-EOF
-
- chmod +x tune_gemm_blocks_auto.sh
- bash tune_gemm_blocks_auto.sh
-
- # Read the best configuration from the log
- if [[ -f "stats/tuning_log.csv" ]]; then
- BEST_CONFIG=$(tail -n +2 "stats/tuning_log.csv" | sort -t',' -k4 -nr | head -1)
- BEST_ROW=$(echo "$BEST_CONFIG" | cut -d',' -f1)
- BEST_COL=$(echo "$BEST_CONFIG" | cut -d',' -f2)
- BEST_PAR=$(echo "$BEST_CONFIG" | cut -d',' -f3)
- BEST_PERF=$(echo "$BEST_CONFIG" | cut -d',' -f4)
-
- log_success "Best configuration found:"
- log_success " ROW_BLOCK_SIZE=${BEST_ROW}, COL_BLOCK_SIZE=${BEST_COL}, PARALLEL_SIZE=${BEST_PAR}"
- log_success " Performance: ${BEST_PERF} tokens/s"
-
- # Apply the best configuration
- log_info "Applying best configuration to gemm-config.h..."
- cat > include/gemm-config.h << EOF
-// Auto-generated with best tuning results
-// Best performance: ${BEST_PERF} tokens/s
-#define ACT_PARALLEL
-#if defined(ACT_PARALLEL)
- #define ROW_BLOCK_SIZE ${BEST_ROW}
- #define COL_BLOCK_SIZE ${BEST_COL}
- #define PARALLEL_SIZE ${BEST_PAR}
-#else
- #define ROW_BLOCK_SIZE ${BEST_ROW}
- #define COL_BLOCK_SIZE ${BEST_COL}
- #define PARALLEL_SIZE ${BEST_PAR}
-#endif
-EOF
-
- # Rebuild with best configuration
- log_info "Rebuilding with best configuration..."
- cmake --build build --config Release -j
-
- log_success "GEMM tuning completed and applied"
- else
- log_error "Tuning log not found"
- fi
-}
-
-################################################################################
-# Step 6: Run Performance Benchmarks
-################################################################################
-
-step6_benchmark() {
- section_header "STEP 6: Running Performance Benchmarks"
-
- # Get number of threads for this machine
- NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "8")
- log_info "Detected ${NPROC} CPU cores"
-
- # Generate thread counts: 1, 2, 4, 8, 16, ...
- THREAD_COUNTS="1"
- for ((i=2; i<=NPROC; i*=2)); do
- THREAD_COUNTS="${THREAD_COUNTS},${i}"
- done
-
- log_info "Testing with thread counts: ${THREAD_COUNTS}"
-
- # Create benchmark script
- cat > bench.sh << EOF
-#!/bin/bash
-set -e
-
-MODEL="${MODEL_DIR}/ggml-model-i2_s_embed_q6_k.gguf"
-THREADS="${THREAD_COUNTS}"
-
-if [[ ! -f "\${MODEL}" ]]; then
- echo "Error: Model not found: \${MODEL}"
- exit 1
-fi
-
-./build/bin/llama-bench -m "\${MODEL}" -p 128 -n 128 -t "\${THREADS}" -ngl 0
-EOF
-
- chmod +x bench.sh
-
- log_info "Running benchmark..."
-
- # Run benchmark and capture output
- ./bench.sh 2>&1 | tee "${BENCH_RAW_FILE}"
-
- # Parse and format results
- log_info "Parsing benchmark results..."
-
- {
- echo "# Benchmark Results"
- echo ""
- echo "**Machine:** $(uname -m)"
- echo "**Timestamp:** $(date)"
- echo "**Model:** ${MODEL_NAME}"
- echo "**Quantization:** I2_S weight, Q6_K embeddings"
- echo ""
- echo "## Performance Summary"
- echo ""
- echo "| Threads | Test Type | Tokens/sec | Std Dev |"
- echo "|---------|-----------|------------|---------|"
-
- awk -F '|' '
- /bitnet.*pp128/ || /bitnet.*tg128/ {
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $6); # threads
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $7); # test
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", $8); # t/s
-
- threads = $6;
- test = $7;
-
- split($8, perf, "±");
- tokens = perf[1];
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", tokens);
-
- stddev = perf[2];
- gsub(/^[[:space:]]+|[[:space:]]+$/, "", stddev);
-
- printf "| %7s | %9s | %10s | %7s |\n", threads, test, tokens, stddev;
- }
- ' "${BENCH_RAW_FILE}"
-
- echo ""
- echo "## Detailed Output"
- echo ""
- echo '```'
- cat "${BENCH_RAW_FILE}"
- echo '```'
-
- } > "${BENCH_RESULTS_FILE}"
-
- log_success "Benchmark results saved to: ${BENCH_RESULTS_FILE}"
-}
-
-################################################################################
-# Step 7: Run PPL Benchmarks
-################################################################################
-
-step7_ppl_benchmark() {
- section_header "STEP 7: Running Perplexity (PPL) Benchmarks"
-
- log_info "Checking benchmark datasets..."
-
- # Check which datasets are available
- DATASETS=""
- for ds in data/wikitext-2-raw/wiki.test.raw data/ptb/ptb.test.txt data/lambada/lambada_test_plain_text.txt data/clue/tnews.test.txt; do
- if [[ -f "$ds" ]]; then
- DATASETS="${DATASETS} ${ds}"
- log_info "Found dataset: ${ds}"
- else
- log_warning "Dataset not found: ${ds}"
- fi
- done
-
- if [[ -z "${DATASETS}" ]]; then
- log_error "No benchmark datasets found in data/ directory"
- log_warning "Skipping PPL benchmarks"
- return
- fi
-
- log_info "Creating PPL benchmark script..."
-
- # Create a modified PPL script
- cat > embed_quant_ppl_auto.sh << 'EOFPPL'
-#!/usr/bin/env bash
-set -euo pipefail
-
-BIN="./build/bin/llama-perplexity"
-MODEL_DIR="models/BitNet-b1.58-2B-4T"
-MODEL_TEMPLATE="ggml-model-i2_s_embed_{ET}.gguf"
-
-EMBED_TYPES="f32 bf16 f16 i2_s q3_k q4_0 q5_0 q6_k tq1_0 tq2_0"
-DATASETS="DATASETS_PLACEHOLDER"
-
-THREADS="${THREADS:-16}"
-NGL="${NGL:-0}"
-
-CSV_LOG="ppl_results_temp.csv"
-
-if [[ ! -x "$BIN" ]]; then
- echo "Error: llama-perplexity not found at $BIN" >&2
- exit 1
-fi
-
-model_size_mib() {
- local f="$1"
- local sz
- sz=$(stat -c %s "$f" 2>/dev/null || stat -f %z "$f" 2>/dev/null || echo 0)
- awk -v b="$sz" 'BEGIN { printf("%.2f", b/1024/1024) }'
-}
-
-extract_ppl_final() {
- awk '
- /Final estimate/ && /PPL/ {
- if (match($0, /PPL[[:space:]]*=[[:space:]]*([0-9]+(\.[0-9]+)?)\s*\+\/\-\s*([0-9]+(\.[0-9]+)?)/, m)) {
- print m[1] "," m[3];
- found=1;
- }
- }
- END { if (!found) exit 1 }
- '
-}
-
-extract_perplexity() {
- awk '
- {
- for (i=1; i<=NF; ++i) {
- if (tolower($i) ~ /perplexity/) {
- for (j=i; j<=NF; ++j) {
- if ($j ~ /^[0-9]+(\.[0-9]+)?$/) { p=$j; break }
- gsub(/^.*=/, "", $j); gsub(/,$/, "", $j); gsub(/^\(/, "", $j); gsub(/\)$/, "", $j)
- if ($j ~ /^[0-9]+(\.[0-9]+)?$/) { p=$j; break }
- }
- }
- }
- if (p) last=p
- }
- END { if (last) print last }'
-}
-
-echo "| embed-type | model | size | dataset | threads | ppl |"
-echo "| ---------- | --------------: | -----: | ------: | ------: | ---------: |"
-echo "embed_type,model,model_size_mib,dataset,threads,perplexity,perplexity_err" > "$CSV_LOG"
-
-for et in $EMBED_TYPES; do
- model_glob="${MODEL_DIR}/$(echo "$MODEL_TEMPLATE" | sed "s/{ET}/$et/")"
-
- found_any=0
- for model in $model_glob; do
- [[ -e "$model" ]] || continue
- found_any=1
- done
-
- if [[ $found_any -eq 0 ]]; then
- echo "Warning: no models found for embed type '$et', skipping." >&2
- continue
- fi
-
- for model in $model_glob; do
- [[ -e "$model" ]] || continue
- size_mib=$(model_size_mib "$model")
-
- for ds in $DATASETS; do
- if [[ ! -r "$ds" ]]; then
- echo "Warning: dataset not found: $ds (skipping)" >&2
- continue
- fi
-
- echo "==> Testing: model=$model, dataset=$ds"
- out=$("$BIN" -m "$model" -f "$ds" -t "$THREADS" -ngl "$NGL" 2>&1 || true)
-
- ppl_pair=$(echo "$out" | extract_ppl_final || true)
- if [[ -n "${ppl_pair:-}" ]]; then
- ppl="${ppl_pair%%,*}"
- ppl_err="${ppl_pair##*,}"
- else
- ppl=$(echo "$out" | extract_perplexity || true)
- if [[ -z "${ppl:-}" ]]; then
- ppl="NA"
- fi
- ppl_err="NA"
- fi
-
- if [[ "$ppl_err" != "NA" ]]; then
- ppl_disp="$ppl ± $ppl_err"
- else
- ppl_disp="$ppl"
- fi
-
- printf "| %10s | %14s | %6s MiB | %7s | %7s | %10s |\n" \
- "$et" "$(basename "$model")" "$size_mib" "$(basename "$ds")" "$THREADS" "$ppl_disp"
-
- echo "$et,$(basename "$model"),$size_mib,$(basename "$ds"),$THREADS,$ppl,$ppl_err" >> "$CSV_LOG"
- done
- done
-done
-
-echo "Done. Results saved to $CSV_LOG"
-EOFPPL
-
- # Replace DATASETS placeholder
- sed -i "s|DATASETS_PLACEHOLDER|${DATASETS}|g" embed_quant_ppl_auto.sh
- chmod +x embed_quant_ppl_auto.sh
-
- log_info "Running PPL benchmarks (this may take a while)..."
-
- # Run the PPL benchmark
- ./embed_quant_ppl_auto.sh 2>&1 | tee "${PPL_RESULTS_FILE}.raw"
-
- # Format the results
- {
- echo "# Perplexity (PPL) Benchmark Results"
- echo ""
- echo "**Machine:** $(uname -m)"
- echo "**Timestamp:** $(date)"
- echo "**Model:** ${MODEL_NAME}"
- echo ""
- echo "## Results by Embedding Type"
- echo ""
-
- grep "^|" "${PPL_RESULTS_FILE}.raw" || true
-
- echo ""
- echo "## Summary Statistics"
- echo ""
-
- if [[ -f "ppl_results_temp.csv" ]]; then
- # Copy to final location
- cp ppl_results_temp.csv "${PPL_CSV_FILE}"
-
- # Generate summary by embed type
- echo "### Average PPL by Embedding Type"
- echo ""
- echo "| Embed Type | Avg PPL | Models Tested |"
- echo "|------------|---------|---------------|"
-
- awk -F',' '
- NR > 1 && $6 != "NA" {
- sum[$1] += $6;
- count[$1]++;
- }
- END {
- for (et in sum) {
- printf "| %10s | %7.2f | %13d |\n", et, sum[et]/count[et], count[et];
- }
- }
- ' "${PPL_CSV_FILE}" | sort -t'|' -k3 -n
-
- echo ""
- fi
-
- echo "## Full Raw Output"
- echo ""
- echo '```'
- cat "${PPL_RESULTS_FILE}.raw"
- echo '```'
-
- } > "${PPL_RESULTS_FILE}"
-
- log_success "PPL results saved to: ${PPL_RESULTS_FILE}"
- log_success "PPL CSV data saved to: ${PPL_CSV_FILE}"
-}
-
-################################################################################
-# Main Execution
-################################################################################
-
-main() {
- section_header "Paper Benchmark Automation - Starting"
-
- log_info "All results will be saved to: ${STATS_DIR}/"
- log_info "Timestamp: ${TIMESTAMP}"
-
- # Execute all steps
- step1_machine_info
- step2_build
- step3_download_convert
- step4_quantize_embeddings
- step5_tune_gemm
- step6_benchmark
- step7_ppl_benchmark
-
- # Final summary
- section_header "All Benchmarks Completed!"
-
- log_success "Results summary:"
- log_success " - Machine info: ${MACHINE_INFO_FILE}"
- log_success " - Benchmark: ${BENCH_RESULTS_FILE}"
- log_success " - PPL results: ${PPL_RESULTS_FILE}"
- log_success " - PPL CSV: ${PPL_CSV_FILE}"
- log_success " - GEMM tuning log: stats/tuning_log.csv"
-
- echo ""
- log_info "You can find all results in the ${STATS_DIR}/ directory"
-}
-
-# Run main function
-main "$@"
diff --git a/setup_env.py b/setup_env.py
index 7d84ed7..f15d65f 100644
--- a/setup_env.py
+++ b/setup_env.py
@@ -136,12 +136,12 @@ def prepare_model():
# quantize to i2s
if platform.system() != "Windows":
if quant_embd:
- run_command(["./build/bin/llama-quantize", "--token-embedding-type", "q6_k", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
+ run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
else:
run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
else:
if quant_embd:
- run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "q6_k", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
+ run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
else:
run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
@@ -228,7 +228,7 @@ def parse_args():
parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models")
parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs")
parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s")
- parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to q6_k")
+ parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16")
parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters")
return parser.parse_args()
diff --git a/test_benchmark_setup.sh b/test_benchmark_setup.sh
deleted file mode 100755
index 0190cb3..0000000
--- a/test_benchmark_setup.sh
+++ /dev/null
@@ -1,160 +0,0 @@
-#!/bin/bash
-
-################################################################################
-# Quick Test Script for Benchmark Automation
-# This script tests individual components without running full benchmarks
-################################################################################
-
-set -euo pipefail
-
-GREEN='\033[0;32m'
-RED='\033[0;31m'
-NC='\033[0m'
-
-echo "========================================"
-echo "Testing Benchmark Automation Components"
-echo "========================================"
-echo ""
-
-# Test 1: Check system info
-echo "Test 1: System Information"
-echo " Architecture: $(uname -m)"
-echo " CPU cores: $(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 'unknown')"
-echo " Python: $(python --version 2>&1 || python3 --version 2>&1)"
-if command -v cmake &> /dev/null; then
- echo -e " CMake: ${GREEN}✓${NC} $(cmake --version | head -1)"
-else
- echo -e " CMake: ${RED}✗ Not found${NC}"
-fi
-if command -v clang &> /dev/null; then
- echo -e " Clang: ${GREEN}✓${NC} $(clang --version | head -1)"
-else
- echo -e " Clang: ${RED}✗ Not found${NC}"
-fi
-echo ""
-
-# Test 2: Check required files
-echo "Test 2: Required Files"
-files=(
- "embed_quant.sh"
- "tune_gemm_blocks.sh"
- "utils/convert-helper-bitnet.py"
- "requirements.txt"
-)
-for f in "${files[@]}"; do
- if [[ -f "$f" ]]; then
- echo -e " $f: ${GREEN}✓${NC}"
- else
- echo -e " $f: ${RED}✗ Missing${NC}"
- fi
-done
-echo ""
-
-# Test 3: Check build directory
-echo "Test 3: Build Status"
-if [[ -d "build" ]]; then
- echo -e " build/ directory: ${GREEN}✓${NC}"
- if [[ -f "build/bin/llama-bench" ]]; then
- echo -e " llama-bench: ${GREEN}✓${NC}"
- else
- echo -e " llama-bench: ${RED}✗ Not built${NC}"
- fi
- if [[ -f "build/bin/llama-perplexity" ]]; then
- echo -e " llama-perplexity: ${GREEN}✓${NC}"
- else
- echo -e " llama-perplexity: ${RED}✗ Not built${NC}"
- fi
- if [[ -f "build/bin/llama-quantize" ]]; then
- echo -e " llama-quantize: ${GREEN}✓${NC}"
- else
- echo -e " llama-quantize: ${RED}✗ Not built${NC}"
- fi
-else
- echo -e " build/ directory: ${RED}✗ Not found${NC}"
-fi
-echo ""
-
-# Test 4: Check data directory
-echo "Test 4: Benchmark Datasets"
-datasets=(
- "data/wikitext-2-raw/wiki.test.raw"
- "data/ptb/ptb.test.txt"
- "data/lambada/lambada_test_plain_text.txt"
- "data/clue/tnews.test.txt"
-)
-found=0
-for ds in "${datasets[@]}"; do
- if [[ -f "$ds" ]]; then
- echo -e " $(basename $(dirname $ds)): ${GREEN}✓${NC}"
- found=$((found + 1))
- else
- echo -e " $(basename $(dirname $ds)): ${RED}✗ Not found${NC}"
- fi
-done
-echo " Total: $found/4 datasets available"
-echo ""
-
-# Test 5: Check models
-echo "Test 5: Model Files"
-MODEL_DIR="models/BitNet-b1.58-2B-4T"
-if [[ -d "$MODEL_DIR" ]]; then
- echo -e " Model directory: ${GREEN}✓${NC}"
- if [[ -f "$MODEL_DIR/ggml-model-f32.gguf" ]]; then
- echo -e " F32 model: ${GREEN}✓${NC}"
- else
- echo -e " F32 model: ${RED}✗ Not found${NC}"
- fi
-
- # Count quantized models
- quant_count=$(ls "$MODEL_DIR"/ggml-model-i2_s_embed_*.gguf 2>/dev/null | wc -l)
- if [[ $quant_count -gt 0 ]]; then
- echo -e " Quantized models: ${GREEN}✓${NC} ($quant_count files)"
- else
- echo -e " Quantized models: ${RED}✗ None found${NC}"
- fi
-else
- echo -e " Model directory: ${RED}✗ Not found${NC}"
-fi
-echo ""
-
-# Test 6: Thread count generation
-echo "Test 6: Thread Configuration"
-NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "8")
-THREAD_COUNTS="1"
-for ((i=2; i<=NPROC; i*=2)); do
- THREAD_COUNTS="${THREAD_COUNTS},${i}"
-done
-echo " Max threads: $NPROC"
-echo " Test thread counts: $THREAD_COUNTS"
-echo ""
-
-# Test 7: Check stats directory
-echo "Test 7: Output Directory"
-if [[ -d "stats" ]]; then
- echo -e " stats/ directory: ${GREEN}✓${NC}"
- file_count=$(ls stats/ 2>/dev/null | wc -l)
- echo " Files in stats/: $file_count"
-else
- echo -e " stats/ directory: ${RED}✗ Not found${NC}"
- echo " Creating stats/ directory..."
- mkdir -p stats
- echo -e " ${GREEN}✓ Created${NC}"
-fi
-echo ""
-
-# Summary
-echo "========================================"
-echo "Test Summary"
-echo "========================================"
-echo ""
-echo "To run the full benchmark automation:"
-echo " ./run_paper_benchmarks.sh"
-echo ""
-echo "To build the project first (if not built):"
-echo " cmake -B build -DCMAKE_BUILD_TYPE=Release"
-echo " cmake --build build --config Release"
-echo ""
-echo "To download and convert the model:"
-echo " huggingface-cli download microsoft/BitNet-b1.58-2B-4T --local-dir models/BitNet-b1.58-2B-4T"
-echo " python utils/convert-helper-bitnet.py models/BitNet-b1.58-2B-4T"
-echo ""
diff --git a/utils/build_test_gemm_kernel.sh b/utils/build_test_gemm_kernel.sh
new file mode 100755
index 0000000..bc45942
--- /dev/null
+++ b/utils/build_test_gemm_kernel.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+# Build script for standalone GEMM kernel benchmark
+
+set -e
+
+echo "Building GEMM kernel benchmark..."
+
+# Compiler settings
+CXX=${CXX:-g++}
+BUILD_DIR="../build"
+SRC_DIR="../src"
+
+# Create build directory if it doesn't exist
+mkdir -p ${BUILD_DIR}
+
+# Compiler flags
+CXXFLAGS="-O3 -march=native -mtune=native -std=c++17 -fopenmp"
+CXXFLAGS+=" -I.. -I../include"
+CXXFLAGS+=" -I../3rdparty/llama.cpp/ggml/include"
+CXXFLAGS+=" -I../3rdparty/llama.cpp/ggml/src"
+CXXFLAGS+=" -I../3rdparty/llama.cpp/include"
+CXXFLAGS+=" -DNDEBUG -ffast-math"
+
+# Link flags
+LDFLAGS="-lm -lpthread"
+
+# Link with pre-built libraries
+GGML_LIB_DIR="../build/3rdparty/llama.cpp/ggml/src"
+GGML_SO="${GGML_LIB_DIR}/libggml.so"
+
+if [ ! -f "${GGML_SO}" ]; then
+ echo "⚠️ Warning: Cannot find libggml.so"
+ echo "Please build the project first with: cmake --build build"
+ exit 1
+fi
+
+LDFLAGS+=" -L${GGML_LIB_DIR} -lggml -Wl,-rpath,\$ORIGIN/../../${GGML_LIB_DIR}"
+echo "Linking with libggml.so"
+
+# Source files
+SOURCES="./test_gemm_kernel.cpp"
+
+# Output binary
+OUTPUT="${BUILD_DIR}/test_gemm_kernel"
+
+echo "Compiler: ${CXX}"
+echo "Flags: ${CXXFLAGS}"
+echo "Sources: ${SOURCES}"
+echo ""
+
+# Build
+${CXX} ${CXXFLAGS} ${SOURCES} -o ${OUTPUT} ${LDFLAGS}
+
+if [ $? -eq 0 ]; then
+ echo ""
+ echo "✅ Build successful!"
+ echo "Output: ${OUTPUT}"
+ echo ""
+ echo "Usage examples:"
+ echo " # Default test (n=2048, nr=32, nc=128, 1000 iterations)"
+ echo " ${OUTPUT}"
+ echo ""
+ echo " # Custom matrix sizes"
+ echo " ${OUTPUT} -n 4096 -r 64 -c 256"
+ echo ""
+ echo " # Quick test (fewer iterations)"
+ echo " ${OUTPUT} -i 100 -w 5"
+ echo ""
+ echo " # Large-scale test"
+ echo " ${OUTPUT} -n 3200 -r 128 -c 512 -i 500"
+ echo ""
+else
+ echo ""
+ echo "❌ Build failed!"
+ exit 1
+fi
diff --git a/utils/convert-helper-bitnet.py b/utils/convert-helper-bitnet.py
index 5b4149a..9ed8db0 100644
--- a/utils/convert-helper-bitnet.py
+++ b/utils/convert-helper-bitnet.py
@@ -109,12 +109,12 @@ def main():
except OSError as e:
print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
- if gguf_f32_output.exists():
- print(f"Removing f32 GGUF: {gguf_f32_output}")
- try:
- gguf_f32_output.unlink()
- except OSError as e:
- print(f"Warning: Could not remove {gguf_f32_output}: {e}")
+ # if gguf_f32_output.exists():
+ # print(f"Removing f32 GGUF: {gguf_f32_output}")
+ # try:
+ # gguf_f32_output.unlink()
+ # except OSError as e:
+ # print(f"Warning: Could not remove {gguf_f32_output}: {e}")
if input_backup_file.exists():
if not input_file.exists():
diff --git a/utils/quantize_embeddings.py b/utils/quantize_embeddings.py
new file mode 100644
index 0000000..90b8020
--- /dev/null
+++ b/utils/quantize_embeddings.py
@@ -0,0 +1,473 @@
+#!/usr/bin/env python3
+"""
+Embedding Quantization Script
+This script converts ggml-model-f32.gguf to multiple quantized versions
+with different token embedding types.
+"""
+
+import subprocess
+import os
+import argparse
+import re
+import csv
+from pathlib import Path
+from datetime import datetime
+
+
+class EmbeddingQuantizer:
+ def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
+ bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
+ self.input_model = Path(input_model)
+ self.output_dir = Path(output_dir)
+ self.quantize_bin = Path(quantize_bin)
+ self.bench_bin = Path(bench_bin)
+ self.stats_dir = Path(stats_dir)
+ self.csv_output = Path(csv_output) if csv_output else None
+
+ # Verify input file exists
+ if not self.input_model.exists():
+ raise FileNotFoundError(f"Input model not found: {self.input_model}")
+
+ # Verify quantize tool exists
+ if not self.quantize_bin.exists():
+ raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
+
+ # Verify bench tool exists
+ if not self.bench_bin.exists():
+ raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
+
+ # Create output directories
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ self.stats_dir.mkdir(parents=True, exist_ok=True)
+
+ self.results = []
+ self.newly_created_files = set() # Track newly created files
+
+ def quantize(self, embedding_type, output_suffix):
+ """
+ Perform single quantization
+
+ Args:
+ embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
+ output_suffix: Output file suffix (lowercase format, e.g., q6_k)
+
+ Returns:
+ bool: Whether successful
+ """
+ output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ # Check if file already exists
+ file_already_existed = output_file.exists()
+
+ if file_already_existed:
+ print(f"ℹ️ File already exists: {output_file}")
+ print(f" Skipping quantization, will use existing file for benchmark")
+ return True
+
+ cmd = [
+ str(self.quantize_bin),
+ "--token-embedding-type", embedding_type,
+ str(self.input_model),
+ str(output_file),
+ "I2_S",
+ "1",
+ "1"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🔄 Quantizing with embedding type: {embedding_type}")
+ print(f"📥 Input: {self.input_model}")
+ print(f"📤 Output: {output_file}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ start_time = datetime.now()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=600 # 10 minute timeout
+ )
+
+ end_time = datetime.now()
+ duration = (end_time - start_time).total_seconds()
+
+ if result.returncode == 0:
+ # Get output file size
+ file_size_mb = output_file.stat().st_size / (1024 * 1024)
+
+ print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
+
+ # Record newly created file
+ if not file_already_existed:
+ self.newly_created_files.add(output_file)
+
+ # Print part of output
+ if result.stdout:
+ print("\n📊 Quantization output:")
+ print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
+
+ return True
+ else:
+ print(f"❌ Failed with return code {result.returncode}")
+ print(f"Error: {result.stderr}")
+ return False
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Timeout (exceeded 10 minutes)")
+ return False
+
+ except Exception as e:
+ print(f"❌ Exception: {e}")
+ return False
+
+ def benchmark_model(self, output_suffix):
+ """
+ Benchmark model
+
+ Args:
+ output_suffix: Output file suffix (lowercase format, e.g., q6_k)
+
+ Returns:
+ dict: Dictionary with benchmark results, or None if failed
+ """
+ model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ if not model_file.exists():
+ print(f"❌ Model file not found for benchmarking: {model_file}")
+ return None
+
+ cmd = [
+ str(self.bench_bin),
+ "-m", str(model_file),
+ "-p", "128",
+ "-n", "0",
+ "-t", "1,2,4,8",
+ "-ngl", "0"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🏃 Running benchmark for: {output_suffix}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=300 # 5 minute timeout
+ )
+
+ if result.returncode == 0:
+ print("✅ Benchmark completed successfully")
+ print("\n📊 Benchmark output:")
+ print(result.stdout)
+
+ # 解析输出
+ bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
+ return bench_results
+ else:
+ print(f"❌ Benchmark failed with return code {result.returncode}")
+ print(f"Error: {result.stderr}")
+ return None
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Benchmark timeout (exceeded 5 minutes)")
+ return None
+
+ except Exception as e:
+ print(f"❌ Benchmark exception: {e}")
+ return None
+
+ def parse_benchmark_output(self, output, output_suffix):
+ """
+ Parse benchmark output to extract t/s data (mean±std)
+
+ Args:
+ output: Benchmark command output
+ output_suffix: Output file suffix
+
+ Returns:
+ dict: Dictionary with parsed results
+ """
+ results = {
+ 'embedding_type': output_suffix,
+ 'threads_1': None,
+ 'threads_2': None,
+ 'threads_4': None,
+ 'threads_8': None,
+ }
+
+ # Parse table data
+ # Find lines containing pp128 and t/s
+ lines = output.strip().split('\n')
+
+ for line in lines:
+ # Skip header and separator lines
+ if '|' not in line or 'model' in line or '---' in line:
+ continue
+
+ # Try to extract data
+ # Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
+ parts = [p.strip() for p in line.split('|')]
+
+ if len(parts) >= 8 and 'pp128' in parts[6]:
+ threads_str = parts[5].strip()
+ throughput_str = parts[7].strip()
+
+ # Extract thread count
+ try:
+ threads = int(threads_str)
+ except:
+ continue
+
+ # Extract t/s data (format: "405.73 ± 3.69" or "405.73")
+ # Try to match "mean ± std" format
+ match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
+ if match_with_std:
+ mean = float(match_with_std.group(1))
+ std = float(match_with_std.group(2))
+ throughput = f"{mean:.2f}±{std:.2f}"
+ else:
+ # Only mean, no std
+ match = re.search(r'([\d.]+)', throughput_str)
+ if match:
+ throughput = f"{float(match.group(1)):.2f}"
+ else:
+ continue
+
+ # Store result based on thread count
+ if threads == 1:
+ results['threads_1'] = throughput
+ elif threads == 2:
+ results['threads_2'] = throughput
+ elif threads == 4:
+ results['threads_4'] = throughput
+ elif threads == 8:
+ results['threads_8'] = throughput
+
+ return results
+
+ def cleanup_model(self, output_suffix):
+ """
+ Cleanup model files (only delete newly created files)
+
+ Args:
+ output_suffix: Output file suffix
+ """
+ model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ if model_file in self.newly_created_files:
+ try:
+ model_file.unlink()
+ print(f"🗑️ Deleted newly created file: {model_file}")
+ self.newly_created_files.remove(model_file)
+ except Exception as e:
+ print(f"⚠️ Failed to delete {model_file}: {e}")
+ else:
+ print(f"ℹ️ Keeping existing file: {model_file}")
+
+ def run_all_quantizations(self, types_to_quantize):
+ """
+ Run all quantizations
+
+ Args:
+ types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
+ """
+ print(f"\n{'='*80}")
+ print(f"🚀 Starting Embedding Quantization and Benchmarking")
+ print(f"{'='*80}")
+ print(f"📥 Input model: {self.input_model}")
+ print(f"📤 Output directory: {self.output_dir}")
+ print(f"📊 Stats directory: {self.stats_dir}")
+ print(f"🔢 Total quantizations: {len(types_to_quantize)}")
+ print(f"{'='*80}\n")
+
+ total_start = datetime.now()
+
+ for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
+ print(f"\n{'#'*80}")
+ print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
+ print(f"{'#'*80}\n")
+
+ # Quantize model
+ success = self.quantize(embedding_type, output_suffix)
+
+ if not success:
+ print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
+ continue
+
+ # Run benchmark
+ bench_results = self.benchmark_model(output_suffix)
+
+ if bench_results:
+ self.results.append(bench_results)
+ else:
+ print(f"⚠️ Benchmark failed for {output_suffix}")
+
+ # Cleanup model files (only delete newly created files)
+ self.cleanup_model(output_suffix)
+
+ print(f"\n{'#'*80}")
+ print(f"✅ Completed {output_suffix}")
+ print(f"{'#'*80}\n")
+
+ total_end = datetime.now()
+ total_duration = (total_end - total_start).total_seconds()
+
+ # 保存结果到CSV
+ self.save_results_to_csv()
+
+ # 打印总结
+ self.print_summary(total_duration)
+
+ def save_results_to_csv(self):
+ """将benchmark结果保存到CSV文件"""
+ if not self.results:
+ print("⚠️ No results to save")
+ return
+
+ # Use user-specified CSV path, otherwise use default path
+ if self.csv_output:
+ csv_file = self.csv_output
+ # Ensure parent directory exists
+ csv_file.parent.mkdir(parents=True, exist_ok=True)
+ else:
+ csv_file = self.stats_dir / f"embedding_benchmark.csv"
+
+ print(f"\n💾 Saving results to: {csv_file}")
+
+ try:
+ with open(csv_file, 'w', newline='') as f:
+ fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+
+ writer.writeheader()
+ for result in self.results:
+ writer.writerow(result)
+
+ print(f"✅ Results saved successfully")
+
+ # Also print table
+ print(f"\n📊 Benchmark Results:")
+ print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
+ print("-" * 87)
+ for result in self.results:
+ t1 = result['threads_1'] if result['threads_1'] else "N/A"
+ t2 = result['threads_2'] if result['threads_2'] else "N/A"
+ t4 = result['threads_4'] if result['threads_4'] else "N/A"
+ t8 = result['threads_8'] if result['threads_8'] else "N/A"
+ print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
+
+ except Exception as e:
+ print(f"❌ Failed to save results: {e}")
+
+ def print_summary(self, total_duration):
+ """Print quantization summary"""
+ print(f"\n\n{'='*80}")
+ print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
+ print(f"{'='*80}\n")
+
+ successful = len(self.results)
+ total = len(self.results)
+
+ print(f"✅ Completed: {successful} benchmarks")
+ print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
+
+ if self.results:
+ if self.csv_output and self.csv_output.exists():
+ print(f"📁 Results saved to: {self.csv_output}")
+ else:
+ csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
+ if csv_files:
+ latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
+ print(f"📁 Results saved to: {latest_csv}")
+
+ print(f"\n{'='*80}\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
+ parser.add_argument('--input', '-i',
+ default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
+ help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
+ parser.add_argument('--output-dir', '-o',
+ default='../models/BitNet-b1.58-2B-4T',
+ help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
+ parser.add_argument('--quantize-bin', '-q',
+ default='../build/bin/llama-quantize',
+ help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
+ parser.add_argument('--bench-bin', '-b',
+ default='../build/bin/llama-bench',
+ help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
+ parser.add_argument('--stats-dir',
+ default='../stats',
+ help='Directory to save benchmark results (default: ../stats)')
+ parser.add_argument('--csv-output', '-c',
+ help='Custom path for CSV output file (e.g., stats/my_results.csv)')
+ parser.add_argument('--types', '-t',
+ nargs='+',
+ help='Specific types to quantize (e.g., f32 q6_k q4_0)')
+ parser.add_argument('--skip-existing', '-s',
+ action='store_true',
+ help='Skip quantization if output file already exists (will still benchmark existing files)')
+
+ args = parser.parse_args()
+
+ # Define all supported quantization types
+ # Format: (embedding_type for command line, output_suffix for filename)
+ all_types = [
+ ('F32', 'f32'),
+ ('F16', 'f16'),
+ ('Q8_0', 'q8_0'),
+ ('Q6_K', 'q6_k'),
+ ('Q5_0', 'q5_0'),
+ ('Q4_0', 'q4_0'),
+ ('Q3_K', 'q3_k'),
+ ('TQ2_0', 'tq2_0'),
+ ]
+
+ # If specific types are specified, filter the list
+ if args.types:
+ types_lower = [t.lower() for t in args.types]
+ types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
+ if not types_to_quantize:
+ print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
+ return
+ else:
+ types_to_quantize = all_types
+
+ # If skip existing files is enabled, no need to filter
+ # Because new logic will automatically detect and skip during quantization, but will still benchmark
+
+ # 创建量化器并运行
+ try:
+ quantizer = EmbeddingQuantizer(
+ args.input,
+ args.output_dir,
+ args.quantize_bin,
+ args.bench_bin,
+ args.stats_dir,
+ args.csv_output
+ )
+ quantizer.run_all_quantizations(types_to_quantize)
+ except FileNotFoundError as e:
+ print(f"❌ Error: {e}")
+ return 1
+ except KeyboardInterrupt:
+ print("\n\n⚠️ Quantization interrupted by user")
+ return 1
+ except Exception as e:
+ print(f"\n❌ Unexpected error: {e}")
+ import traceback
+ traceback.print_exc()
+ return 1
+
+
+if __name__ == "__main__":
+ exit(main() or 0)
diff --git a/utils/test_gemm_kernel.cpp b/utils/test_gemm_kernel.cpp
new file mode 100644
index 0000000..36964ce
--- /dev/null
+++ b/utils/test_gemm_kernel.cpp
@@ -0,0 +1,274 @@
+/**
+ * Standalone benchmark for ggml_gemm_i2_i8_s kernel
+ *
+ * This program tests the performance of the ggml_gemm_i2_i8_s kernel
+ * with configurable matrix sizes and iteration counts.
+ *
+ * Usage: ./test_gemm_kernel [options]
+ * -n : embedding dimension (must be divisible by 4, default: 2048)
+ * -r : number of rows in matrix Y (default: 32)
+ * -c : number of columns in matrix X (default: 128)
+ * -i : number of iterations (default: 1000)
+ * -w : number of warmup iterations (default: 10)
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// Include necessary headers
+#include "../include/gemm-config.h"
+
+// Function declarations (from ggml-quants.h)
+extern "C" void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc);
+
+// GEMM kernel definition
+void ggml_gemm_i2_i8_s(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+#if defined(ACT_PARALLEL)
+ const int64_t row_block = ROW_BLOCK_SIZE;
+ const int64_t col_block = COL_BLOCK_SIZE;
+
+ for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
+ int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
+ for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
+ int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
+ const void * vy_r = (const uint8_t *)vy + r0 * n;
+ for (int64_t c = 0; c < cur_c; ++c) {
+ const int64_t col = c0 + c;
+ float * s_col = s + col;
+ const void * vx_col = (const uint8_t *)vx + col * n / 4;
+ ggml_vec_dot_i2_i8_s(n, s_col + r0 * bs, bs, vx_col, n, vy_r, n, cur_r);
+ }
+ }
+ }
+#else
+ const int64_t row_block = ROW_BLOCK_SIZE;
+ const int64_t col_block = COL_BLOCK_SIZE;
+
+ for (int64_t r0 = 0; r0 < nr; r0 += row_block) {
+ int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0);
+ for (int64_t c0 = 0; c0 < nc; c0 += col_block) {
+ int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0);
+ const void * vx_c = (const uint8_t *)vx + c0 * n / 4;
+ for (int64_t r = 0; r < cur_r; ++r) {
+ const int64_t row = r0 + r;
+ float * s_row = s + row * bs;
+ const void * vy_row = (const uint8_t *)vy + row * n;
+ ggml_vec_dot_i2_i8_s(n, s_row + c0, bs, vx_c, n, vy_row, n, cur_c);
+ }
+ }
+ }
+#endif
+}
+
+// Helper function to get current time in nanoseconds
+double get_time_ns() {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return ts.tv_sec * 1e9 + ts.tv_nsec;
+}
+
+// Initialize matrix with random i2 values (2-bit quantized)
+void init_matrix_i2(uint8_t* data, int n, int cols) {
+ // i2 format: 4 values per byte (2 bits each)
+ int total_bytes = n * cols / 4;
+ for (int i = 0; i < total_bytes; i++) {
+ data[i] = rand() & 0xFF;
+ }
+}
+
+// Initialize matrix with random i8 values
+void init_matrix_i8(int8_t* data, int n, int rows) {
+ int total_elements = n * rows;
+ for (int i = 0; i < total_elements; i++) {
+ data[i] = (int8_t)((rand() % 256) - 128);
+ }
+}
+
+// Benchmark configuration
+struct BenchmarkConfig {
+ int n; // embedding dimension (must be divisible by 4)
+ int nr; // number of rows in Y matrix
+ int nc; // number of columns in X matrix
+ int iterations; // number of benchmark iterations
+ int warmup; // number of warmup iterations
+};
+
+void print_config(const BenchmarkConfig& config) {
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf("Benchmark Configuration:\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf(" Embedding dimension (n) : %d\n", config.n);
+ printf(" Matrix Y rows (nr) : %d\n", config.nr);
+ printf(" Matrix X columns (nc) : %d\n", config.nc);
+ printf(" Iterations : %d\n", config.iterations);
+ printf(" Warmup iterations : %d\n", config.warmup);
+ printf("\nMatrix sizes:\n");
+ printf(" X (i2): %d x %d (%.2f KB)\n", config.nc, config.n,
+ (config.nc * config.n / 4) / 1024.0);
+ printf(" Y (i8): %d x %d (%.2f KB)\n", config.nr, config.n,
+ (config.nr * config.n) / 1024.0);
+ printf(" S (f32): %d x %d (%.2f KB)\n", config.nr, config.nc,
+ (config.nr * config.nc * sizeof(float)) / 1024.0);
+ printf("\nGEMM Config:\n");
+#if defined(ACT_PARALLEL)
+ printf(" ACT_PARALLEL : ON\n");
+#else
+ printf(" ACT_PARALLEL : OFF\n");
+#endif
+ printf(" ROW_BLOCK_SIZE : %d\n", ROW_BLOCK_SIZE);
+ printf(" COL_BLOCK_SIZE : %d\n", COL_BLOCK_SIZE);
+ printf(" PARALLEL_SIZE : %d\n", PARALLEL_SIZE);
+ printf("=" "=%.78s\n\n", "===============================================================================");
+}
+
+void run_benchmark(const BenchmarkConfig& config) {
+ // Allocate matrices
+ printf("Allocating matrices...\n");
+
+ // X matrix (i2 format): nc x n, but stored as nc x (n/4) bytes
+ uint8_t* X = (uint8_t*)malloc(config.nc * config.n / 4);
+
+ // Y matrix (i8 format): nr x n
+ int8_t* Y = (int8_t*)malloc(config.nr * config.n);
+
+ // Result matrix (float32): nr x nc
+ float* S = (float*)malloc(config.nr * config.nc * sizeof(float));
+
+ if (!X || !Y || !S) {
+ fprintf(stderr, "Failed to allocate memory\n");
+ exit(1);
+ }
+
+ // Initialize matrices with random data
+ printf("Initializing matrices with random data...\n");
+ srand(time(NULL));
+ init_matrix_i2(X, config.n, config.nc);
+ init_matrix_i8(Y, config.n, config.nr);
+ memset(S, 0, config.nr * config.nc * sizeof(float));
+
+ // Warmup
+ printf("Running %d warmup iterations...\n", config.warmup);
+ for (int i = 0; i < config.warmup; i++) {
+ ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
+ }
+
+ // Benchmark
+ printf("Running %d benchmark iterations...\n", config.iterations);
+ double total_time = 0.0;
+ double min_time = 1e20;
+ double max_time = 0.0;
+
+ for (int i = 0; i < config.iterations; i++) {
+ double start = get_time_ns();
+ ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc);
+ double end = get_time_ns();
+
+ double elapsed = end - start;
+ total_time += elapsed;
+ if (elapsed < min_time) min_time = elapsed;
+ if (elapsed > max_time) max_time = elapsed;
+
+ if ((i + 1) % 100 == 0) {
+ printf(" Progress: %d/%d iterations\n", i + 1, config.iterations);
+ }
+ }
+
+ // Calculate statistics
+ double avg_time_ns = total_time / config.iterations;
+ double avg_time_ms = avg_time_ns / 1e6;
+ double min_time_ms = min_time / 1e6;
+ double max_time_ms = max_time / 1e6;
+
+ // Calculate GFLOPS
+ // For GEMM: nr x nc x n multiply-adds = 2 * nr * nc * n FLOPs
+ double flops = 2.0 * config.nr * config.nc * config.n;
+ double gflops = (flops / avg_time_ns);
+
+ // Calculate throughput (tokens/s assuming each column is a token)
+ double throughput = (config.nc * 1e9) / avg_time_ns;
+
+ // Print results
+ printf("\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf("Benchmark Results:\n");
+ printf("=" "=%.78s\n", "===============================================================================");
+ printf(" Average time : %.3f ms\n", avg_time_ms);
+ printf(" Min time : %.3f ms\n", min_time_ms);
+ printf(" Max time : %.3f ms\n", max_time_ms);
+ printf(" Std dev : %.3f ms\n", sqrt((max_time_ms - min_time_ms) * (max_time_ms - min_time_ms) / 12));
+ printf("\nPerformance:\n");
+ printf(" GFLOPS : %.2f\n", gflops);
+ printf(" Throughput : %.2f tokens/s\n", throughput);
+ printf(" Latency/token : %.3f us\n", (avg_time_ms * 1000) / config.nc);
+ printf("=" "=%.78s\n", "===============================================================================");
+
+ // Cleanup
+ free(X);
+ free(Y);
+ free(S);
+}
+
+void print_usage(const char* program) {
+ printf("Usage: %s [options]\n", program);
+ printf("Options:\n");
+ printf(" -n Embedding dimension (must be divisible by 4, default: 2048)\n");
+ printf(" -r Number of rows in matrix Y (default: 32)\n");
+ printf(" -c Number of columns in matrix X (default: 128)\n");
+ printf(" -i Number of iterations (default: 1000)\n");
+ printf(" -w Number of warmup iterations (default: 10)\n");
+ printf(" -h Show this help message\n");
+}
+
+int main(int argc, char** argv) {
+ BenchmarkConfig config = {
+ .n = 2048,
+ .nr = 32,
+ .nc = 128,
+ .iterations = 1000,
+ .warmup = 10
+ };
+
+ // Parse command line arguments
+ for (int i = 1; i < argc; i++) {
+ if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) {
+ config.n = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) {
+ config.nr = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) {
+ config.nc = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-i") == 0 && i + 1 < argc) {
+ config.iterations = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-w") == 0 && i + 1 < argc) {
+ config.warmup = atoi(argv[++i]);
+ } else if (strcmp(argv[i], "-h") == 0) {
+ print_usage(argv[0]);
+ return 0;
+ } else {
+ fprintf(stderr, "Unknown option: %s\n", argv[i]);
+ print_usage(argv[0]);
+ return 1;
+ }
+ }
+
+ // Validate configuration
+ if (config.n % 4 != 0) {
+ fprintf(stderr, "Error: Embedding dimension (-n) must be divisible by 4\n");
+ return 1;
+ }
+
+ if (config.n <= 0 || config.nr <= 0 || config.nc <= 0 || config.iterations <= 0) {
+ fprintf(stderr, "Error: All size parameters must be positive\n");
+ return 1;
+ }
+
+ // Run benchmark
+ print_config(config);
+ run_benchmark(config);
+
+ return 0;
+}
diff --git a/utils/test_parallel_strategy.sh b/utils/test_parallel_strategy.sh
new file mode 100755
index 0000000..44da140
--- /dev/null
+++ b/utils/test_parallel_strategy.sh
@@ -0,0 +1,277 @@
+#!/bin/bash
+
+# Script: Test different GEMM parallel strategy performance
+# Strategies: weight-parallel and no-parallel
+# Thread counts: 1,2,4,8,12,16
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
+GEMM_CONFIG="$PROJECT_ROOT/include/gemm-config.h"
+GEMM_CONFIG_BACKUP="$PROJECT_ROOT/include/gemm-config.h.bak"
+BUILD_DIR="$PROJECT_ROOT/build"
+STATS_DIR="$PROJECT_ROOT/stats"
+CSV_FILE="$STATS_DIR/test_parallel_strategy_benchmark.csv"
+MODEL_PATH="$PROJECT_ROOT/models/BitNet-b1.58-2B-4T/ggml-model-original.gguf"
+BENCHMARK_CMD="./build/bin/llama-bench"
+THREADS_LIST="1 2 4 8 12 16"
+
+# Color output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+NC='\033[0m' # No Color
+
+log_info() {
+ echo -e "${GREEN}[INFO]${NC} $1"
+}
+
+log_warn() {
+ echo -e "${YELLOW}[WARN]${NC} $1"
+}
+
+log_error() {
+ echo -e "${RED}[ERROR]${NC} $1"
+}
+
+# Check prerequisites
+check_prerequisites() {
+ log_info "Checking prerequisites..."
+
+ if [ ! -f "$GEMM_CONFIG" ]; then
+ log_error "gemm-config.h not found: $GEMM_CONFIG"
+ exit 1
+ fi
+
+ if [ ! -f "$MODEL_PATH" ]; then
+ log_error "Model file not found: $MODEL_PATH"
+ exit 1
+ fi
+
+ if [ ! -d "$BUILD_DIR" ]; then
+ log_error "Build directory not found: $BUILD_DIR"
+ exit 1
+ fi
+
+ if [ ! -f "$BUILD_DIR/bin/llama-bench" ]; then
+ log_warn "llama-bench executable not found, building..."
+ build_project
+ fi
+
+ if [ ! -d "$STATS_DIR" ]; then
+ log_info "Creating stats directory..."
+ mkdir -p "$STATS_DIR"
+ fi
+
+ log_info "Prerequisites check completed"
+}
+
+# Backup original config file
+backup_config() {
+ log_info "Backing up gemm-config.h..."
+ cp "$GEMM_CONFIG" "$GEMM_CONFIG_BACKUP"
+ log_info "Backup completed: $GEMM_CONFIG_BACKUP"
+}
+
+# Restore original config file
+restore_config() {
+ if [ -f "$GEMM_CONFIG_BACKUP" ]; then
+ log_info "Restoring original gemm-config.h..."
+ cp "$GEMM_CONFIG_BACKUP" "$GEMM_CONFIG"
+ rm "$GEMM_CONFIG_BACKUP"
+ log_info "Restore completed"
+ else
+ log_warn "Backup file not found, skipping restore"
+ fi
+}
+
+# Set activation-parallel configuration (keep original ACT_PARALLEL)
+set_activation_parallel() {
+ log_info "Configuration: activation-parallel (keeping #define ACT_PARALLEL)"
+ log_info "Configuration completed"
+}
+
+# Set weight-parallel configuration (remove ACT_PARALLEL)
+set_weight_parallel() {
+ log_info "Configuration: weight-parallel (removing #define ACT_PARALLEL)"
+
+ # Remove ACT_PARALLEL definition
+ sed -i '/#define ACT_PARALLEL/d' "$GEMM_CONFIG"
+
+ # Verify modification
+ if grep -q "^#define ACT_PARALLEL" "$GEMM_CONFIG"; then
+ log_error "Failed to remove ACT_PARALLEL"
+ exit 1
+ fi
+ log_info "Configuration completed"
+}
+
+# Set no-parallel configuration (remove ACT_PARALLEL + modify SIZE to 1)
+set_no_parallel() {
+ log_info "Configuration: no-parallel (removing #define ACT_PARALLEL + modifying SIZE to 1)"
+
+ # Remove ACT_PARALLEL definition
+ sed -i '/#define ACT_PARALLEL/d' "$GEMM_CONFIG"
+
+ # Modify all ROW_BLOCK_SIZE and COL_BLOCK_SIZE to 1
+ sed -i 's/#define ROW_BLOCK_SIZE [0-9]\+/#define ROW_BLOCK_SIZE 1/g' "$GEMM_CONFIG"
+ sed -i 's/#define COL_BLOCK_SIZE [0-9]\+/#define COL_BLOCK_SIZE 1/g' "$GEMM_CONFIG"
+
+ log_info "Configuration completed"
+}
+
+# Build project
+build_project() {
+ log_info "Building project..."
+ cd "$PROJECT_ROOT"
+
+ if [ ! -f "$BUILD_DIR/Makefile" ]; then
+ log_info "First build, running cmake..."
+ cmake -B "$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release > /dev/null 2>&1
+ fi
+
+ cd "$BUILD_DIR"
+ make -j$(nproc) llama-bench > /dev/null 2>&1
+
+ if [ ! -f "./bin/llama-bench" ]; then
+ log_error "Build failed"
+ exit 1
+ fi
+
+ log_info "Build completed"
+ cd "$PROJECT_ROOT"
+}
+
+# Run benchmark test
+run_benchmark() {
+ local strategy=$1
+ local threads=$2
+
+ cd "$PROJECT_ROOT"
+
+ # Run llama-bench
+ local output=$($BENCHMARK_CMD -m "$MODEL_PATH" -p 128 -n 0 -t "$threads" -ngl 0 2>&1)
+
+ # Extract line containing "pp128"
+ local line=$(echo "$output" | grep "pp128" | tail -1)
+
+ if [ -z "$line" ]; then
+ return 1
+ fi
+
+ echo "$line"
+}
+
+# Extract throughput value from benchmark output
+extract_throughput() {
+ local line=$1
+
+ # Remove any leading/trailing whitespace and log messages
+ # The line format is: | model | size | params | backend | threads | test | throughput |
+ # We need to extract the last field which contains the throughput in format "XXX.XX ± YY.YY"
+ local throughput=$(echo "$line" | awk -F'|' '{print $NF}' | xargs | sed 's/\[.*\]//' | xargs)
+
+ echo "$throughput"
+}
+
+# Initialize CSV file
+init_csv() {
+ log_info "Initializing CSV file: $CSV_FILE"
+
+ cat > "$CSV_FILE" << 'EOF'
+Strategy,Threads,Throughput
+EOF
+
+ log_info "CSV file created"
+}
+
+# Add result to CSV
+add_to_csv() {
+ local strategy=$1
+ local threads=$2
+ local throughput=$3
+
+ echo "$strategy,$threads,$throughput" >> "$CSV_FILE"
+}
+
+# Main function
+main() {
+ log_info "Starting GEMM parallel strategy benchmark tests"
+ log_info "================================================"
+
+ # Check prerequisites
+ check_prerequisites
+
+ # Backup original configuration
+ backup_config
+
+ # Initialize CSV file
+ init_csv
+
+ # Define strategies to test
+ local strategies=("activation-parallel" "weight-parallel" "no-parallel")
+
+ for strategy in "${strategies[@]}"; do
+ log_info "================================================"
+ log_info "Testing strategy: $strategy"
+ log_info "================================================"
+
+ # Restore to original configuration
+ restore_config
+ backup_config
+
+ # Apply configuration based on strategy
+ case $strategy in
+ activation-parallel)
+ set_activation_parallel
+ ;;
+ weight-parallel)
+ set_weight_parallel
+ ;;
+ no-parallel)
+ set_no_parallel
+ ;;
+ esac
+
+ # Rebuild project to apply new configuration
+ log_info "Rebuilding project to apply new configuration..."
+ build_project
+
+ # Run test for each thread count
+ for threads in $THREADS_LIST; do
+ log_info ""
+ log_info "Strategy: $strategy, Threads: $threads"
+
+ # Run test (capture only output, not log messages)
+ local result=$(run_benchmark "$strategy" "$threads")
+ local test_status=$?
+
+ if [ $test_status -eq 0 ]; then
+ # Extract throughput value from the result line
+ local throughput=$(extract_throughput "$result")
+ log_info "Throughput: $throughput"
+
+ # Add to CSV
+ add_to_csv "$strategy" "$threads" "$throughput"
+ else
+ log_warn "Test failed for strategy $strategy, threads $threads"
+ fi
+
+ sleep 2 # Give system time to cool down
+ done
+ done
+
+ # Restore original configuration
+ restore_config
+
+ log_info "================================================"
+ log_info "Test completed!"
+ log_info "Results saved to: $CSV_FILE"
+ log_info "================================================"
+
+ # Display CSV content
+ log_info "CSV file content:"
+ cat "$CSV_FILE"
+}
+
+# Run main function
+main "$@"
diff --git a/utils/test_perplexity.py b/utils/test_perplexity.py
new file mode 100644
index 0000000..f2d9788
--- /dev/null
+++ b/utils/test_perplexity.py
@@ -0,0 +1,608 @@
+#!/usr/bin/env python3
+"""
+Perplexity Test Script
+Tests GGUF model perplexity on multiple datasets using llama-perplexity.
+"""
+
+import os
+import subprocess
+import time
+import csv
+import re
+from datetime import datetime
+from pathlib import Path
+import argparse
+import tempfile
+import shutil
+import statistics
+
+
+class PerplexityTester:
+ def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity",
+ data_dir="../data", output_dir="perplexity_results", quick_mode=False,
+ quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None):
+ self.model_path = Path(model_path)
+ self.llama_perplexity_bin = Path(llama_perplexity_bin)
+ self.quantize_bin = Path(quantize_bin)
+ self.data_dir = Path(data_dir)
+ self.output_dir = Path(output_dir)
+ self.quick_mode = quick_mode
+ self.test_embeddings = test_embeddings
+ self.csv_output = Path(csv_output) if csv_output else None
+ self.results = []
+ self.created_models = set() # Track newly created model files
+ self.temp_files = [] # Track temporary files for cleanup
+
+ # Embedding types to test
+ self.embedding_types = [
+ ('F32', 'f32'),
+ ('F16', 'f16'),
+ ('Q8_0', 'q8_0'),
+ ('Q6_K', 'q6_k'),
+ ('Q5_0', 'q5_0'),
+ ('Q4_0', 'q4_0'),
+ ('Q3_K', 'q3_k'),
+ ('TQ2_0', 'tq2_0'),
+ ]
+
+ # Create output directory
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Verify llama-perplexity binary exists
+ if not self.llama_perplexity_bin.exists():
+ raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}")
+
+ # Verify quantize binary exists if testing embeddings
+ if self.test_embeddings and not self.quantize_bin.exists():
+ raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}")
+
+ # Verify model file exists
+ if not self.model_path.exists():
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
+
+ def find_datasets(self):
+ """Find all test.txt files in dataset directories."""
+ datasets = []
+
+ if not self.data_dir.exists():
+ print(f"❌ Data directory not found: {self.data_dir}")
+ return datasets
+
+ print(f"\n🔍 Searching for datasets in {self.data_dir}...")
+
+ # Look for test.txt files in subdirectories
+ for dataset_dir in sorted(self.data_dir.iterdir()):
+ if dataset_dir.is_dir():
+ test_file = dataset_dir / "test.txt"
+ if test_file.exists():
+ size_mb = test_file.stat().st_size / (1024 * 1024)
+ datasets.append({
+ 'name': dataset_dir.name,
+ 'path': test_file,
+ 'size': test_file.stat().st_size,
+ 'size_mb': size_mb
+ })
+ print(f" ✅ {dataset_dir.name:<20} ({size_mb:.2f} MB)")
+ else:
+ print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)")
+
+ return datasets
+
+ def create_quick_dataset(self, dataset_path, num_chars=4096):
+ """Create a temporary dataset with only the first N characters for quick testing."""
+ temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8')
+ self.temp_files.append(temp_file.name)
+
+ try:
+ with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f:
+ content = f.read(num_chars)
+ temp_file.write(content)
+ temp_file.close()
+ return Path(temp_file.name)
+ except Exception as e:
+ print(f"⚠️ Failed to create quick dataset: {e}")
+ temp_file.close()
+ return dataset_path
+
+ def cleanup_temp_files(self):
+ """Clean up temporary files."""
+ for temp_file in self.temp_files:
+ try:
+ os.unlink(temp_file)
+ except:
+ pass
+ self.temp_files = []
+
+ def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None):
+ """Run perplexity test on a single dataset."""
+ test_model = model_override if model_override else self.model_path
+
+ print(f"\n{'='*80}")
+ print(f"📊 Testing on dataset: {dataset_name}")
+ print(f" File: {dataset_path}")
+ print(f" Model: {test_model.name}")
+ print(f"{'='*80}")
+
+ cmd = [
+ str(self.llama_perplexity_bin),
+ "-m", str(test_model),
+ "-f", str(dataset_path),
+ "-t", str(threads),
+ "-c", str(ctx_size),
+ "-ngl", "0" # CPU only
+ ]
+
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"⏱️ Starting test...\n")
+
+ start_time = time.time()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ timeout=3600, # 1 hour timeout
+ cwd=os.getcwd()
+ )
+
+ elapsed_time = time.time() - start_time
+
+ if result.returncode == 0:
+ # Parse perplexity from output (check both stdout and stderr)
+ combined_output = result.stdout + "\n" + result.stderr
+ ppl = self.parse_perplexity(combined_output)
+
+ if ppl is not None:
+ print(f"\n✅ Perplexity: {ppl}")
+ print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)")
+ status = "success"
+ else:
+ print(f"\n⚠️ Test completed but could not parse perplexity")
+ print(f"Last 500 chars of stdout:")
+ print(result.stdout[-500:])
+ print(f"Last 500 chars of stderr:")
+ print(result.stderr[-500:])
+ status = "parse_error"
+ ppl = None
+ else:
+ print(f"\n❌ Test failed with return code {result.returncode}")
+ print(f"Error: {result.stderr[:500]}")
+ status = "failed"
+ ppl = None
+ elapsed_time = time.time() - start_time
+
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': ppl,
+ 'time': elapsed_time,
+ 'status': status,
+ 'stdout': result.stdout,
+ 'stderr': result.stderr
+ }
+
+ except subprocess.TimeoutExpired:
+ elapsed_time = time.time() - start_time
+ print(f"\n❌ Timeout after {elapsed_time:.2f}s")
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': None,
+ 'time': elapsed_time,
+ 'status': 'timeout',
+ 'stdout': '',
+ 'stderr': 'Test exceeded 1 hour timeout'
+ }
+ except Exception as e:
+ elapsed_time = time.time() - start_time
+ print(f"\n❌ Error: {e}")
+ return {
+ 'dataset': dataset_name,
+ 'perplexity': None,
+ 'time': elapsed_time,
+ 'status': 'error',
+ 'stdout': '',
+ 'stderr': str(e)
+ }
+
+ def parse_perplexity(self, output):
+ """Parse perplexity value (mean±std format) from llama-perplexity output."""
+ # First try to match "PPL = mean +/- std" format
+ pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)'
+ match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE)
+ if match:
+ try:
+ mean = float(match.group(1))
+ std = float(match.group(2))
+ return f"{mean:.4f}±{std:.4f}"
+ except ValueError:
+ pass
+
+ # Fallback to patterns without std
+ patterns = [
+ r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)',
+ r'Final perplexity:\s*(\d+\.?\d*)',
+ r'PPL\s*=\s*(\d+\.?\d*)',
+ r'PPL:\s*(\d+\.?\d*)',
+ r'perplexity:\s*(\d+\.?\d*)',
+ r'ppl\s*=\s*(\d+\.?\d*)',
+ r'Perplexity:\s*(\d+\.?\d*)',
+ ]
+
+ for pattern in patterns:
+ match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE)
+ if match:
+ try:
+ return f"{float(match.group(1)):.4f}"
+ except ValueError:
+ continue
+
+ return None
+
+ def quantize_embedding(self, embedding_type, output_suffix):
+ """
+ Quantize model with specific embedding type.
+
+ Args:
+ embedding_type: Token embedding type (uppercase, e.g., 'Q6_K')
+ output_suffix: Output file suffix (lowercase, e.g., 'q6_k')
+
+ Returns:
+ Path to quantized model or None if failed
+ """
+ # Construct output path
+ model_dir = self.model_path.parent
+ output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
+
+ # Check if file already exists
+ file_existed = output_path.exists()
+
+ if file_existed:
+ print(f"ℹ️ Model already exists: {output_path.name}")
+ return output_path
+
+ cmd = [
+ str(self.quantize_bin),
+ "--token-embedding-type", embedding_type,
+ str(self.model_path),
+ str(output_path),
+ "I2_S",
+ "1",
+ "1"
+ ]
+
+ print(f"\n{'='*80}")
+ print(f"🔄 Quantizing with embedding type: {embedding_type}")
+ print(f"📥 Input: {self.model_path.name}")
+ print(f"📤 Output: {output_path.name}")
+ print(f"💻 Command: {' '.join(cmd)}")
+ print(f"{'='*80}\n")
+
+ start_time = time.time()
+
+ try:
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=600 # 10 minutes timeout
+ )
+
+ duration = time.time() - start_time
+
+ if result.returncode == 0:
+ file_size_mb = output_path.stat().st_size / (1024 * 1024)
+ print(f"✅ Quantization successful!")
+ print(f" Duration: {duration:.2f}s")
+ print(f" Size: {file_size_mb:.2f} MB")
+
+ # Mark as newly created
+ self.created_models.add(output_path)
+ return output_path
+ else:
+ print(f"❌ Quantization failed with return code {result.returncode}")
+ print(f"Error: {result.stderr[:500]}")
+ return None
+
+ except subprocess.TimeoutExpired:
+ print(f"❌ Quantization timeout (exceeded 10 minutes)")
+ return None
+ except Exception as e:
+ print(f"❌ Quantization error: {e}")
+ return None
+
+ def cleanup_model(self, model_path):
+ """Delete model file if it was created during this session."""
+ if model_path in self.created_models:
+ try:
+ model_path.unlink()
+ print(f"🗑️ Deleted: {model_path.name}")
+ self.created_models.remove(model_path)
+ except Exception as e:
+ print(f"⚠️ Failed to delete {model_path.name}: {e}")
+ else:
+ print(f"ℹ️ Keeping existing file: {model_path.name}")
+
+ def run_all_tests(self, threads=16, ctx_size=512):
+ """Run perplexity tests on all datasets."""
+ datasets = self.find_datasets()
+
+ if not datasets:
+ print(f"\n❌ No datasets found in {self.data_dir}")
+ print(f" Make sure each dataset directory has a test.txt file")
+ return
+
+ # Quick mode: test all datasets but only first 4096 chars with smaller context
+ if self.quick_mode:
+ ctx_size = min(ctx_size, 128) # Use smaller context in quick mode
+ print(f"\n⚡ QUICK TEST MODE ENABLED")
+ print(f" - Testing all datasets with first 4096 characters only")
+ print(f" - Using reduced context size: {ctx_size}")
+
+ # Determine models to test
+ if self.test_embeddings:
+ print(f"\n{'='*80}")
+ print(f"🧪 EMBEDDING QUANTIZATION TEST MODE")
+ print(f"{'='*80}")
+ print(f"📦 Base model: {self.model_path.name}")
+ print(f"🔢 Embedding types to test: {len(self.embedding_types)}")
+ print(f"📊 Datasets: {len(datasets)}")
+ print(f"🧵 Threads: {threads}")
+ print(f"📏 Context size: {ctx_size}")
+ print(f"{'='*80}")
+
+ total_start = time.time()
+
+ # Test each embedding type
+ for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1):
+ print(f"\n\n{'#'*80}")
+ print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})")
+ print(f"{'#'*80}")
+
+ # Quantize model
+ quantized_model = self.quantize_embedding(embedding_type, output_suffix)
+
+ if quantized_model is None:
+ print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure")
+ continue
+
+ # Test on all datasets
+ for j, dataset in enumerate(datasets, 1):
+ print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...")
+
+ # Use quick dataset if in quick mode
+ test_path = dataset['path']
+ if self.quick_mode:
+ test_path = self.create_quick_dataset(dataset['path'])
+
+ result = self.run_perplexity_test(
+ f"{dataset['name']}_embed-{output_suffix}",
+ test_path,
+ threads,
+ ctx_size,
+ model_override=quantized_model
+ )
+ self.results.append(result)
+
+ # Cleanup model after testing
+ print(f"\n🧹 Cleaning up {output_suffix} model...")
+ self.cleanup_model(quantized_model)
+
+ print(f"\n{'#'*80}")
+ print(f"✅ Completed {output_suffix}")
+ print(f"{'#'*80}")
+
+ total_time = time.time() - total_start
+
+ else:
+ # Regular single model test
+ print(f"\n{'='*80}")
+ print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}")
+ print(f"{'='*80}")
+ print(f"📦 Model: {self.model_path.name}")
+ print(f"📁 Model path: {self.model_path}")
+ print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}")
+ print(f"🧵 Threads: {threads}")
+ print(f"📏 Context size: {ctx_size}")
+ print(f"{'='*80}")
+
+ total_start = time.time()
+
+ # Run tests
+ for i, dataset in enumerate(datasets, 1):
+ print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...")
+
+ # Use quick dataset if in quick mode
+ test_path = dataset['path']
+ if self.quick_mode:
+ test_path = self.create_quick_dataset(dataset['path'])
+
+ result = self.run_perplexity_test(
+ dataset['name'],
+ test_path,
+ threads,
+ ctx_size
+ )
+ self.results.append(result)
+
+ total_time = time.time() - total_start
+
+ # Clean up temporary files
+ if self.quick_mode:
+ print(f"\n🧹 Cleaning up temporary files...")
+ self.cleanup_temp_files()
+
+ # Save results
+ self.save_results()
+
+ # Print summary
+ self.print_summary(total_time)
+
+ def save_results(self):
+ """Save results to CSV file."""
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ model_name = self.model_path.stem
+
+ # Use custom CSV path if provided
+ if self.csv_output:
+ csv_file = self.csv_output
+ # Create parent directory if needed
+ csv_file.parent.mkdir(parents=True, exist_ok=True)
+ else:
+ csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv"
+
+ print(f"\n💾 Saving results...")
+
+ with open(csv_file, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status'])
+ writer.writeheader()
+ for result in self.results:
+ writer.writerow({
+ 'dataset': result['dataset'],
+ 'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A',
+ 'time_seconds': f"{result['time']:.2f}",
+ 'status': result['status']
+ })
+
+ print(f" ✅ CSV saved: {csv_file}")
+
+ # Save detailed log
+ log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log"
+ with open(log_file, 'w') as f:
+ f.write(f"Perplexity Test Results\n")
+ f.write(f"{'='*80}\n")
+ f.write(f"Model: {self.model_path}\n")
+ f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+ f.write(f"{'='*80}\n\n")
+
+ for result in self.results:
+ f.write(f"\n{'='*80}\n")
+ f.write(f"Dataset: {result['dataset']}\n")
+ f.write(f"Perplexity: {result['perplexity']}\n")
+ f.write(f"Time: {result['time']:.2f}s\n")
+ f.write(f"Status: {result['status']}\n")
+ f.write(f"\nOutput:\n{result['stdout']}\n")
+ if result['stderr']:
+ f.write(f"\nErrors:\n{result['stderr']}\n")
+
+ print(f" ✅ Log saved: {log_file}")
+
+ def print_summary(self, total_time):
+ """Print summary of all tests."""
+ print(f"\n\n{'='*80}")
+ print(f"📊 TEST SUMMARY")
+ print(f"{'='*80}\n")
+
+ # Sort results by perplexity (lower is better)
+ successful = [r for r in self.results if r['perplexity'] is not None]
+ failed = [r for r in self.results if r['perplexity'] is None]
+
+ if successful:
+ # Extract numeric value from "mean±std" format for sorting
+ def get_ppl_value(result):
+ ppl = result['perplexity']
+ if isinstance(ppl, str) and '±' in ppl:
+ return float(ppl.split('±')[0])
+ elif isinstance(ppl, str):
+ try:
+ return float(ppl)
+ except ValueError:
+ return float('inf')
+ return ppl
+
+ successful_sorted = sorted(successful, key=get_ppl_value)
+
+ print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}")
+ print(f"{'-'*80}")
+
+ for result in successful_sorted:
+ ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A'
+ print(f"{result['dataset']:<20} {ppl_str:>20} "
+ f"{result['time']:>12.2f} {result['status']:<15}")
+
+ best_ppl = str(successful_sorted[0]['perplexity'])
+ print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} "
+ f"(PPL: {best_ppl})")
+
+ if failed:
+ print(f"\n❌ Failed tests ({len(failed)}):")
+ for result in failed:
+ print(f" - {result['dataset']}: {result['status']}")
+
+ print(f"\n{'='*80}")
+ print(f"✅ Completed: {len(successful)}/{len(self.results)}")
+ print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
+ print(f"📁 Results saved in: {self.output_dir}")
+ print(f"{'='*80}\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets')
+ parser.add_argument('--model', '-m',
+ required=True,
+ help='Path to GGUF model file')
+ parser.add_argument('--data-dir', '-d',
+ default='data',
+ help='Directory containing dataset folders (default: data)')
+ parser.add_argument('--threads', '-t',
+ type=int,
+ default=16,
+ help='Number of threads (default: 16)')
+ parser.add_argument('--ctx-size', '-c',
+ type=int,
+ default=512,
+ help='Context size (default: 512)')
+ parser.add_argument('--output-dir', '-o',
+ default='perplexity_results',
+ help='Output directory for results (default: perplexity_results)')
+ parser.add_argument('--llama-perplexity',
+ default='./build/bin/llama-perplexity',
+ help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)')
+ parser.add_argument('--quick', '-q',
+ action='store_true',
+ help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)')
+ parser.add_argument('--test-embeddings', '-e',
+ action='store_true',
+ help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)')
+ parser.add_argument('--csv-output',
+ help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)')
+ parser.add_argument('--quantize-bin',
+ default='./build/bin/llama-quantize',
+ help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)')
+
+ args = parser.parse_args()
+
+ try:
+ tester = PerplexityTester(
+ model_path=args.model,
+ llama_perplexity_bin=args.llama_perplexity,
+ data_dir=args.data_dir,
+ output_dir=args.output_dir,
+ quick_mode=args.quick,
+ quantize_bin=args.quantize_bin,
+ test_embeddings=args.test_embeddings,
+ csv_output=args.csv_output
+ )
+
+ tester.run_all_tests(
+ threads=args.threads,
+ ctx_size=args.ctx_size
+ )
+
+ except FileNotFoundError as e:
+ print(f"❌ Error: {e}")
+ return 1
+ except KeyboardInterrupt:
+ print("\n\n⚠️ Test interrupted by user")
+ return 1
+ except Exception as e:
+ print(f"\n❌ Unexpected error: {e}")
+ import traceback
+ traceback.print_exc()
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())
diff --git a/utils/test_typical_shapes.sh b/utils/test_typical_shapes.sh
new file mode 100755
index 0000000..6ad805c
--- /dev/null
+++ b/utils/test_typical_shapes.sh
@@ -0,0 +1,120 @@
+#!/bin/bash
+# Test typical matrix shapes for BitNet-2B model
+# Based on BitNet-b1.58-2B-4T architecture
+
+echo "=========================================="
+echo "BitNet-2B Typical Shapes Performance Test"
+echo "=========================================="
+echo ""
+
+ITERATIONS=1000
+BENCHMARK="../build/test_gemm_kernel"
+
+# Create stats directory if not exists
+mkdir -p ../stats
+
+# Generate output CSV filename
+CSV_FILE="../stats/gemm_kernel_test_noparal.csv"
+
+# Write CSV header
+echo "test_name,n,nr,nc,time_ms,gflops,throughput_tokens_per_sec" > "$CSV_FILE"
+echo "Results will be saved to: $CSV_FILE"
+echo ""
+
+# Function to extract metrics and append to CSV
+extract_and_save() {
+ local test_name="$1"
+ local output="$2"
+
+ # Extract values using grep and awk
+ local n=$(echo "$output" | grep "Embedding dimension" | awk '{print $5}')
+ local nr=$(echo "$output" | grep "Matrix Y rows" | awk '{print $6}')
+ local nc=$(echo "$output" | grep "Matrix X columns" | awk '{print $6}')
+ local avg_time=$(echo "$output" | grep "Average time" | awk '{print $4}')
+ local min_time=$(echo "$output" | grep "Min time" | awk '{print $4}')
+ local max_time=$(echo "$output" | grep "Max time" | awk '{print $4}')
+ local gflops=$(echo "$output" | grep "GFLOPS" | awk '{print $3}')
+ local throughput=$(echo "$output" | grep "Throughput" | awk '{print $3}')
+
+ # Calculate standard deviation estimate from range (assuming ~95% of data within min-max)
+ # For normal distribution, range ≈ 4*std, so std ≈ range/4
+ local std_time=$(echo "scale=4; ($max_time - $min_time) / 4" | bc)
+
+ # Format as mean±std
+ local time_formatted="${avg_time}±${std_time}"
+
+ # For GFLOPS and throughput, we don't have std info, so just use the value
+ # If you want to estimate std for these as well, you would need more data
+
+ # Append to CSV
+ echo "${test_name},${n},${nr},${nc},${time_formatted},${gflops},${throughput}" >> "$CSV_FILE"
+}
+
+echo "Test 1: Single Token Generation (Attention QKV projection)"
+echo " Scenario: Generating 1 token at a time"
+echo " Shape: n=2048, r=1, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 1 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "single_token_gen" "$OUTPUT"
+echo ""
+
+echo "Test 2: Small Batch Prompt Processing (Attention QKV projection)"
+echo " Scenario: Processing prompt with 128 tokens, batch size 1"
+echo " Shape: n=2048, r=128, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 128 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "small_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 3: Medium Batch Prompt Processing (Attention QKV projection)"
+echo " Scenario: Processing prompt with 256 tokens or batch of 256"
+echo " Shape: n=2048, r=256, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 256 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "medium_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 4: Large Batch Processing (Attention QKV projection)"
+echo " Scenario: Processing 512 tokens or batch of 512"
+echo " Shape: n=2048, r=512, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 512 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "large_batch_prompt" "$OUTPUT"
+echo ""
+
+echo "Test 5: FFN Up-projection (Small batch)"
+echo " Scenario: Feed-forward network expansion, 128 tokens"
+echo " Shape: n=2048, r=128, c=8192"
+OUTPUT=$($BENCHMARK -n 2048 -r 128 -c 8192 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "ffn_up_projection" "$OUTPUT"
+echo ""
+
+echo "Test 6: FFN Down-projection (Small batch)"
+echo " Scenario: Feed-forward network reduction, 128 tokens"
+echo " Shape: n=8192, r=128, c=2048"
+OUTPUT=$($BENCHMARK -n 8192 -r 128 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "ffn_down_projection" "$OUTPUT"
+echo ""
+
+echo "Test 7: Long Context Processing"
+echo " Scenario: Processing very long context (2048 tokens)"
+echo " Shape: n=2048, r=2048, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 2048 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "long_context" "$OUTPUT"
+echo ""
+
+echo "Test 8: Batched Token Generation"
+echo " Scenario: Generating tokens for 32 sequences simultaneously"
+echo " Shape: n=2048, r=32, c=2048"
+OUTPUT=$($BENCHMARK -n 2048 -r 32 -c 2048 -i $ITERATIONS 2>&1)
+echo "$OUTPUT"
+extract_and_save "batched_token_gen" "$OUTPUT"
+echo ""
+
+echo "=========================================="
+echo "All tests completed!"
+echo "Results saved to: $CSV_FILE"
+echo "=========================================="
diff --git a/utils/tune_gemm_config.py b/utils/tune_gemm_config.py
new file mode 100644
index 0000000..83b4218
--- /dev/null
+++ b/utils/tune_gemm_config.py
@@ -0,0 +1,405 @@
+#!/usr/bin/env python3
+"""
+GEMM Configuration Tuning Script
+This script automatically tunes ROW_BLOCK_SIZE, COL_BLOCK_SIZE, and PARALLEL_SIZE
+to find the optimal configuration for maximum throughput (t/s).
+"""
+
+import subprocess
+import os
+import re
+import csv
+import shutil
+from datetime import datetime
+from pathlib import Path
+import argparse
+
+
+class GemmTuner:
+ def __init__(self, config_path, model_path, threads=16):
+ self.config_path = Path(config_path)
+ self.model_path = model_path
+ self.threads = threads
+ self.backup_path = self.config_path.parent / f"gemm-config.h.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
+ self.build_dir = Path("../build")
+ self.results = []
+
+ def backup_config(self):
+ """Backup current configuration file"""
+ print(f"📦 Backing up current config to {self.backup_path}")
+ shutil.copy2(self.config_path, self.backup_path)
+
+ def restore_config(self):
+ """Restore original configuration file"""
+ print(f"♻️ Restoring original config from {self.backup_path}")
+ shutil.copy2(self.backup_path, self.config_path)
+
+ def generate_config(self, act_parallel, row_block_size, col_block_size, parallel_size):
+ """Generate new configuration file"""
+ content = ""
+
+ # ACT_PARALLEL definition
+ if act_parallel:
+ content += "#define ACT_PARALLEL\n"
+ else:
+ content += "// #define ACT_PARALLEL\n"
+
+ # Detect architecture branches in original config file
+ with open(self.backup_path, 'r') as f:
+ original = f.read()
+
+ has_avx = "__AVX__" in original or "__AVX2__" in original
+ has_arm = "__ARM_NEON" in original
+
+ # If architecture detection exists, generate corresponding branches
+ if has_avx and has_arm:
+ # Multi-architecture configuration
+ content += "#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)\n"
+ content += "#if defined(ACT_PARALLEL)\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#else\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#endif\n"
+ content += "#elif defined(__ARM_NEON)\n"
+ content += "#if defined(ACT_PARALLEL)\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#else\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#endif\n"
+ content += "#endif\n"
+ elif has_avx:
+ # AVX architecture only
+ content += "#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)\n"
+ content += "#if defined(ACT_PARALLEL)\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#else\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#endif\n"
+ content += "#endif\n"
+ elif has_arm:
+ # ARM architecture only
+ content += "#if defined(__ARM_NEON)\n"
+ content += "#if defined(ACT_PARALLEL)\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#else\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#endif\n"
+ content += "#endif\n"
+ else:
+ # No architecture detection, define directly
+ content += "#if defined(ACT_PARALLEL)\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#else\n"
+ content += f" #define ROW_BLOCK_SIZE {row_block_size}\n"
+ content += f" #define COL_BLOCK_SIZE {col_block_size}\n"
+ content += f" #define PARALLEL_SIZE {parallel_size}\n"
+ content += "#endif\n"
+
+ content += "\n"
+
+ with open(self.config_path, 'w') as f:
+ f.write(content)
+
+ def rebuild_project(self):
+ """Rebuild project"""
+ print("🔨 Rebuilding project...")
+ result = subprocess.run(
+ ["cmake", "--build", str(self.build_dir), "--target", "llama-bench"],
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd()
+ )
+ if result.returncode != 0:
+ print(f"⚠️ Build warning/error: {result.stderr}")
+ return False
+ return True
+
+ def run_benchmark(self):
+ """Run benchmark test"""
+ cmd = [
+ f"{self.build_dir}/bin/llama-bench",
+ "-m", self.model_path,
+ "-p", "128",
+ "-n", "0",
+ "-t", str(self.threads),
+ "-ngl", "0"
+ ]
+
+ print(f"⚡ Running benchmark: {' '.join(cmd)}")
+
+ result = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ cwd=os.getcwd(),
+ timeout=300 # 5分钟超时
+ )
+
+ if result.returncode != 0:
+ print(f"❌ Benchmark failed: {result.stderr}")
+ return None
+
+ return result.stdout
+
+ def parse_throughput(self, output):
+ """Parse pp128 throughput from output"""
+ # 匹配 pp128: | pp128 | 501.06 ± 11.37 |
+ pp_pattern = r'\|\s+pp128\s+\|\s+([\d.]+)\s+±\s+([\d.]+)\s+\|'
+ pp_match = re.search(pp_pattern, output)
+
+ if pp_match:
+ pp_throughput = float(pp_match.group(1))
+ pp_std_dev = float(pp_match.group(2))
+
+ return {
+ 'pp_throughput': pp_throughput,
+ 'pp_std_dev': pp_std_dev
+ }
+
+ return None
+
+ def test_configuration(self, act_parallel, row_block_size, col_block_size, parallel_size):
+ """Test single configuration"""
+ config_name = f"ACT_{'ON' if act_parallel else 'OFF'}_R{row_block_size}_C{col_block_size}_P{parallel_size}"
+ print(f"\n{'='*80}")
+ print(f"🧪 Testing configuration: {config_name}")
+ print(f" ACT_PARALLEL: {act_parallel}")
+ print(f" ROW_BLOCK_SIZE: {row_block_size}")
+ print(f" COL_BLOCK_SIZE: {col_block_size}")
+ print(f" PARALLEL_SIZE: {parallel_size}")
+ print(f"{'='*80}")
+
+ # Generate configuration
+ self.generate_config(act_parallel, row_block_size, col_block_size, parallel_size)
+
+ # Rebuild project
+ if not self.rebuild_project():
+ print("⚠️ Build failed, skipping this configuration")
+ return None
+
+ # Run benchmark test
+ output = self.run_benchmark()
+ if output is None:
+ return None
+
+ # Parse results
+ metrics = self.parse_throughput(output)
+
+ if metrics is not None:
+ result = {
+ 'act_parallel': act_parallel,
+ 'row_block_size': row_block_size,
+ 'col_block_size': col_block_size,
+ 'parallel_size': parallel_size,
+ 'config_name': config_name,
+ **metrics
+ }
+ self.results.append(result)
+ print(f"✅ PP128: {metrics['pp_throughput']:.2f} ± {metrics['pp_std_dev']:.2f} t/s")
+ return result
+ else:
+ print("❌ Failed to parse throughput")
+ return None
+
+ def save_results(self, csv_path):
+ """Save results to CSV file"""
+ print(f"\n💾 Saving results to {csv_path}")
+
+ with open(csv_path, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=[
+ 'config_name', 'act_parallel', 'row_block_size',
+ 'col_block_size', 'parallel_size',
+ 'pp_throughput', 'pp_std_dev'
+ ])
+ writer.writeheader()
+ writer.writerows(self.results)
+
+ def find_best_config(self):
+ """Find the best configuration with highest throughput"""
+ if not self.results:
+ print("❌ No valid results found")
+ return None
+
+ best = max(self.results, key=lambda x: x['pp_throughput'])
+ return best
+
+ def run_tuning(self, configurations, output_csv=None):
+ """Run test for all configurations"""
+ print(f"\n🚀 Starting tuning process with {len(configurations)} configurations")
+ print(f"📊 Model: {self.model_path}")
+ print(f"🧵 Threads: {self.threads}\n")
+
+ # Backup configuration
+ self.backup_config()
+
+ try:
+ # Test all configurations
+ for i, config in enumerate(configurations, 1):
+ print(f"\n[{i}/{len(configurations)}]")
+ self.test_configuration(**config)
+
+ # Save results
+ if output_csv is None:
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+ csv_path = f"stats/tuning_results_{timestamp}.csv"
+ else:
+ csv_path = output_csv
+ self.save_results(csv_path)
+
+ # Find best configuration
+ best = self.find_best_config()
+ if best:
+ print(f"\n{'='*80}")
+ print(f"🏆 BEST CONFIGURATION FOUND!")
+ print(f"{'='*80}")
+ print(f"Configuration: {best['config_name']}")
+ print(f"ACT_PARALLEL: {best['act_parallel']}")
+ print(f"ROW_BLOCK_SIZE: {best['row_block_size']}")
+ print(f"COL_BLOCK_SIZE: {best['col_block_size']}")
+ print(f"PARALLEL_SIZE: {best['parallel_size']}")
+ print(f"PP128 Throughput: {best['pp_throughput']:.2f} ± {best['pp_std_dev']:.2f} t/s")
+ print(f"{'='*80}\n")
+
+ # Apply best configuration
+ apply = input("Do you want to apply this configuration? (y/n): ").strip().lower()
+ if apply == 'y':
+ self.generate_config(
+ best['act_parallel'],
+ best['row_block_size'],
+ best['col_block_size'],
+ best['parallel_size']
+ )
+ self.rebuild_project()
+ print("✅ Best configuration applied!")
+ else:
+ self.restore_config()
+ print("✅ Original configuration restored")
+
+ except KeyboardInterrupt:
+ print("\n⚠️ Tuning interrupted by user")
+ self.restore_config()
+ except Exception as e:
+ print(f"\n❌ Error during tuning: {e}")
+ self.restore_config()
+ raise
+
+
+def generate_configurations():
+ """Generate list of configurations to test"""
+ configurations = []
+
+ act_parallel_options = [True]
+
+ row_sizes = [2, 4, 8, 16, 32]
+ col_sizes = [32, 64, 128, 256, 512, 1024]
+ parallelism_degree = [2, 4, 8]
+
+ for act_parallel in act_parallel_options:
+ for row in row_sizes:
+ for col in col_sizes:
+ for parallel in parallelism_degree:
+ # Add filtering conditions
+ if act_parallel:
+ # When ACT_PARALLEL=True, only calculate combinations with parallel < row
+ if parallel > row:
+ continue
+ else:
+ # When ACT_PARALLEL=False, only calculate combinations with parallel < col
+ if parallel > col:
+ continue
+
+ configurations.append({
+ 'act_parallel': act_parallel,
+ 'row_block_size': row,
+ 'col_block_size': col,
+ 'parallel_size': parallel
+ })
+
+ return configurations
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Tune GEMM configuration for optimal performance')
+ parser.add_argument('--config', default='../include/gemm-config.h',
+ help='Path to gemm-config.h file')
+ parser.add_argument('--model', default='../models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf',
+ help='Path to model file')
+ parser.add_argument('--threads', type=int, default=8,
+ help='Number of threads to use')
+ parser.add_argument('--quick', action='store_true',
+ help='Quick test with fewer configurations')
+ parser.add_argument('--custom', action='store_true',
+ help='Manually specify configurations to test')
+ parser.add_argument('--output', type=str, default=None,
+ help='Output CSV file path (default: stats/tuning_results_.csv)')
+
+ args = parser.parse_args()
+
+ tuner = GemmTuner(args.config, args.model, args.threads)
+
+ if args.custom:
+ # Custom configuration mode
+ print("Custom configuration mode")
+ configurations = []
+ while True:
+ print("\nEnter configuration (or 'done' to finish):")
+ act = input("ACT_PARALLEL (y/n): ").strip().lower() == 'y'
+ if input == 'done':
+ break
+ row = int(input("ROW_BLOCK_SIZE: "))
+ col = int(input("COL_BLOCK_SIZE: "))
+ par = int(input("PARALLEL_SIZE: "))
+ configurations.append({
+ 'act_parallel': act,
+ 'row_block_size': row,
+ 'col_block_size': col,
+ 'parallel_size': par
+ })
+ elif args.quick:
+ # Quick test mode - test only a few key configurations
+ configurations = [
+ {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 128, 'parallel_size': 4},
+ {'act_parallel': True, 'row_block_size': 8, 'col_block_size': 128, 'parallel_size': 4},
+ {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 64, 'parallel_size': 4},
+ {'act_parallel': False, 'row_block_size': 32, 'col_block_size': 4, 'parallel_size': 4},
+ {'act_parallel': False, 'row_block_size': 16, 'col_block_size': 4, 'parallel_size': 4},
+ ]
+ else:
+ # Full test mode
+ configurations = generate_configurations()
+
+ print(f"\n{'='*80}")
+ print(f"GEMM Configuration Tuner")
+ print(f"{'='*80}")
+ print(f"Total configurations to test: {len(configurations)}")
+ print(f"Estimated time: ~{len(configurations) * 0.5:.1f} minutes (assuming 30s per test)")
+ print(f"{'='*80}\n")
+
+ proceed = input("Proceed with tuning? (y/n): ").strip().lower()
+ if proceed != 'y':
+ print("Tuning cancelled")
+ return
+
+ tuner.run_tuning(configurations, output_csv=args.output)
+
+
+if __name__ == "__main__":
+ main()