Merge branch 'microsoft:main' into add-falcon-e-final

This commit is contained in:
Younes Belkada
2025-05-20 17:05:11 +04:00
committed by GitHub
18 changed files with 129726 additions and 3 deletions
+1
View File
@@ -34,6 +34,7 @@ nppBackup
# Models
models/*
gpu/checkpoints/*
# Python
+4 -3
View File
@@ -4,9 +4,9 @@
[<img src="./assets/header_model_release.png" alt="BitNet Model on Hugging Face" width="800"/>](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or [build and run](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) it on your own CPU.
Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).
bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next).
bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
@@ -22,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
## What's New:
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ![NEW](https://img.shields.io/badge/NEW-red)
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) ![NEW](https://img.shields.io/badge/NEW-red)
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
- 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
Executable
+93
View File
@@ -0,0 +1,93 @@
# BitNet Inference Kernel
This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
## Features
- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
- Custom CUDA kernels with low-latency execution
- Optimizations for memory access, decoding, and compute throughput
## Usage
Installation and kernel performance tests:
```bash
# (Recommended) Create a new conda environment
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu
# Install dependencies
pip install -r requirements.txt
# Build the kernel
cd bitnet_kernels
bash compile.sh
cd ..
# Run performance tests
python test.py
```
End-to-end inference:
```bash
# Download and convert the BitNet-b1.58-2B model
mkdir checkpoints
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
rm ./checkpoints/model_state.pt
# Inference
python3 ./generate.py ./checkpoints/ --interactive --chat_format
```
## Optimizations
### Weight Permutation
The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
See `convert_checkpoint.py` for details.
### Fast Decoding
Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
```
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
```
This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
### `dp4a` Instruction
We use the `dp4a` instruction to accelerate low-precision dot product operations.
This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
It significantly improves GEMV throughput when processing quantized weights and activations.
## Performance
Kernel performance (tested on NVIDIA A100 40GB GPU):
| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|---------------------|-------------------|-------------------|----------------------|
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |
Generation throughput:
| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|---|---|---|
| 10.9 | 213.3 | 19.6 |
+37
View File
@@ -0,0 +1,37 @@
#include "bitnet_kernels.h"
extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
if (M == 1 && N == 3840 && K == 2560){
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 2560 && K == 2560){
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 13824 && K == 2560){
ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 2560 && K == 6912){
ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 4800 && K == 3200){
ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 3200 && K == 3200){
ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 20480 && K == 3200){
ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 3200 && K == 10240){
ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 5120 && K == 27648){
ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if(M == 1 && N == 55296 && K == 5120){
ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else{
std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
}
}
+83
View File
@@ -0,0 +1,83 @@
#include <cuda_runtime.h>
#include <math_constants.h>
#include <math.h>
#include <mma.h>
#include <iostream>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
#else
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
#endif
template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
{
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);
// i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
uint const i2s = *_i2s;
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
#pragma unroll
for (int i = 0; i < (N / 4); i++)
{
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(i8s[i], 0x02020202);
}
}
template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
__global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
constexpr int K_per_loop = 16;
constexpr int wmma_K = 32;
constexpr int wmma_N = 16;
int in_thread_C_local[1];
signed char A_local[K_per_loop];
int B_reshape_local[1];
signed char B_decode_local[K_per_loop];
int red_buf0[1];
in_thread_C_local[0] = 0;
#pragma unroll
for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
*(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
B_reshape_local[0] = *(int*)(B +
(((int)blockIdx.x) * N_block_size * K / 4) +
(k_0 * K_block_size * K_per_loop * wmma_N / 4) +
((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
);
decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
#pragma unroll
for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
}
}
red_buf0[0] = in_thread_C_local[0];
#pragma unroll
for (int offset = K_block_size/2; offset > 0; offset /= 2) {
red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
}
int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
int ws_idx = out_idx / (N / ws_num);
if (threadIdx.x == 0)
dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
}
+3
View File
@@ -0,0 +1,3 @@
nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so
+13
View File
@@ -0,0 +1,13 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='bitlinear_cpp',
ext_modules=[
CUDAExtension('bitlinear_cuda', [
'bitnet_kernels.cu',
])
],
cmdclass={
'build_ext': BuildExtension
})
+100
View File
@@ -0,0 +1,100 @@
import json
import os
import re
import sys
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
import torch
from einops import rearrange
from safetensors.torch import save_file
import model
from pack_weight import convert_weight_int8_to_int2
@torch.inference_mode()
def convert_ts_checkpoint(
*,
input_path: str = "",
) -> None:
config = model.ModelArgs()
print(f"Model config {config.__dict__}")
def quant_weight_int8(weight):
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8)
new_scale = (1.0 / s).to(torch.bfloat16)
return new_weight, new_scale.reshape(1)
def quant_weight_fp16(weight):
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
new_weight = (weight * s).round().clamp(-1, 1) / s
return new_weight
def convert_int8_to_int2(weight):
return convert_weight_int8_to_int2(weight)
merged_result = torch.load(input_path, map_location="cpu", mmap=True)
int2_result = {}
fp16_result = {}
zero = torch.zeros(1).to(torch.bfloat16)
for key, value in merged_result.items():
if 'wqkv' in key:
wq = value[:config.dim]
wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim]
wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:]
wq_weight, wa_scale = quant_weight_int8(wq)
wk_weight, wb_scale = quant_weight_int8(wk)
wv_weight, wc_scale = quant_weight_int8(wv)
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
int2_result[key] = convert_int8_to_int2(wqkv_weight)
int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale
wq_weight = quant_weight_fp16(wq)
wk_weight = quant_weight_fp16(wk)
wv_weight = quant_weight_fp16(wv)
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
fp16_result[key] = wqkv_weight
elif 'w13' in key:
w1 = value[:config.ffn_dim]
w3 = value[config.ffn_dim:]
w1_weight, w1_scale = quant_weight_int8(w1)
w3_weight, w3_scale = quant_weight_int8(w3)
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(w13_weight)
int2_result[key.replace('weight', 'weight_scale')] = w13_scale
w1_weight = quant_weight_fp16(w1)
w3_weight = quant_weight_fp16(w3)
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
fp16_result[key] = w13_weight
elif 'w2' in key or 'wo' in key:
weight, scale = quant_weight_int8(value)
scale = torch.cat([scale, zero, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(weight)
int2_result[key.replace('weight', 'weight_scale')] = scale
weight = quant_weight_fp16(value)
fp16_result[key] = weight
else:
int2_result[key] = value.clone()
fp16_result[key] = value.clone()
output_dir = os.path.dirname(input_path)
print(f"Saving checkpoint to {output_dir}/model_state_int2.pt")
torch.save(int2_result, f"{output_dir}/model_state_int2.pt")
print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt")
torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.')
parser.add_argument('--input', type=str)
args = parser.parse_args()
convert_ts_checkpoint(
input_path=args.input,
)
+116
View File
@@ -0,0 +1,116 @@
import re
import torch
from pathlib import Path
from safetensors.torch import load_file
from einops import rearrange
from dataclasses import dataclass
from typing import Optional
transformer_configs = {
"2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912),
}
@dataclass
class ModelArgs:
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
config = [k for k in transformer_configs if k in name.upper() or k in name]
assert len(config) == 1, f"Unknown model name: {name}"
return cls(**transformer_configs[config[0]])
def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2)
def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2)
def convert_back(
safetensors_path: str,
output_file: str,
model_name: Optional[str] = None,
):
st_dict = load_file(safetensors_path)
cfg = ModelArgs.from_name(model_name)
print(f"Using model configurations: {cfg}")
recovered: dict = {}
for layer in range(cfg.n_layer):
base = f"model.layers.{layer}."
wq = st_dict[f"{base}self_attn.q_proj.weight"]
wk = st_dict[f"{base}self_attn.k_proj.weight"]
wv = st_dict[f"{base}self_attn.v_proj.weight"]
wq = invert_convert_q(wq, cfg)
wk = invert_convert_k(wk, cfg)
wqkv = torch.cat([wq, wk, wv], dim=0)
recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv
recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"]
recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"]
recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"]
recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"]
recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"]
gate = st_dict[f"{base}mlp.gate_proj.weight"]
up = st_dict[f"{base}mlp.up_proj.weight"]
w13 = torch.cat([gate, up], dim=0)
recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13
recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"]
recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"]
recovered["output.weight"] = st_dict["model.embed_tokens.weight"]
recovered["norm.weight"] = st_dict["model.norm.weight"]
print(f"Saving to {output_file}")
torch.save(recovered, output_file)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
parser.add_argument(
"--safetensors_file", type=str, required=True,
help="Path to input .safetensors file"
)
parser.add_argument(
"--output", type=str, default="./checkpoints/model_state.pt",
help="Path to output .pt file"
)
parser.add_argument(
"--model_name", type=str, default="2B",
help="Model configuration name to use (e.g. 2B)"
)
args = parser.parse_args()
convert_back(
safetensors_path=args.safetensors_file,
output_file=args.output,
model_name=args.model_name,
)
+359
View File
@@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import readline # type: ignore # noqa
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union
import fire
import model as fast
import torch
from stats import Stats
from tokenizer import Tokenizer, ChatFormat
import sample_utils
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
@dataclass
class GenArgs:
gen_length: int = 32
gen_bsz: int = 1
prompt_length: int = 64
use_sampling: bool = False
temperature: float = 0.8
top_p: float = 0.9
class FastGen:
GRAPH_WARMUPS: int = 1
tokenizer: Tokenizer
@staticmethod
def build(
ckpt_dir: str,
gen_args: GenArgs,
device: Union[torch.device, str],
tokenizer_path: Optional[str] = None,
num_layers: int = 13,
use_full_vocab: bool = False,
) -> "FastGen":
"""
Load a Llama or Code Llama checkpoint and return a new
generator for this model.
"""
start_time = time.time()
model_args_prefill = fast.ModelArgs(use_kernel=False)
model_args_decode = fast.ModelArgs(use_kernel=True)
tokenizer = Tokenizer("./tokenizer.model")
torch.set_default_device(device)
torch.set_default_dtype(torch.bfloat16)
prefill_model = fast.Transformer(model_args_prefill)
decode_model = fast.Transformer(model_args_decode)
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
decode_model.load_state_dict(int2_checkpoint, strict=True)
torch.cuda.synchronize()
print(f"loaded model in {time.time() - start_time:.2f} seconds")
start_time = time.time()
return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer)
def __init__(
self,
args: GenArgs,
model_args: fast.ModelArgs,
prefill_model: fast.Transformer,
decode_model: fast.Transformer,
tokenizer: Tokenizer,
):
self.gen_args = args
self.max_seq_length = args.prompt_length + args.gen_length
self.model_args = model_args
# self.model = model
self.prefill_model = prefill_model
self.decode_model = decode_model
self.tokenizer = tokenizer
self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None
self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None
self._cache = None
start_time = time.time()
self._prefill_compile_model = self.compile_prefill()
self._generate_compile_model = self.compile_generate()
print(f"compiled model in {time.time() - start_time:.2f} seconds")
def compile_prefill(self):
if self._cache is None:
self._cache = fast.make_cache(
args=self.model_args,
length=self.gen_args.gen_bsz * self.max_seq_length,
)
seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
bias = AttnBias.from_seqlens(
q_seqlen=seq_lens,
kv_seqlen=seq_lens,
kv_padding=self.max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda()
self._prefill_inputs = (tokens, bias)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
_ = self.prefill_model.forward_with_attn_bias(
token_values=self._prefill_inputs[0],
attn_bias=self._prefill_inputs[1],
cache=self._cache,
)
torch.cuda.current_stream().wait_stream(s)
self._prefill_cuda_graph = torch.cuda.CUDAGraph()
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs):
self._prefill_logits = self.prefill_model.forward_with_attn_bias(
token_values=self._prefill_inputs[0],
attn_bias=self._prefill_inputs[1],
cache=self._cache,
)
def replay(tokens, seq_lens=None):
self._prefill_inputs[0].copy_(tokens)
if seq_lens is not None:
self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
self._prefill_cuda_graph.replay()
torch.cuda.synchronize()
return self._prefill_logits
return replay
def compile_generate(self):
if self._cache is None:
self._cache = fast.make_cache(
args=self.model_args,
length=self.gen_args.gen_bsz * self.max_seq_length,
)
seq_lens = [1 for _ in range(self.gen_args.gen_bsz)]
kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
bias = AttnBias.from_seqlens(
q_seqlen=seq_lens,
kv_seqlen=kv_seq_lens,
kv_padding=self.max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda()
self._generate_inputs = (tokens, bias)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
_ = self.decode_model.forward_with_attn_bias(
token_values=self._generate_inputs[0],
attn_bias=self._generate_inputs[1],
cache=self._cache,
)
torch.cuda.current_stream().wait_stream(s)
self._generate_cuda_graph = torch.cuda.CUDAGraph()
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs):
self._generate_logits = self.decode_model.forward_with_attn_bias(
token_values=self._generate_inputs[0],
attn_bias=self._generate_inputs[1],
cache=self._cache,
)
def replay(tokens, seq_lens):
self._generate_inputs[0].copy_(tokens)
self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
self._generate_cuda_graph.replay()
return self._generate_logits
return replay
@torch.inference_mode()
def generate_all(
self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool
) -> Tuple[Stats, list[list[int]]]:
bs = len(prompts)
prompt_lens = [len(p) for p in prompts]
padded_prompt_lens = [self.gen_args.prompt_length] * bs
max_prompt_length = max(prompt_lens)
gen_length = self.gen_args.gen_length
max_seq_length = max_prompt_length + gen_length
print(max_prompt_length, gen_length)
bias = AttnBias.from_seqlens(
q_seqlen=padded_prompt_lens,
kv_seqlen=prompt_lens,
kv_padding=max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
# Input tensors to the cuda graph
kv_seqlen = bias.k_seqinfo.seqlen
prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts]
tokens = torch.IntTensor(sum(prompts, [])).cuda()
out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
stats = Stats()
torch.cuda.synchronize()
stats.phase("prefill" if use_cuda_graphs else "total")
# stats.phase("total")
output = self._prefill_compile_model(tokens, None)
logits = output[kv_seqlen - 1, :]
logits = logits.view(bs, self.model_args.vocab_size)
if use_sampling:
temp = 0.7
top_p = 0.95
probs = torch.softmax(logits / temp, dim=-1)
next_token = sample_utils.top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(bs)
out_tokens[0, :] = next_token
torch.cuda.synchronize()
stats.phase("decode" if use_cuda_graphs else "total")
eos_id = self.tokenizer.eot_id
for niter in range(1, gen_length):
kv_seqlen.add_(kv_seqlen < max_seq_length)
output = self._generate_compile_model(next_token, kv_seqlen)
logits = output.view(bs, self.model_args.vocab_size)
if use_sampling:
temp = 0.7
top_p = 0.95
probs = torch.softmax(logits / temp, dim=-1)
next_token = sample_utils.top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(bs)
out_tokens[niter, :] = next_token
if next_token.eq(eos_id).any():
break
torch.cuda.synchronize()
stats.end_phase(tokens=niter * bs)
def trim_answer(prompt_len, tokens):
# print(prompt, tokens)
"""Trim the answer to end it on an eos token."""
tokens = tokens[: max_seq_length - prompt_len]
eos_id = self.tokenizer.eot_id
if eos_id in tokens:
return tokens[: tokens.index(eos_id) + 1]
else:
return tokens
answers = [
trim_answer(prompt_len, answer)
for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist())
]
return stats, answers
def get_prompts(interactive: bool) -> Iterable[list[str]]:
if interactive:
while True:
try:
prompts = input("enter prompt: ").split("\n")
except EOFError:
print("exiting")
sys.exit(0)
yield prompts
else:
yield [
"Hello, my name is",
]
def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False):
local_rank = 0
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)
g = FastGen.build(ckpt_dir, GenArgs(), device)
if chat_format:
g.tokenizer = ChatFormat(g.tokenizer)
for prompts in get_prompts(interactive):
# prompts = [f"{prompt}\n" for prompt in prompts]
if chat_format:
# prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts]
tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts]
else:
tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts]
print(tokens)
stats, out_tokens = g.generate_all(
tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling,
)
for i, prompt in enumerate(prompts):
print(f"> {prompt}")
answer = g.tokenizer.decode(out_tokens[i])
print(answer)
print("---------------")
for phase_stats in stats.phases:
print(phase_stats.show())
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if __name__ == "__main__":
fire.Fire(main)
Executable
+366
View File
@@ -0,0 +1,366 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from xformers.ops import RMSNorm, fmha, rope_padded
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
import ctypes
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
def bitnet_int8xint2_linear(input0, input1, s, ws):
out_shape = list(input0.shape)
out_shape[-1] = input1.shape[0]
stream = torch.cuda.current_stream()
M = input0.shape[0]
if len(out_shape) == 3:
M *= input0.shape[1]
N = input1.shape[0]
K = input1.shape[1] * 4
ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
return ret
@dataclass
class ModelArgs:
dim: int = 2560
n_layers: int = 30
n_heads: int = 20
n_kv_heads: int = 5
vocab_size: int = 128256
ffn_dim: int = 6912
norm_eps: float = 1e-5
rope_theta: float = 500000.0
use_kernel: bool = False
LayerCache = Tuple[torch.Tensor, torch.Tensor]
class BitLinearKernel(nn.Module):
in_features: int
out_features: int
weight: torch.Tensor
weight_scale: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
@torch.compile
def quant_input(self, input):
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
return (input * s).round().clamp(-128, 127).to(torch.int8), s
def forward(self, input):
input, s = self.quant_input(input)
return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
class BitLinear(nn.Linear):
@torch.compile
def quant_input(self, input):
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
return (input * s).round().clamp(-128, 127) / s
def forward(self, input):
input = self.quant_input(input)
return F.linear(input, self.weight)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
norm_eps: float,
use_kernel: bool,
):
super().__init__()
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_local_heads = n_heads
self.n_local_kv_heads = n_kv_heads
Linear = BitLinearKernel if use_kernel else BitLinear
self.wqkv = Linear(
dim,
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
bias=False,
)
self.wo = Linear(
self.n_local_heads * head_dim,
dim,
bias=False,
)
self.attn_sub_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
) -> torch.Tensor:
xqkv = self.wqkv(x)
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
xk, xv = xkv.chunk(2, 1)
output_shape = xq.shape
heads_per_group = self.n_local_heads // self.n_local_kv_heads
xq = xq.view(
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
)
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
# xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
# xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
cache_k, cache_v = cache
xq = rope_padded(
xq=xq,
xk=xk,
xv=xv,
cache_k=cache_k,
cache_v=cache_v,
attn_bias=attn_bias,
theta=self.rope_theta,
)
output = fmha.memory_efficient_attention_forward(
xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
)
output = output.reshape(output_shape)
output = self.attn_sub_norm(output)
output = self.wo(output)
return output
@torch.compile
def squared_relu(x: torch.Tensor) -> torch.Tensor:
return F.relu(x) ** 2
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
norm_eps: float,
use_kernel: bool,
):
super().__init__()
Linear = BitLinearKernel if use_kernel else BitLinear
self.w13 = Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = Linear(
hidden_dim,
dim,
bias=False,
)
self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x13 = self.w13(x)
x1, x3 = x13.chunk(2, -1)
inner = self.ffn_sub_norm(squared_relu(x1) * x3)
output = self.w2(inner)
return output
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
head_dim = args.dim // args.n_heads
if args.n_kv_heads is not None:
n_kv_heads = args.n_kv_heads
else:
n_kv_heads = args.n_heads
assert args.n_heads % n_kv_heads == 0
self.attention = Attention(
dim=args.dim,
head_dim=head_dim,
n_heads=args.n_heads,
n_kv_heads=n_kv_heads,
rope_theta=args.rope_theta,
norm_eps=args.norm_eps,
use_kernel=args.use_kernel,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=args.ffn_dim,
norm_eps=args.norm_eps,
use_kernel=args.use_kernel,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
) -> torch.Tensor:
h = x + self.attention.forward(
self.attention_norm(x),
cache,
attn_bias,
)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.tok_embeddings = nn.Embedding(
num_embeddings=args.vocab_size,
embedding_dim=args.dim,
)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
@torch.no_grad()
def forward_with_attn_bias(
self,
token_values: torch.Tensor,
attn_bias: AttnBias,
cache: list[LayerCache],
) -> torch.Tensor:
h = self.tok_embeddings(token_values)
for i, layer in enumerate(self.layers):
h = layer(h, cache[i], attn_bias)
logits = self.output(self.norm(h))
return logits.float()
def forward(
self,
token_values: torch.Tensor,
token_lengths: torch.Tensor,
start_pos: torch.Tensor,
cache: list[LayerCache],
kv_padding: int,
) -> torch.Tensor:
attn_bias = AttnBias.from_seqlens(
q_seqlen=token_lengths.tolist(),
kv_seqlen=(start_pos + token_lengths).tolist(),
kv_padding=kv_padding,
)
return self.forward_with_attn_bias(token_values, attn_bias, cache)
def make_cache(
args: ModelArgs,
length: int,
device: Optional[Union[str, torch.device]] = None,
n_layers: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> list[LayerCache]:
"""
Allocate a cache to be used with the Transformer module.
Args:
args (ModelArgs): the model configuration.
length (int): per layer cache size.
It is usually budgeted as ``max_batch * max_seq``
device (torch.device, optional): the device on which
the cache should be allocated.
n_layers (int, optional): the number of layers to
allocate a cache for (defaults to the model
settings).
dtype (torch.dtype, optional): the dtype to use for
cache entries (defaults to the default dtype).
Returns:
The cache object to pass to ``Tranformer.forward``.
"""
head_dim = args.dim // args.n_heads
n_kv_heads = args.n_kv_heads
if n_kv_heads is None:
n_kv_heads = args.n_heads
n_local_kv_heads = n_kv_heads
if n_layers is None:
n_layers = args.n_layers
shape = (1, length, n_local_kv_heads, 1, head_dim)
heads_per_group = args.n_heads // n_kv_heads
expansion = (-1, -1, -1, heads_per_group, -1)
return [
(
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
)
for _ in range(n_layers)
]
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
"""
Take a prefix view of a larger cache.
The original cache object remains of identical size and valid
after the shrinked alias has been used. This function is useful
when a cache was allocated for a larger batch size than what is
necessary.
Args:
cache: the cache to take a view in.
length (int): the desired length
Returns:
A view in the input cache object.
"""
if len(cache) > 0:
assert cache[0][0].shape[1] >= length
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
+98
View File
@@ -0,0 +1,98 @@
import torch
import numpy as np
def B_global_16x32_to_shared_load_16x32_layout(i, j):
"""
stride * 8 * (tx // HALF_WARP_expr)
+ (tx % 8) * stride
+ 16 * ((tx % HALF_WARP_expr) // 8)
"""
thread_id = i * 2 + j // 16
row = (thread_id // 16) * 8 + (thread_id % 8)
col = (j % 16) + 16 * ((thread_id % 16) // 8)
return row, col
def permutate_weight_fastest(weight):
wmma_n = 16
wmma_k = 32
N = weight.shape[0]
K = weight.shape[1]
# Create a lookup table for the permutation
mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int)
for ii in range(wmma_n):
for jj in range(wmma_k):
mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj)
# Reshape weight for the final format
permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8")
# Use advanced indexing for the entire operation
i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis]
j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis]
# Create the source indices
src_i = i_indices * wmma_n + mapping[:, :, 0]
src_j = j_indices * wmma_k + mapping[:, :, 1]
# Extract and reshape in one go
permutated_weight = weight[src_i, src_j]
return permutated_weight
def compress_int2_to_int8(int2_weight):
int8_weight = np.zeros(
(*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8
)
for j in range(int2_weight.shape[-1] // 4):
for k in range(4):
int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2)
return int8_weight
def interleave_weight_int8(qweight, nbits=2):\
# reinterpret the data type of qweight to int32
# shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30]
# index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride # 4
elems_per_group = bits_stride // nbits # 4
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
return new_qweight.view(np.int8)
def convert_weight_int8_to_int2(weight):
N = weight.shape[0]
K = weight.shape[1]
weight = weight+2
weight = weight.cpu().numpy()
# print(weight)
# print(torch.max(weight), torch.min(weight))
# permutated_weight_slow = permutate_weight(weight)
permutated_weight = permutate_weight_fastest(weight)
# assert np.all(permutated_weight_slow == permutated_weight)
# print("Permutation is correct")
compressed_weight = compress_int2_to_int8(permutated_weight)
interleaved_weight = interleave_weight_int8(compressed_weight, 2)
ret = torch.from_numpy(interleaved_weight)
ret = torch.reshape(ret, (N, K // 4))
return ret
+9
View File
@@ -0,0 +1,9 @@
fire
sentencepiece
torch>=2.2.0
xformers>=0.0.22
tiktoken
blobfile
flask
einops
transformers
+31
View File
@@ -0,0 +1,31 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
@torch.compile
def top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): probability distribution tensor.
p (float): probability threshold for top-p sampling.
Returns:
torch.Tensor: sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative
probability mass exceeds the threshold p. The distribution is
renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
Executable
+57
View File
@@ -0,0 +1,57 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import time
from dataclasses import dataclass
from typing import Optional
@dataclass
class PhaseStats:
name: str
tokens: int
time: float
def show(self) -> str:
tps = self.tokens / self.time
return (
f"[{self.name}] "
f"generated tokens: {self.tokens}"
f" - total time: {self.time:.3f}s"
f" - {tps:.1f} tokens per second"
)
class Stats:
"""
Generation stats, split by phases.
"""
def __init__(self):
self.phases = []
self.current = None
def end_phase(self, tokens: int, now: Optional[float] = None):
"""Terminate the current phase."""
if self.current is None:
return
if now is None:
now = time.time()
cname, ctokens, ctime = self.current
stats = PhaseStats(
name=cname,
tokens=tokens - ctokens,
time=now - ctime,
)
self.phases.append(stats)
def phase(self, name: str, tokens: int = 0):
"""
Start a new phase, and terminate the current one,
if one is ongoing.
"""
now = time.time()
self.end_phase(tokens, now)
self.current = (name, tokens, now)
+99
View File
@@ -0,0 +1,99 @@
import torch
from torch.utils import benchmark
from torch import nn
from pack_weight import convert_weight_int8_to_int2
from torch.profiler import profile, record_function, ProfilerActivity
import ctypes
import numpy as np
# set all seed
torch.manual_seed(42)
np.random.seed(42)
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
def bitnet_int8xint2_linear(input0, input1, s, ws, ret):
out_shape = list(input0.shape)
out_shape[-1] = input1.shape[0]
stream = torch.cuda.current_stream()
M = input0.shape[0]
if len(out_shape) == 3:
M *= input0.shape[1]
N = input1.shape[0]
K = input1.shape[1] * 4
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
return ret
if __name__ == '__main__':
test_list = [
(2560, 2560),
(3840, 2560),
(13824, 2560),
(2560, 6912) ,
(3200, 3200),
(4800, 3200),
(3200, 10240),
(20480, 3200),
]
for N,K in test_list:
weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda')
weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda')
weight_compressed = convert_weight_int8_to_int2(weight).to('cuda')
for i in range(1):
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
input0_bf16 = input0.to(torch.bfloat16)
input_np = input0.cpu().to(torch.int32).numpy()
weight_np = weight.cpu().to(torch.int32).T.numpy()
out_np = np.matmul(input_np,weight_np)
out_np = torch.tensor(out_np).cuda().to(torch.bfloat16)
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
print(f'custom == np {torch.all(out==out_np)}')
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
input0_fp16 = input0.to(torch.float16)
input0_bf16 = input0.to(torch.bfloat16)
weight_fp16 = weight.to(torch.float16).T
weight_bf16 = weight.to(torch.bfloat16).T
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
t0 = benchmark.Timer(
stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)",
setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear",
num_threads=1,
)
t1 = benchmark.Timer(
stmt="torch.matmul(input0_bf16,weight_bf16)",
setup="from __main__ import input0_bf16, weight_bf16",
num_threads=1,
)
time0 = t0.timeit(50)
time1 = t1.timeit(50)
print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us')
# activities = [ ProfilerActivity.CUDA,
# # ProfilerActivity.CPU
# ]
# sort_by_keyword = 'cuda' + "_time_total"
# with profile(activities=activities, record_shapes=True) as prof:
# with record_function("model_inference1"):
# for _ in range(10):
# bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
# torch.matmul(input0_fp16,weight_fp16)
# torch.matmul(input0_bf16,weight_bf16)
# print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15))
+128000
View File
File diff suppressed because it is too large Load Diff
+257
View File
@@ -0,0 +1,257 @@
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
cast,
Collection,
Dict,
Iterator,
List,
Literal,
Sequence,
TypedDict,
Union,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
Role = Literal["system", "user", "assistant"]
class Message(TypedDict):
role: Role
content: str
Dialog = Sequence[Message]
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
logger.info(f"Reloaded tiktoken model from {model_path}")
self.n_words: int = self.model.n_vocab
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.pad_id: int = self.n_words - 1
self.stop_tokens = {
self.special_tokens["<|end_of_text|>"],
self.special_tokens["<|eot_id|>"],
}
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
assert type(s) is str
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
class ChatFormat:
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.eot_id = tokenizer.special_tokens["<|eot_id|>"]
def decode(self, tokens: List[int]) -> str:
# Decode the tokens to a string.
decoded_str = self.tokenizer.decode(tokens)
# Remove the special tokens from the decoded string.
decoded_str = decoded_str.replace("<|eot_id|>", "")
return decoded_str
def encode_header(self, message: Message) -> List[int]:
tokens = []
if message["role"] == "system":
tokens.extend(self.tokenizer.encode("System: ", bos=False, eos=False))
elif message["role"] == "user":
tokens.extend(self.tokenizer.encode("User: ", bos=False, eos=False))
elif message["role"] == "assistant":
tokens.extend(self.tokenizer.encode("Assistant: ", bos=False, eos=False))
else:
raise NotImplementedError(f"Role {message['role']} not implemented.")
# tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
# tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
# tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
# tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_message(self, message: Message, return_target=False) -> List[int]:
tokens, targets = [], []
headers = self.encode_header(message)
contents = self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
contents.append(self.tokenizer.special_tokens["<|eot_id|>"])
tokens = headers + contents
if message["role"] == "assistant":
targets = [-1] * len(headers) + contents
else:
targets = [-1] * len(tokens)
if return_target:
return tokens, targets
return tokens, None
def encode_dialog_prompt(self, dialog: Dialog, completion=False, return_target=False) -> List[int]:
tokens = [self.tokenizer.special_tokens["<|begin_of_text|>"]]
targets = [-1]
for message in dialog:
_tokens, _targets = self.encode_message(message, return_target=return_target)
tokens.extend(_tokens)
if _targets is not None:
targets.extend(_targets)
# Add the start of an assistant message for the model to complete.
if completion:
tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
if return_target:
return tokens, targets
return tokens