Files
BitNet/gpu
2025-05-19 04:34:00 +00:00
..
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00
2025-05-19 04:34:00 +00:00

BitNet Inference Kernel

This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the BitNet-b1.58-2B-4T model.

Features

  • Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
  • Custom CUDA kernels with low-latency execution
  • Optimizations for memory access, decoding, and compute throughput

Usage

Installation and kernel performance tests:

# (Recommended) Create a new conda environment
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu

# Install dependencies
pip install -r requirements.txt

# Build the kernel
cd bitnet_kernels
bash compile.sh
cd ..

# Run performance tests
python test.py

End-to-end inference:

# Download and convert the BitNet-b1.58-2B model
mkdir checkpoints
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
rm ./checkpoints/model_state.pt

# Inference
python3 ./generate.py ./checkpoints/ --interactive --chat_format

Optimizations

Weight Permutation

The weight matrix is divided into 16×32 blocks to optimize memory access patterns.

Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.

See convert_checkpoint.py for details.

Fast Decoding

Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:

[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]

This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into int8.

dp4a Instruction

We use the dp4a instruction to accelerate low-precision dot product operations.

This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.

It significantly improves GEMV throughput when processing quantized weights and activations.

Performance

Kernel performance (tested on NVIDIA A100 40GB GPU):

Shape (N×K) W2A8 Latency (us) BF16 Latency (us) Speedup Ratio
2560 × 2560 13.32 18.32 1.38
3840 × 2560 14.90 18.87 1.27
13824 × 2560 18.75 59.51 3.17
2560 × 6912 14.49 37.78 2.61
3200 × 3200 14.61 19.08 1.31
4800 × 3200 13.09 21.84 1.67
3200 × 10240 19.64 60.79 3.10
20480 × 3200 30.99 112.39 3.63

Generation throughput:

BF16 (tokens/s) W2A8 (tokens/s) Speedup Ratio
10.9 213.3 19.6