Files
BitNet/utils/quantize_embeddings.py
T

474 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Embedding Quantization Script
This script converts ggml-model-f32.gguf to multiple quantized versions
with different token embedding types.
"""
import subprocess
import os
import argparse
import re
import csv
from pathlib import Path
from datetime import datetime
class EmbeddingQuantizer:
def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
self.input_model = Path(input_model)
self.output_dir = Path(output_dir)
self.quantize_bin = Path(quantize_bin)
self.bench_bin = Path(bench_bin)
self.stats_dir = Path(stats_dir)
self.csv_output = Path(csv_output) if csv_output else None
# Verify input file exists
if not self.input_model.exists():
raise FileNotFoundError(f"Input model not found: {self.input_model}")
# Verify quantize tool exists
if not self.quantize_bin.exists():
raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
# Verify bench tool exists
if not self.bench_bin.exists():
raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
# Create output directories
self.output_dir.mkdir(parents=True, exist_ok=True)
self.stats_dir.mkdir(parents=True, exist_ok=True)
self.results = []
self.newly_created_files = set() # Track newly created files
def quantize(self, embedding_type, output_suffix):
"""
Perform single quantization
Args:
embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
Returns:
bool: Whether successful
"""
output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
# Check if file already exists
file_already_existed = output_file.exists()
if file_already_existed:
print(f"️ File already exists: {output_file}")
print(f" Skipping quantization, will use existing file for benchmark")
return True
cmd = [
str(self.quantize_bin),
"--token-embedding-type", embedding_type,
str(self.input_model),
str(output_file),
"I2_S",
"1",
"1"
]
print(f"\n{'='*80}")
print(f"🔄 Quantizing with embedding type: {embedding_type}")
print(f"📥 Input: {self.input_model}")
print(f"📤 Output: {output_file}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
start_time = datetime.now()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=600 # 10 minute timeout
)
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
if result.returncode == 0:
# Get output file size
file_size_mb = output_file.stat().st_size / (1024 * 1024)
print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
# Record newly created file
if not file_already_existed:
self.newly_created_files.add(output_file)
# Print part of output
if result.stdout:
print("\n📊 Quantization output:")
print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
return True
else:
print(f"❌ Failed with return code {result.returncode}")
print(f"Error: {result.stderr}")
return False
except subprocess.TimeoutExpired:
print(f"❌ Timeout (exceeded 10 minutes)")
return False
except Exception as e:
print(f"❌ Exception: {e}")
return False
def benchmark_model(self, output_suffix):
"""
Benchmark model
Args:
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
Returns:
dict: Dictionary with benchmark results, or None if failed
"""
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
if not model_file.exists():
print(f"❌ Model file not found for benchmarking: {model_file}")
return None
cmd = [
str(self.bench_bin),
"-m", str(model_file),
"-p", "128",
"-n", "0",
"-t", "1,2,4,8",
"-ngl", "0"
]
print(f"\n{'='*80}")
print(f"🏃 Running benchmark for: {output_suffix}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=300 # 5 minute timeout
)
if result.returncode == 0:
print("✅ Benchmark completed successfully")
print("\n📊 Benchmark output:")
print(result.stdout)
# 解析输出
bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
return bench_results
else:
print(f"❌ Benchmark failed with return code {result.returncode}")
print(f"Error: {result.stderr}")
return None
except subprocess.TimeoutExpired:
print(f"❌ Benchmark timeout (exceeded 5 minutes)")
return None
except Exception as e:
print(f"❌ Benchmark exception: {e}")
return None
def parse_benchmark_output(self, output, output_suffix):
"""
Parse benchmark output to extract t/s data (mean±std)
Args:
output: Benchmark command output
output_suffix: Output file suffix
Returns:
dict: Dictionary with parsed results
"""
results = {
'embedding_type': output_suffix,
'threads_1': None,
'threads_2': None,
'threads_4': None,
'threads_8': None,
}
# Parse table data
# Find lines containing pp128 and t/s
lines = output.strip().split('\n')
for line in lines:
# Skip header and separator lines
if '|' not in line or 'model' in line or '---' in line:
continue
# Try to extract data
# Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
parts = [p.strip() for p in line.split('|')]
if len(parts) >= 8 and 'pp128' in parts[6]:
threads_str = parts[5].strip()
throughput_str = parts[7].strip()
# Extract thread count
try:
threads = int(threads_str)
except:
continue
# Extract t/s data (format: "405.73 ± 3.69" or "405.73")
# Try to match "mean ± std" format
match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
if match_with_std:
mean = float(match_with_std.group(1))
std = float(match_with_std.group(2))
throughput = f"{mean:.2f}±{std:.2f}"
else:
# Only mean, no std
match = re.search(r'([\d.]+)', throughput_str)
if match:
throughput = f"{float(match.group(1)):.2f}"
else:
continue
# Store result based on thread count
if threads == 1:
results['threads_1'] = throughput
elif threads == 2:
results['threads_2'] = throughput
elif threads == 4:
results['threads_4'] = throughput
elif threads == 8:
results['threads_8'] = throughput
return results
def cleanup_model(self, output_suffix):
"""
Cleanup model files (only delete newly created files)
Args:
output_suffix: Output file suffix
"""
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
if model_file in self.newly_created_files:
try:
model_file.unlink()
print(f"🗑️ Deleted newly created file: {model_file}")
self.newly_created_files.remove(model_file)
except Exception as e:
print(f"⚠️ Failed to delete {model_file}: {e}")
else:
print(f"️ Keeping existing file: {model_file}")
def run_all_quantizations(self, types_to_quantize):
"""
Run all quantizations
Args:
types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
"""
print(f"\n{'='*80}")
print(f"🚀 Starting Embedding Quantization and Benchmarking")
print(f"{'='*80}")
print(f"📥 Input model: {self.input_model}")
print(f"📤 Output directory: {self.output_dir}")
print(f"📊 Stats directory: {self.stats_dir}")
print(f"🔢 Total quantizations: {len(types_to_quantize)}")
print(f"{'='*80}\n")
total_start = datetime.now()
for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
print(f"\n{'#'*80}")
print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
print(f"{'#'*80}\n")
# Quantize model
success = self.quantize(embedding_type, output_suffix)
if not success:
print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
continue
# Run benchmark
bench_results = self.benchmark_model(output_suffix)
if bench_results:
self.results.append(bench_results)
else:
print(f"⚠️ Benchmark failed for {output_suffix}")
# Cleanup model files (only delete newly created files)
self.cleanup_model(output_suffix)
print(f"\n{'#'*80}")
print(f"✅ Completed {output_suffix}")
print(f"{'#'*80}\n")
total_end = datetime.now()
total_duration = (total_end - total_start).total_seconds()
# 保存结果到CSV
self.save_results_to_csv()
# 打印总结
self.print_summary(total_duration)
def save_results_to_csv(self):
"""将benchmark结果保存到CSV文件"""
if not self.results:
print("⚠️ No results to save")
return
# Use user-specified CSV path, otherwise use default path
if self.csv_output:
csv_file = self.csv_output
# Ensure parent directory exists
csv_file.parent.mkdir(parents=True, exist_ok=True)
else:
csv_file = self.stats_dir / f"embedding_benchmark.csv"
print(f"\n💾 Saving results to: {csv_file}")
try:
with open(csv_file, 'w', newline='') as f:
fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for result in self.results:
writer.writerow(result)
print(f"✅ Results saved successfully")
# Also print table
print(f"\n📊 Benchmark Results:")
print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
print("-" * 87)
for result in self.results:
t1 = result['threads_1'] if result['threads_1'] else "N/A"
t2 = result['threads_2'] if result['threads_2'] else "N/A"
t4 = result['threads_4'] if result['threads_4'] else "N/A"
t8 = result['threads_8'] if result['threads_8'] else "N/A"
print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
except Exception as e:
print(f"❌ Failed to save results: {e}")
def print_summary(self, total_duration):
"""Print quantization summary"""
print(f"\n\n{'='*80}")
print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
print(f"{'='*80}\n")
successful = len(self.results)
total = len(self.results)
print(f"✅ Completed: {successful} benchmarks")
print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
if self.results:
if self.csv_output and self.csv_output.exists():
print(f"📁 Results saved to: {self.csv_output}")
else:
csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
if csv_files:
latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
print(f"📁 Results saved to: {latest_csv}")
print(f"\n{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
parser.add_argument('--input', '-i',
default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
parser.add_argument('--output-dir', '-o',
default='../models/BitNet-b1.58-2B-4T',
help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
parser.add_argument('--quantize-bin', '-q',
default='../build/bin/llama-quantize',
help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
parser.add_argument('--bench-bin', '-b',
default='../build/bin/llama-bench',
help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
parser.add_argument('--stats-dir',
default='../stats',
help='Directory to save benchmark results (default: ../stats)')
parser.add_argument('--csv-output', '-c',
help='Custom path for CSV output file (e.g., stats/my_results.csv)')
parser.add_argument('--types', '-t',
nargs='+',
help='Specific types to quantize (e.g., f32 q6_k q4_0)')
parser.add_argument('--skip-existing', '-s',
action='store_true',
help='Skip quantization if output file already exists (will still benchmark existing files)')
args = parser.parse_args()
# Define all supported quantization types
# Format: (embedding_type for command line, output_suffix for filename)
all_types = [
('F32', 'f32'),
('F16', 'f16'),
('Q8_0', 'q8_0'),
('Q6_K', 'q6_k'),
('Q5_0', 'q5_0'),
('Q4_0', 'q4_0'),
('Q3_K', 'q3_k'),
('TQ2_0', 'tq2_0'),
]
# If specific types are specified, filter the list
if args.types:
types_lower = [t.lower() for t in args.types]
types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
if not types_to_quantize:
print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
return
else:
types_to_quantize = all_types
# If skip existing files is enabled, no need to filter
# Because new logic will automatically detect and skip during quantization, but will still benchmark
# 创建量化器并运行
try:
quantizer = EmbeddingQuantizer(
args.input,
args.output_dir,
args.quantize_bin,
args.bench_bin,
args.stats_dir,
args.csv_output
)
quantizer.run_all_quantizations(types_to_quantize)
except FileNotFoundError as e:
print(f"❌ Error: {e}")
return 1
except KeyboardInterrupt:
print("\n\n⚠️ Quantization interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
exit(main() or 0)