mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-03 11:20:36 +00:00
474 lines
17 KiB
Python
474 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Embedding Quantization Script
|
||
This script converts ggml-model-f32.gguf to multiple quantized versions
|
||
with different token embedding types.
|
||
"""
|
||
|
||
import subprocess
|
||
import os
|
||
import argparse
|
||
import re
|
||
import csv
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
|
||
class EmbeddingQuantizer:
|
||
def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize",
|
||
bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None):
|
||
self.input_model = Path(input_model)
|
||
self.output_dir = Path(output_dir)
|
||
self.quantize_bin = Path(quantize_bin)
|
||
self.bench_bin = Path(bench_bin)
|
||
self.stats_dir = Path(stats_dir)
|
||
self.csv_output = Path(csv_output) if csv_output else None
|
||
|
||
# Verify input file exists
|
||
if not self.input_model.exists():
|
||
raise FileNotFoundError(f"Input model not found: {self.input_model}")
|
||
|
||
# Verify quantize tool exists
|
||
if not self.quantize_bin.exists():
|
||
raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}")
|
||
|
||
# Verify bench tool exists
|
||
if not self.bench_bin.exists():
|
||
raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}")
|
||
|
||
# Create output directories
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
self.stats_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
self.results = []
|
||
self.newly_created_files = set() # Track newly created files
|
||
|
||
def quantize(self, embedding_type, output_suffix):
|
||
"""
|
||
Perform single quantization
|
||
|
||
Args:
|
||
embedding_type: Token embedding type (uppercase format, e.g., Q6_K)
|
||
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
|
||
|
||
Returns:
|
||
bool: Whether successful
|
||
"""
|
||
output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||
|
||
# Check if file already exists
|
||
file_already_existed = output_file.exists()
|
||
|
||
if file_already_existed:
|
||
print(f"ℹ️ File already exists: {output_file}")
|
||
print(f" Skipping quantization, will use existing file for benchmark")
|
||
return True
|
||
|
||
cmd = [
|
||
str(self.quantize_bin),
|
||
"--token-embedding-type", embedding_type,
|
||
str(self.input_model),
|
||
str(output_file),
|
||
"I2_S",
|
||
"1",
|
||
"1"
|
||
]
|
||
|
||
print(f"\n{'='*80}")
|
||
print(f"🔄 Quantizing with embedding type: {embedding_type}")
|
||
print(f"📥 Input: {self.input_model}")
|
||
print(f"📤 Output: {output_file}")
|
||
print(f"💻 Command: {' '.join(cmd)}")
|
||
print(f"{'='*80}\n")
|
||
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
cwd=os.getcwd(),
|
||
timeout=600 # 10 minute timeout
|
||
)
|
||
|
||
end_time = datetime.now()
|
||
duration = (end_time - start_time).total_seconds()
|
||
|
||
if result.returncode == 0:
|
||
# Get output file size
|
||
file_size_mb = output_file.stat().st_size / (1024 * 1024)
|
||
|
||
print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB")
|
||
|
||
# Record newly created file
|
||
if not file_already_existed:
|
||
self.newly_created_files.add(output_file)
|
||
|
||
# Print part of output
|
||
if result.stdout:
|
||
print("\n📊 Quantization output:")
|
||
print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout)
|
||
|
||
return True
|
||
else:
|
||
print(f"❌ Failed with return code {result.returncode}")
|
||
print(f"Error: {result.stderr}")
|
||
return False
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print(f"❌ Timeout (exceeded 10 minutes)")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"❌ Exception: {e}")
|
||
return False
|
||
|
||
def benchmark_model(self, output_suffix):
|
||
"""
|
||
Benchmark model
|
||
|
||
Args:
|
||
output_suffix: Output file suffix (lowercase format, e.g., q6_k)
|
||
|
||
Returns:
|
||
dict: Dictionary with benchmark results, or None if failed
|
||
"""
|
||
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||
|
||
if not model_file.exists():
|
||
print(f"❌ Model file not found for benchmarking: {model_file}")
|
||
return None
|
||
|
||
cmd = [
|
||
str(self.bench_bin),
|
||
"-m", str(model_file),
|
||
"-p", "128",
|
||
"-n", "0",
|
||
"-t", "1,2,4,8",
|
||
"-ngl", "0"
|
||
]
|
||
|
||
print(f"\n{'='*80}")
|
||
print(f"🏃 Running benchmark for: {output_suffix}")
|
||
print(f"💻 Command: {' '.join(cmd)}")
|
||
print(f"{'='*80}\n")
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
cwd=os.getcwd(),
|
||
timeout=300 # 5 minute timeout
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
print("✅ Benchmark completed successfully")
|
||
print("\n📊 Benchmark output:")
|
||
print(result.stdout)
|
||
|
||
# 解析输出
|
||
bench_results = self.parse_benchmark_output(result.stdout, output_suffix)
|
||
return bench_results
|
||
else:
|
||
print(f"❌ Benchmark failed with return code {result.returncode}")
|
||
print(f"Error: {result.stderr}")
|
||
return None
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print(f"❌ Benchmark timeout (exceeded 5 minutes)")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ Benchmark exception: {e}")
|
||
return None
|
||
|
||
def parse_benchmark_output(self, output, output_suffix):
|
||
"""
|
||
Parse benchmark output to extract t/s data (mean±std)
|
||
|
||
Args:
|
||
output: Benchmark command output
|
||
output_suffix: Output file suffix
|
||
|
||
Returns:
|
||
dict: Dictionary with parsed results
|
||
"""
|
||
results = {
|
||
'embedding_type': output_suffix,
|
||
'threads_1': None,
|
||
'threads_2': None,
|
||
'threads_4': None,
|
||
'threads_8': None,
|
||
}
|
||
|
||
# Parse table data
|
||
# Find lines containing pp128 and t/s
|
||
lines = output.strip().split('\n')
|
||
|
||
for line in lines:
|
||
# Skip header and separator lines
|
||
if '|' not in line or 'model' in line or '---' in line:
|
||
continue
|
||
|
||
# Try to extract data
|
||
# Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 |
|
||
parts = [p.strip() for p in line.split('|')]
|
||
|
||
if len(parts) >= 8 and 'pp128' in parts[6]:
|
||
threads_str = parts[5].strip()
|
||
throughput_str = parts[7].strip()
|
||
|
||
# Extract thread count
|
||
try:
|
||
threads = int(threads_str)
|
||
except:
|
||
continue
|
||
|
||
# Extract t/s data (format: "405.73 ± 3.69" or "405.73")
|
||
# Try to match "mean ± std" format
|
||
match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str)
|
||
if match_with_std:
|
||
mean = float(match_with_std.group(1))
|
||
std = float(match_with_std.group(2))
|
||
throughput = f"{mean:.2f}±{std:.2f}"
|
||
else:
|
||
# Only mean, no std
|
||
match = re.search(r'([\d.]+)', throughput_str)
|
||
if match:
|
||
throughput = f"{float(match.group(1)):.2f}"
|
||
else:
|
||
continue
|
||
|
||
# Store result based on thread count
|
||
if threads == 1:
|
||
results['threads_1'] = throughput
|
||
elif threads == 2:
|
||
results['threads_2'] = throughput
|
||
elif threads == 4:
|
||
results['threads_4'] = throughput
|
||
elif threads == 8:
|
||
results['threads_8'] = throughput
|
||
|
||
return results
|
||
|
||
def cleanup_model(self, output_suffix):
|
||
"""
|
||
Cleanup model files (only delete newly created files)
|
||
|
||
Args:
|
||
output_suffix: Output file suffix
|
||
"""
|
||
model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
|
||
|
||
if model_file in self.newly_created_files:
|
||
try:
|
||
model_file.unlink()
|
||
print(f"🗑️ Deleted newly created file: {model_file}")
|
||
self.newly_created_files.remove(model_file)
|
||
except Exception as e:
|
||
print(f"⚠️ Failed to delete {model_file}: {e}")
|
||
else:
|
||
print(f"ℹ️ Keeping existing file: {model_file}")
|
||
|
||
def run_all_quantizations(self, types_to_quantize):
|
||
"""
|
||
Run all quantizations
|
||
|
||
Args:
|
||
types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix)
|
||
"""
|
||
print(f"\n{'='*80}")
|
||
print(f"🚀 Starting Embedding Quantization and Benchmarking")
|
||
print(f"{'='*80}")
|
||
print(f"📥 Input model: {self.input_model}")
|
||
print(f"📤 Output directory: {self.output_dir}")
|
||
print(f"📊 Stats directory: {self.stats_dir}")
|
||
print(f"🔢 Total quantizations: {len(types_to_quantize)}")
|
||
print(f"{'='*80}\n")
|
||
|
||
total_start = datetime.now()
|
||
|
||
for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1):
|
||
print(f"\n{'#'*80}")
|
||
print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})")
|
||
print(f"{'#'*80}\n")
|
||
|
||
# Quantize model
|
||
success = self.quantize(embedding_type, output_suffix)
|
||
|
||
if not success:
|
||
print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure")
|
||
continue
|
||
|
||
# Run benchmark
|
||
bench_results = self.benchmark_model(output_suffix)
|
||
|
||
if bench_results:
|
||
self.results.append(bench_results)
|
||
else:
|
||
print(f"⚠️ Benchmark failed for {output_suffix}")
|
||
|
||
# Cleanup model files (only delete newly created files)
|
||
self.cleanup_model(output_suffix)
|
||
|
||
print(f"\n{'#'*80}")
|
||
print(f"✅ Completed {output_suffix}")
|
||
print(f"{'#'*80}\n")
|
||
|
||
total_end = datetime.now()
|
||
total_duration = (total_end - total_start).total_seconds()
|
||
|
||
# 保存结果到CSV
|
||
self.save_results_to_csv()
|
||
|
||
# 打印总结
|
||
self.print_summary(total_duration)
|
||
|
||
def save_results_to_csv(self):
|
||
"""将benchmark结果保存到CSV文件"""
|
||
if not self.results:
|
||
print("⚠️ No results to save")
|
||
return
|
||
|
||
# Use user-specified CSV path, otherwise use default path
|
||
if self.csv_output:
|
||
csv_file = self.csv_output
|
||
# Ensure parent directory exists
|
||
csv_file.parent.mkdir(parents=True, exist_ok=True)
|
||
else:
|
||
csv_file = self.stats_dir / f"embedding_benchmark.csv"
|
||
|
||
print(f"\n💾 Saving results to: {csv_file}")
|
||
|
||
try:
|
||
with open(csv_file, 'w', newline='') as f:
|
||
fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8']
|
||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||
|
||
writer.writeheader()
|
||
for result in self.results:
|
||
writer.writerow(result)
|
||
|
||
print(f"✅ Results saved successfully")
|
||
|
||
# Also print table
|
||
print(f"\n📊 Benchmark Results:")
|
||
print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}")
|
||
print("-" * 87)
|
||
for result in self.results:
|
||
t1 = result['threads_1'] if result['threads_1'] else "N/A"
|
||
t2 = result['threads_2'] if result['threads_2'] else "N/A"
|
||
t4 = result['threads_4'] if result['threads_4'] else "N/A"
|
||
t8 = result['threads_8'] if result['threads_8'] else "N/A"
|
||
print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ Failed to save results: {e}")
|
||
|
||
def print_summary(self, total_duration):
|
||
"""Print quantization summary"""
|
||
print(f"\n\n{'='*80}")
|
||
print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY")
|
||
print(f"{'='*80}\n")
|
||
|
||
successful = len(self.results)
|
||
total = len(self.results)
|
||
|
||
print(f"✅ Completed: {successful} benchmarks")
|
||
print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n")
|
||
|
||
if self.results:
|
||
if self.csv_output and self.csv_output.exists():
|
||
print(f"📁 Results saved to: {self.csv_output}")
|
||
else:
|
||
csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv"))
|
||
if csv_files:
|
||
latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime)
|
||
print(f"📁 Results saved to: {latest_csv}")
|
||
|
||
print(f"\n{'='*80}\n")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats')
|
||
parser.add_argument('--input', '-i',
|
||
default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf',
|
||
help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)')
|
||
parser.add_argument('--output-dir', '-o',
|
||
default='../models/BitNet-b1.58-2B-4T',
|
||
help='Output directory (default: ../models/BitNet-b1.58-2B-4T)')
|
||
parser.add_argument('--quantize-bin', '-q',
|
||
default='../build/bin/llama-quantize',
|
||
help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)')
|
||
parser.add_argument('--bench-bin', '-b',
|
||
default='../build/bin/llama-bench',
|
||
help='Path to llama-bench binary (default: ../build/bin/llama-bench)')
|
||
parser.add_argument('--stats-dir',
|
||
default='../stats',
|
||
help='Directory to save benchmark results (default: ../stats)')
|
||
parser.add_argument('--csv-output', '-c',
|
||
help='Custom path for CSV output file (e.g., stats/my_results.csv)')
|
||
parser.add_argument('--types', '-t',
|
||
nargs='+',
|
||
help='Specific types to quantize (e.g., f32 q6_k q4_0)')
|
||
parser.add_argument('--skip-existing', '-s',
|
||
action='store_true',
|
||
help='Skip quantization if output file already exists (will still benchmark existing files)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Define all supported quantization types
|
||
# Format: (embedding_type for command line, output_suffix for filename)
|
||
all_types = [
|
||
('F32', 'f32'),
|
||
('F16', 'f16'),
|
||
('Q8_0', 'q8_0'),
|
||
('Q6_K', 'q6_k'),
|
||
('Q5_0', 'q5_0'),
|
||
('Q4_0', 'q4_0'),
|
||
('Q3_K', 'q3_k'),
|
||
('TQ2_0', 'tq2_0'),
|
||
]
|
||
|
||
# If specific types are specified, filter the list
|
||
if args.types:
|
||
types_lower = [t.lower() for t in args.types]
|
||
types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower]
|
||
if not types_to_quantize:
|
||
print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}")
|
||
return
|
||
else:
|
||
types_to_quantize = all_types
|
||
|
||
# If skip existing files is enabled, no need to filter
|
||
# Because new logic will automatically detect and skip during quantization, but will still benchmark
|
||
|
||
# 创建量化器并运行
|
||
try:
|
||
quantizer = EmbeddingQuantizer(
|
||
args.input,
|
||
args.output_dir,
|
||
args.quantize_bin,
|
||
args.bench_bin,
|
||
args.stats_dir,
|
||
args.csv_output
|
||
)
|
||
quantizer.run_all_quantizations(types_to_quantize)
|
||
except FileNotFoundError as e:
|
||
print(f"❌ Error: {e}")
|
||
return 1
|
||
except KeyboardInterrupt:
|
||
print("\n\n⚠️ Quantization interrupted by user")
|
||
return 1
|
||
except Exception as e:
|
||
print(f"\n❌ Unexpected error: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main() or 0)
|