Files
BitNet/gpu/README.md
2025-05-19 04:34:00 +00:00

93 lines
3.3 KiB
Markdown
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# BitNet Inference Kernel
This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
## Features
- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
- Custom CUDA kernels with low-latency execution
- Optimizations for memory access, decoding, and compute throughput
## Usage
Installation and kernel performance tests:
```bash
# (Recommended) Create a new conda environment
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu
# Install dependencies
pip install -r requirements.txt
# Build the kernel
cd bitnet_kernels
bash compile.sh
cd ..
# Run performance tests
python test.py
```
End-to-end inference:
```bash
# Download and convert the BitNet-b1.58-2B model
mkdir checkpoints
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
rm ./checkpoints/model_state.pt
# Inference
python3 ./generate.py ./checkpoints/ --interactive --chat_format
```
## Optimizations
### Weight Permutation
The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
See `convert_checkpoint.py` for details.
### Fast Decoding
Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
```
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
```
This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
### `dp4a` Instruction
We use the `dp4a` instruction to accelerate low-precision dot product operations.
This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
It significantly improves GEMV throughput when processing quantized weights and activations.
## Performance
Kernel performance (tested on NVIDIA A100 40GB GPU):
| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|---------------------|-------------------|-------------------|----------------------|
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |
Generation throughput:
| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|---|---|---|
| 10.9 | 213.3 | 19.6 |