mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-03 11:20:36 +00:00
Merge branch 'microsoft:main' into add-falcon-e-final
This commit is contained in:
@@ -34,6 +34,7 @@ nppBackup
|
||||
|
||||
# Models
|
||||
models/*
|
||||
gpu/checkpoints/*
|
||||
|
||||
# Python
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
[<img src="./assets/header_model_release.png" alt="BitNet Model on Hugging Face" width="800"/>](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
|
||||
|
||||
Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or [build and run](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) it on your own CPU.
|
||||
Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).
|
||||
|
||||
bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next).
|
||||
bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).
|
||||
|
||||
The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
|
||||
|
||||
@@ -22,7 +22,8 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
|
||||
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
|
||||
|
||||
## What's New:
|
||||
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) 
|
||||
- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) 
|
||||
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
|
||||
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
|
||||
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
|
||||
- 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
|
||||
|
||||
Executable
+93
@@ -0,0 +1,93 @@
|
||||
# BitNet Inference Kernel
|
||||
|
||||
This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
|
||||
|
||||
## Features
|
||||
|
||||
- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
|
||||
- Custom CUDA kernels with low-latency execution
|
||||
- Optimizations for memory access, decoding, and compute throughput
|
||||
|
||||
## Usage
|
||||
|
||||
Installation and kernel performance tests:
|
||||
|
||||
```bash
|
||||
# (Recommended) Create a new conda environment
|
||||
conda create --name bitnet-gpu "python<3.13"
|
||||
conda activate bitnet-gpu
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Build the kernel
|
||||
cd bitnet_kernels
|
||||
bash compile.sh
|
||||
cd ..
|
||||
|
||||
# Run performance tests
|
||||
python test.py
|
||||
```
|
||||
|
||||
End-to-end inference:
|
||||
|
||||
```bash
|
||||
# Download and convert the BitNet-b1.58-2B model
|
||||
mkdir checkpoints
|
||||
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
|
||||
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
|
||||
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
|
||||
rm ./checkpoints/model_state.pt
|
||||
|
||||
# Inference
|
||||
python3 ./generate.py ./checkpoints/ --interactive --chat_format
|
||||
```
|
||||
|
||||
## Optimizations
|
||||
|
||||
### Weight Permutation
|
||||
|
||||
The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
|
||||
|
||||
Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
|
||||
|
||||
See `convert_checkpoint.py` for details.
|
||||
|
||||
### Fast Decoding
|
||||
|
||||
Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
|
||||
```
|
||||
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
|
||||
```
|
||||
|
||||
This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
|
||||
|
||||
### `dp4a` Instruction
|
||||
|
||||
We use the `dp4a` instruction to accelerate low-precision dot product operations.
|
||||
|
||||
This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
|
||||
|
||||
It significantly improves GEMV throughput when processing quantized weights and activations.
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
Kernel performance (tested on NVIDIA A100 40GB GPU):
|
||||
|
||||
| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|
||||
|---------------------|-------------------|-------------------|----------------------|
|
||||
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
|
||||
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
|
||||
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
|
||||
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
|
||||
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
|
||||
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
|
||||
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
|
||||
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |
|
||||
|
||||
Generation throughput:
|
||||
|
||||
| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|
||||
|---|---|---|
|
||||
| 10.9 | 213.3 | 19.6 |
|
||||
@@ -0,0 +1,37 @@
|
||||
#include "bitnet_kernels.h"
|
||||
|
||||
extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
|
||||
if (M == 1 && N == 3840 && K == 2560){
|
||||
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if (M == 1 && N == 2560 && K == 2560){
|
||||
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if (M == 1 && N == 13824 && K == 2560){
|
||||
ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if (M == 1 && N == 2560 && K == 6912){
|
||||
ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 4800 && K == 3200){
|
||||
ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 3200 && K == 3200){
|
||||
ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 20480 && K == 3200){
|
||||
ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 3200 && K == 10240){
|
||||
ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 5120 && K == 27648){
|
||||
ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else if(M == 1 && N == 55296 && K == 5120){
|
||||
ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
||||
}
|
||||
else{
|
||||
std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <math_constants.h>
|
||||
#include <math.h>
|
||||
#include <mma.h>
|
||||
#include <iostream>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
|
||||
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
|
||||
#define TVM_ENABLE_L2_PREFETCH 1
|
||||
#else
|
||||
#define TVM_ENABLE_L2_PREFETCH 0
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
|
||||
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
|
||||
#else
|
||||
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
|
||||
#endif
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
|
||||
{
|
||||
// convert 8 int2b_t to 8 int8b_t -> 2 int32
|
||||
uint *i8s = reinterpret_cast<uint *>(_i8s);
|
||||
|
||||
// i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
|
||||
uint const i2s = *_i2s;
|
||||
|
||||
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
|
||||
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
|
||||
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (N / 4); i++)
|
||||
{
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(i8s[i])
|
||||
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
|
||||
i8s[i] = __vsubss4(i8s[i], 0x02020202);
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
|
||||
__global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
|
||||
constexpr int K_per_loop = 16;
|
||||
constexpr int wmma_K = 32;
|
||||
constexpr int wmma_N = 16;
|
||||
int in_thread_C_local[1];
|
||||
signed char A_local[K_per_loop];
|
||||
int B_reshape_local[1];
|
||||
signed char B_decode_local[K_per_loop];
|
||||
int red_buf0[1];
|
||||
in_thread_C_local[0] = 0;
|
||||
#pragma unroll
|
||||
for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
|
||||
*(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
|
||||
B_reshape_local[0] = *(int*)(B +
|
||||
(((int)blockIdx.x) * N_block_size * K / 4) +
|
||||
(k_0 * K_block_size * K_per_loop * wmma_N / 4) +
|
||||
((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
|
||||
((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
|
||||
((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
|
||||
((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
|
||||
);
|
||||
decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
|
||||
#pragma unroll
|
||||
for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
|
||||
in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
|
||||
}
|
||||
}
|
||||
red_buf0[0] = in_thread_C_local[0];
|
||||
#pragma unroll
|
||||
for (int offset = K_block_size/2; offset > 0; offset /= 2) {
|
||||
red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
|
||||
}
|
||||
int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
|
||||
int ws_idx = out_idx / (N / ws_num);
|
||||
if (threadIdx.x == 0)
|
||||
dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so
|
||||
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='bitlinear_cpp',
|
||||
ext_modules=[
|
||||
CUDAExtension('bitlinear_cuda', [
|
||||
'bitnet_kernels.cu',
|
||||
])
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
Executable
+100
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from safetensors.torch import save_file
|
||||
import model
|
||||
from pack_weight import convert_weight_int8_to_int2
|
||||
|
||||
@torch.inference_mode()
|
||||
def convert_ts_checkpoint(
|
||||
*,
|
||||
input_path: str = "",
|
||||
) -> None:
|
||||
|
||||
config = model.ModelArgs()
|
||||
print(f"Model config {config.__dict__}")
|
||||
|
||||
def quant_weight_int8(weight):
|
||||
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
|
||||
new_weight = (weight * s).round().clamp(-1, 1).to(torch.int8)
|
||||
new_scale = (1.0 / s).to(torch.bfloat16)
|
||||
return new_weight, new_scale.reshape(1)
|
||||
|
||||
def quant_weight_fp16(weight):
|
||||
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
|
||||
new_weight = (weight * s).round().clamp(-1, 1) / s
|
||||
return new_weight
|
||||
|
||||
def convert_int8_to_int2(weight):
|
||||
return convert_weight_int8_to_int2(weight)
|
||||
|
||||
merged_result = torch.load(input_path, map_location="cpu", mmap=True)
|
||||
int2_result = {}
|
||||
fp16_result = {}
|
||||
zero = torch.zeros(1).to(torch.bfloat16)
|
||||
for key, value in merged_result.items():
|
||||
if 'wqkv' in key:
|
||||
wq = value[:config.dim]
|
||||
wk = value[config.dim:config.dim // config.n_heads * config.n_kv_heads + config.dim]
|
||||
wv = value[config.dim // config.n_heads * config.n_kv_heads + config.dim:]
|
||||
wq_weight, wa_scale = quant_weight_int8(wq)
|
||||
wk_weight, wb_scale = quant_weight_int8(wk)
|
||||
wv_weight, wc_scale = quant_weight_int8(wv)
|
||||
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
|
||||
wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
|
||||
int2_result[key] = convert_int8_to_int2(wqkv_weight)
|
||||
int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale
|
||||
|
||||
wq_weight = quant_weight_fp16(wq)
|
||||
wk_weight = quant_weight_fp16(wk)
|
||||
wv_weight = quant_weight_fp16(wv)
|
||||
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
|
||||
fp16_result[key] = wqkv_weight
|
||||
elif 'w13' in key:
|
||||
w1 = value[:config.ffn_dim]
|
||||
w3 = value[config.ffn_dim:]
|
||||
w1_weight, w1_scale = quant_weight_int8(w1)
|
||||
w3_weight, w3_scale = quant_weight_int8(w3)
|
||||
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
|
||||
w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
|
||||
int2_result[key] = convert_int8_to_int2(w13_weight)
|
||||
int2_result[key.replace('weight', 'weight_scale')] = w13_scale
|
||||
|
||||
w1_weight = quant_weight_fp16(w1)
|
||||
w3_weight = quant_weight_fp16(w3)
|
||||
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
|
||||
fp16_result[key] = w13_weight
|
||||
elif 'w2' in key or 'wo' in key:
|
||||
weight, scale = quant_weight_int8(value)
|
||||
scale = torch.cat([scale, zero, zero, zero], dim=0)
|
||||
int2_result[key] = convert_int8_to_int2(weight)
|
||||
int2_result[key.replace('weight', 'weight_scale')] = scale
|
||||
|
||||
weight = quant_weight_fp16(value)
|
||||
fp16_result[key] = weight
|
||||
else:
|
||||
int2_result[key] = value.clone()
|
||||
fp16_result[key] = value.clone()
|
||||
|
||||
output_dir = os.path.dirname(input_path)
|
||||
print(f"Saving checkpoint to {output_dir}/model_state_int2.pt")
|
||||
torch.save(int2_result, f"{output_dir}/model_state_int2.pt")
|
||||
|
||||
print(f"Saving checkpoint to {output_dir}/model_state_fp16.pt")
|
||||
torch.save(fp16_result, f"{output_dir}/model_state_fp16.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Convert TorchScale checkpoint.')
|
||||
parser.add_argument('--input', type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_ts_checkpoint(
|
||||
input_path=args.input,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
import re
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
from einops import rearrange
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
transformer_configs = {
|
||||
"2B": dict(n_layer=30, n_head=20, dim=2560, vocab_size=128256, n_local_heads=5, intermediate_size=6912),
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 4096
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = None
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = n_hidden + (256 - n_hidden % 256) if n_hidden % 256 else n_hidden
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
config = [k for k in transformer_configs if k in name.upper() or k in name]
|
||||
assert len(config) == 1, f"Unknown model name: {name}"
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
def invert_convert_q(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
|
||||
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_head, l=2)
|
||||
|
||||
def invert_convert_k(w: torch.Tensor, config: ModelArgs) -> torch.Tensor:
|
||||
return rearrange(w, '(h l d) i -> (h d l) i', h=config.n_local_heads, l=2)
|
||||
|
||||
def convert_back(
|
||||
safetensors_path: str,
|
||||
output_file: str,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
st_dict = load_file(safetensors_path)
|
||||
|
||||
cfg = ModelArgs.from_name(model_name)
|
||||
print(f"Using model configurations: {cfg}")
|
||||
|
||||
recovered: dict = {}
|
||||
|
||||
for layer in range(cfg.n_layer):
|
||||
base = f"model.layers.{layer}."
|
||||
|
||||
wq = st_dict[f"{base}self_attn.q_proj.weight"]
|
||||
wk = st_dict[f"{base}self_attn.k_proj.weight"]
|
||||
wv = st_dict[f"{base}self_attn.v_proj.weight"]
|
||||
|
||||
wq = invert_convert_q(wq, cfg)
|
||||
wk = invert_convert_k(wk, cfg)
|
||||
|
||||
wqkv = torch.cat([wq, wk, wv], dim=0)
|
||||
recovered[f"layers.{layer}.attention.wqkv.weight"] = wqkv
|
||||
|
||||
recovered[f"layers.{layer}.attention.wo.weight"] = st_dict[f"{base}self_attn.o_proj.weight"]
|
||||
|
||||
recovered[f"layers.{layer}.attention_norm.weight"] = st_dict[f"{base}input_layernorm.weight"]
|
||||
recovered[f"layers.{layer}.ffn_norm.weight"] = st_dict[f"{base}post_attention_layernorm.weight"]
|
||||
recovered[f"layers.{layer}.attention.attn_sub_norm.weight"] = st_dict[f"{base}self_attn.attn_sub_norm.weight"]
|
||||
recovered[f"layers.{layer}.feed_forward.ffn_sub_norm.weight"] = st_dict[f"{base}mlp.ffn_sub_norm.weight"]
|
||||
|
||||
gate = st_dict[f"{base}mlp.gate_proj.weight"]
|
||||
up = st_dict[f"{base}mlp.up_proj.weight"]
|
||||
w13 = torch.cat([gate, up], dim=0)
|
||||
recovered[f"layers.{layer}.feed_forward.w13.weight"] = w13
|
||||
|
||||
recovered[f"layers.{layer}.feed_forward.w2.weight"] = st_dict[f"{base}mlp.down_proj.weight"]
|
||||
|
||||
recovered["tok_embeddings.weight"] = st_dict["model.embed_tokens.weight"]
|
||||
recovered["output.weight"] = st_dict["model.embed_tokens.weight"]
|
||||
recovered["norm.weight"] = st_dict["model.norm.weight"]
|
||||
|
||||
print(f"Saving to {output_file}")
|
||||
torch.save(recovered, output_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
|
||||
parser.add_argument(
|
||||
"--safetensors_file", type=str, required=True,
|
||||
help="Path to input .safetensors file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="./checkpoints/model_state.pt",
|
||||
help="Path to output .pt file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name", type=str, default="2B",
|
||||
help="Model configuration name to use (e.g. 2B)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_back(
|
||||
safetensors_path=args.safetensors_file,
|
||||
output_file=args.output,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
Executable
+359
@@ -0,0 +1,359 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import readline # type: ignore # noqa
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import fire
|
||||
import model as fast
|
||||
import torch
|
||||
from stats import Stats
|
||||
from tokenizer import Tokenizer, ChatFormat
|
||||
import sample_utils
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenArgs:
|
||||
gen_length: int = 32
|
||||
gen_bsz: int = 1
|
||||
prompt_length: int = 64
|
||||
|
||||
use_sampling: bool = False
|
||||
temperature: float = 0.8
|
||||
top_p: float = 0.9
|
||||
|
||||
|
||||
class FastGen:
|
||||
GRAPH_WARMUPS: int = 1
|
||||
tokenizer: Tokenizer
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
gen_args: GenArgs,
|
||||
device: Union[torch.device, str],
|
||||
tokenizer_path: Optional[str] = None,
|
||||
num_layers: int = 13,
|
||||
use_full_vocab: bool = False,
|
||||
) -> "FastGen":
|
||||
"""
|
||||
Load a Llama or Code Llama checkpoint and return a new
|
||||
generator for this model.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
model_args_prefill = fast.ModelArgs(use_kernel=False)
|
||||
model_args_decode = fast.ModelArgs(use_kernel=True)
|
||||
tokenizer = Tokenizer("./tokenizer.model")
|
||||
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
prefill_model = fast.Transformer(model_args_prefill)
|
||||
decode_model = fast.Transformer(model_args_decode)
|
||||
|
||||
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
|
||||
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
|
||||
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
|
||||
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
|
||||
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
|
||||
decode_model.load_state_dict(int2_checkpoint, strict=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print(f"loaded model in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
|
||||
return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: GenArgs,
|
||||
model_args: fast.ModelArgs,
|
||||
prefill_model: fast.Transformer,
|
||||
decode_model: fast.Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
):
|
||||
self.gen_args = args
|
||||
self.max_seq_length = args.prompt_length + args.gen_length
|
||||
self.model_args = model_args
|
||||
# self.model = model
|
||||
self.prefill_model = prefill_model
|
||||
self.decode_model = decode_model
|
||||
self.tokenizer = tokenizer
|
||||
self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None
|
||||
self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None
|
||||
self._cache = None
|
||||
start_time = time.time()
|
||||
self._prefill_compile_model = self.compile_prefill()
|
||||
self._generate_compile_model = self.compile_generate()
|
||||
print(f"compiled model in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
def compile_prefill(self):
|
||||
|
||||
if self._cache is None:
|
||||
self._cache = fast.make_cache(
|
||||
args=self.model_args,
|
||||
length=self.gen_args.gen_bsz * self.max_seq_length,
|
||||
)
|
||||
|
||||
seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
|
||||
|
||||
bias = AttnBias.from_seqlens(
|
||||
q_seqlen=seq_lens,
|
||||
kv_seqlen=seq_lens,
|
||||
kv_padding=self.max_seq_length,
|
||||
)
|
||||
bias.q_seqinfo.to("cuda")
|
||||
bias.k_seqinfo.to("cuda")
|
||||
|
||||
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda()
|
||||
self._prefill_inputs = (tokens, bias)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
_ = self.prefill_model.forward_with_attn_bias(
|
||||
token_values=self._prefill_inputs[0],
|
||||
attn_bias=self._prefill_inputs[1],
|
||||
cache=self._cache,
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
self._prefill_cuda_graph = torch.cuda.CUDAGraph()
|
||||
recording_kwargs = {}
|
||||
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
|
||||
# In PyTorch 2.1+ and nightlies from late Aug 2023,
|
||||
# we can do this to maybe avoid watchdog-related crashes
|
||||
recording_kwargs["capture_error_mode"] = "thread_local"
|
||||
with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs):
|
||||
self._prefill_logits = self.prefill_model.forward_with_attn_bias(
|
||||
token_values=self._prefill_inputs[0],
|
||||
attn_bias=self._prefill_inputs[1],
|
||||
cache=self._cache,
|
||||
)
|
||||
|
||||
def replay(tokens, seq_lens=None):
|
||||
self._prefill_inputs[0].copy_(tokens)
|
||||
if seq_lens is not None:
|
||||
self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
|
||||
|
||||
self._prefill_cuda_graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return self._prefill_logits
|
||||
|
||||
return replay
|
||||
|
||||
def compile_generate(self):
|
||||
|
||||
if self._cache is None:
|
||||
self._cache = fast.make_cache(
|
||||
args=self.model_args,
|
||||
length=self.gen_args.gen_bsz * self.max_seq_length,
|
||||
)
|
||||
|
||||
seq_lens = [1 for _ in range(self.gen_args.gen_bsz)]
|
||||
kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
|
||||
|
||||
bias = AttnBias.from_seqlens(
|
||||
q_seqlen=seq_lens,
|
||||
kv_seqlen=kv_seq_lens,
|
||||
kv_padding=self.max_seq_length,
|
||||
)
|
||||
bias.q_seqinfo.to("cuda")
|
||||
bias.k_seqinfo.to("cuda")
|
||||
|
||||
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda()
|
||||
self._generate_inputs = (tokens, bias)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
_ = self.decode_model.forward_with_attn_bias(
|
||||
token_values=self._generate_inputs[0],
|
||||
attn_bias=self._generate_inputs[1],
|
||||
cache=self._cache,
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
self._generate_cuda_graph = torch.cuda.CUDAGraph()
|
||||
recording_kwargs = {}
|
||||
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
|
||||
# In PyTorch 2.1+ and nightlies from late Aug 2023,
|
||||
# we can do this to maybe avoid watchdog-related crashes
|
||||
recording_kwargs["capture_error_mode"] = "thread_local"
|
||||
with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs):
|
||||
self._generate_logits = self.decode_model.forward_with_attn_bias(
|
||||
token_values=self._generate_inputs[0],
|
||||
attn_bias=self._generate_inputs[1],
|
||||
cache=self._cache,
|
||||
)
|
||||
|
||||
def replay(tokens, seq_lens):
|
||||
self._generate_inputs[0].copy_(tokens)
|
||||
self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
|
||||
|
||||
self._generate_cuda_graph.replay()
|
||||
|
||||
return self._generate_logits
|
||||
|
||||
return replay
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_all(
|
||||
self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool
|
||||
) -> Tuple[Stats, list[list[int]]]:
|
||||
bs = len(prompts)
|
||||
prompt_lens = [len(p) for p in prompts]
|
||||
padded_prompt_lens = [self.gen_args.prompt_length] * bs
|
||||
max_prompt_length = max(prompt_lens)
|
||||
gen_length = self.gen_args.gen_length
|
||||
max_seq_length = max_prompt_length + gen_length
|
||||
print(max_prompt_length, gen_length)
|
||||
|
||||
bias = AttnBias.from_seqlens(
|
||||
q_seqlen=padded_prompt_lens,
|
||||
kv_seqlen=prompt_lens,
|
||||
kv_padding=max_seq_length,
|
||||
)
|
||||
bias.q_seqinfo.to("cuda")
|
||||
bias.k_seqinfo.to("cuda")
|
||||
|
||||
# Input tensors to the cuda graph
|
||||
kv_seqlen = bias.k_seqinfo.seqlen
|
||||
prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts]
|
||||
tokens = torch.IntTensor(sum(prompts, [])).cuda()
|
||||
out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
|
||||
|
||||
stats = Stats()
|
||||
torch.cuda.synchronize()
|
||||
stats.phase("prefill" if use_cuda_graphs else "total")
|
||||
# stats.phase("total")
|
||||
|
||||
output = self._prefill_compile_model(tokens, None)
|
||||
|
||||
logits = output[kv_seqlen - 1, :]
|
||||
logits = logits.view(bs, self.model_args.vocab_size)
|
||||
|
||||
if use_sampling:
|
||||
temp = 0.7
|
||||
top_p = 0.95
|
||||
probs = torch.softmax(logits / temp, dim=-1)
|
||||
next_token = sample_utils.top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
|
||||
next_token = next_token.reshape(bs)
|
||||
out_tokens[0, :] = next_token
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stats.phase("decode" if use_cuda_graphs else "total")
|
||||
|
||||
eos_id = self.tokenizer.eot_id
|
||||
for niter in range(1, gen_length):
|
||||
kv_seqlen.add_(kv_seqlen < max_seq_length)
|
||||
output = self._generate_compile_model(next_token, kv_seqlen)
|
||||
|
||||
logits = output.view(bs, self.model_args.vocab_size)
|
||||
|
||||
if use_sampling:
|
||||
temp = 0.7
|
||||
top_p = 0.95
|
||||
probs = torch.softmax(logits / temp, dim=-1)
|
||||
next_token = sample_utils.top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits, dim=-1)
|
||||
|
||||
next_token = next_token.reshape(bs)
|
||||
out_tokens[niter, :] = next_token
|
||||
|
||||
if next_token.eq(eos_id).any():
|
||||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stats.end_phase(tokens=niter * bs)
|
||||
|
||||
def trim_answer(prompt_len, tokens):
|
||||
# print(prompt, tokens)
|
||||
"""Trim the answer to end it on an eos token."""
|
||||
tokens = tokens[: max_seq_length - prompt_len]
|
||||
eos_id = self.tokenizer.eot_id
|
||||
if eos_id in tokens:
|
||||
return tokens[: tokens.index(eos_id) + 1]
|
||||
else:
|
||||
return tokens
|
||||
|
||||
answers = [
|
||||
trim_answer(prompt_len, answer)
|
||||
for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist())
|
||||
]
|
||||
return stats, answers
|
||||
|
||||
|
||||
def get_prompts(interactive: bool) -> Iterable[list[str]]:
|
||||
if interactive:
|
||||
while True:
|
||||
try:
|
||||
prompts = input("enter prompt: ").split("\n")
|
||||
except EOFError:
|
||||
print("exiting")
|
||||
sys.exit(0)
|
||||
yield prompts
|
||||
else:
|
||||
yield [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
|
||||
def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False):
|
||||
|
||||
local_rank = 0
|
||||
device = f"cuda:{local_rank}"
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
g = FastGen.build(ckpt_dir, GenArgs(), device)
|
||||
|
||||
if chat_format:
|
||||
g.tokenizer = ChatFormat(g.tokenizer)
|
||||
|
||||
for prompts in get_prompts(interactive):
|
||||
# prompts = [f"{prompt}\n" for prompt in prompts]
|
||||
if chat_format:
|
||||
# prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts]
|
||||
tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts]
|
||||
else:
|
||||
tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts]
|
||||
|
||||
print(tokens)
|
||||
stats, out_tokens = g.generate_all(
|
||||
tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling,
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
print(f"> {prompt}")
|
||||
answer = g.tokenizer.decode(out_tokens[i])
|
||||
print(answer)
|
||||
print("---------------")
|
||||
|
||||
for phase_stats in stats.phases:
|
||||
print(phase_stats.show())
|
||||
|
||||
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
Executable
+366
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from xformers.ops import RMSNorm, fmha, rope_padded
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
|
||||
)
|
||||
|
||||
import ctypes
|
||||
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
|
||||
|
||||
def bitnet_int8xint2_linear(input0, input1, s, ws):
|
||||
out_shape = list(input0.shape)
|
||||
out_shape[-1] = input1.shape[0]
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
M = input0.shape[0]
|
||||
if len(out_shape) == 3:
|
||||
M *= input0.shape[1]
|
||||
N = input1.shape[0]
|
||||
K = input1.shape[1] * 4
|
||||
|
||||
ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
|
||||
|
||||
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
|
||||
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 2560
|
||||
n_layers: int = 30
|
||||
n_heads: int = 20
|
||||
n_kv_heads: int = 5
|
||||
vocab_size: int = 128256
|
||||
ffn_dim: int = 6912
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000.0
|
||||
use_kernel: bool = False
|
||||
|
||||
|
||||
LayerCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class BitLinearKernel(nn.Module):
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
weight_scale: torch.Tensor
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
|
||||
self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
|
||||
|
||||
@torch.compile
|
||||
def quant_input(self, input):
|
||||
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
|
||||
return (input * s).round().clamp(-128, 127).to(torch.int8), s
|
||||
|
||||
def forward(self, input):
|
||||
input, s = self.quant_input(input)
|
||||
return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
|
||||
|
||||
class BitLinear(nn.Linear):
|
||||
@torch.compile
|
||||
def quant_input(self, input):
|
||||
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
|
||||
return (input * s).round().clamp(-128, 127) / s
|
||||
|
||||
def forward(self, input):
|
||||
input = self.quant_input(input)
|
||||
return F.linear(input, self.weight)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
rope_theta: float,
|
||||
norm_eps: float,
|
||||
use_kernel: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.n_local_heads = n_heads
|
||||
self.n_local_kv_heads = n_kv_heads
|
||||
|
||||
Linear = BitLinearKernel if use_kernel else BitLinear
|
||||
|
||||
self.wqkv = Linear(
|
||||
dim,
|
||||
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.wo = Linear(
|
||||
self.n_local_heads * head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.attn_sub_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: LayerCache,
|
||||
attn_bias: AttnBias,
|
||||
) -> torch.Tensor:
|
||||
|
||||
xqkv = self.wqkv(x)
|
||||
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
|
||||
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
|
||||
xk, xv = xkv.chunk(2, 1)
|
||||
|
||||
output_shape = xq.shape
|
||||
heads_per_group = self.n_local_heads // self.n_local_kv_heads
|
||||
xq = xq.view(
|
||||
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
|
||||
)
|
||||
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
|
||||
# xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
|
||||
# xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
|
||||
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
|
||||
cache_k, cache_v = cache
|
||||
|
||||
xq = rope_padded(
|
||||
xq=xq,
|
||||
xk=xk,
|
||||
xv=xv,
|
||||
cache_k=cache_k,
|
||||
cache_v=cache_v,
|
||||
attn_bias=attn_bias,
|
||||
theta=self.rope_theta,
|
||||
)
|
||||
|
||||
output = fmha.memory_efficient_attention_forward(
|
||||
xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
|
||||
)
|
||||
|
||||
output = output.reshape(output_shape)
|
||||
output = self.attn_sub_norm(output)
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
@torch.compile
|
||||
def squared_relu(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x) ** 2
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
norm_eps: float,
|
||||
use_kernel: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
Linear = BitLinearKernel if use_kernel else BitLinear
|
||||
|
||||
self.w13 = Linear(
|
||||
dim,
|
||||
2 * hidden_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.w2 = Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x13 = self.w13(x)
|
||||
x1, x3 = x13.chunk(2, -1)
|
||||
inner = self.ffn_sub_norm(squared_relu(x1) * x3)
|
||||
output = self.w2(inner)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert args.dim % args.n_heads == 0
|
||||
head_dim = args.dim // args.n_heads
|
||||
if args.n_kv_heads is not None:
|
||||
n_kv_heads = args.n_kv_heads
|
||||
else:
|
||||
n_kv_heads = args.n_heads
|
||||
|
||||
assert args.n_heads % n_kv_heads == 0
|
||||
|
||||
self.attention = Attention(
|
||||
dim=args.dim,
|
||||
head_dim=head_dim,
|
||||
n_heads=args.n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
rope_theta=args.rope_theta,
|
||||
norm_eps=args.norm_eps,
|
||||
use_kernel=args.use_kernel,
|
||||
)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=args.ffn_dim,
|
||||
norm_eps=args.norm_eps,
|
||||
use_kernel=args.use_kernel,
|
||||
)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: LayerCache,
|
||||
attn_bias: AttnBias,
|
||||
) -> torch.Tensor:
|
||||
h = x + self.attention.forward(
|
||||
self.attention_norm(x),
|
||||
cache,
|
||||
attn_bias,
|
||||
)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
|
||||
self.tok_embeddings = nn.Embedding(
|
||||
num_embeddings=args.vocab_size,
|
||||
embedding_dim=args.dim,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(args))
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
self.output = nn.Linear(
|
||||
args.dim,
|
||||
args.vocab_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_attn_bias(
|
||||
self,
|
||||
token_values: torch.Tensor,
|
||||
attn_bias: AttnBias,
|
||||
cache: list[LayerCache],
|
||||
) -> torch.Tensor:
|
||||
h = self.tok_embeddings(token_values)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
h = layer(h, cache[i], attn_bias)
|
||||
|
||||
logits = self.output(self.norm(h))
|
||||
return logits.float()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_values: torch.Tensor,
|
||||
token_lengths: torch.Tensor,
|
||||
start_pos: torch.Tensor,
|
||||
cache: list[LayerCache],
|
||||
kv_padding: int,
|
||||
) -> torch.Tensor:
|
||||
attn_bias = AttnBias.from_seqlens(
|
||||
q_seqlen=token_lengths.tolist(),
|
||||
kv_seqlen=(start_pos + token_lengths).tolist(),
|
||||
kv_padding=kv_padding,
|
||||
)
|
||||
return self.forward_with_attn_bias(token_values, attn_bias, cache)
|
||||
|
||||
|
||||
def make_cache(
|
||||
args: ModelArgs,
|
||||
length: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
n_layers: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> list[LayerCache]:
|
||||
"""
|
||||
Allocate a cache to be used with the Transformer module.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): the model configuration.
|
||||
length (int): per layer cache size.
|
||||
It is usually budgeted as ``max_batch * max_seq``
|
||||
device (torch.device, optional): the device on which
|
||||
the cache should be allocated.
|
||||
n_layers (int, optional): the number of layers to
|
||||
allocate a cache for (defaults to the model
|
||||
settings).
|
||||
dtype (torch.dtype, optional): the dtype to use for
|
||||
cache entries (defaults to the default dtype).
|
||||
|
||||
Returns:
|
||||
The cache object to pass to ``Tranformer.forward``.
|
||||
"""
|
||||
|
||||
head_dim = args.dim // args.n_heads
|
||||
n_kv_heads = args.n_kv_heads
|
||||
if n_kv_heads is None:
|
||||
n_kv_heads = args.n_heads
|
||||
n_local_kv_heads = n_kv_heads
|
||||
|
||||
if n_layers is None:
|
||||
n_layers = args.n_layers
|
||||
|
||||
shape = (1, length, n_local_kv_heads, 1, head_dim)
|
||||
heads_per_group = args.n_heads // n_kv_heads
|
||||
expansion = (-1, -1, -1, heads_per_group, -1)
|
||||
return [
|
||||
(
|
||||
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
|
||||
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
|
||||
|
||||
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
|
||||
"""
|
||||
Take a prefix view of a larger cache.
|
||||
|
||||
The original cache object remains of identical size and valid
|
||||
after the shrinked alias has been used. This function is useful
|
||||
when a cache was allocated for a larger batch size than what is
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
cache: the cache to take a view in.
|
||||
length (int): the desired length
|
||||
|
||||
Returns:
|
||||
A view in the input cache object.
|
||||
"""
|
||||
|
||||
if len(cache) > 0:
|
||||
assert cache[0][0].shape[1] >= length
|
||||
|
||||
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
|
||||
Executable
+98
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def B_global_16x32_to_shared_load_16x32_layout(i, j):
|
||||
"""
|
||||
stride * 8 * (tx // HALF_WARP_expr)
|
||||
+ (tx % 8) * stride
|
||||
+ 16 * ((tx % HALF_WARP_expr) // 8)
|
||||
"""
|
||||
thread_id = i * 2 + j // 16
|
||||
row = (thread_id // 16) * 8 + (thread_id % 8)
|
||||
col = (j % 16) + 16 * ((thread_id % 16) // 8)
|
||||
return row, col
|
||||
|
||||
|
||||
def permutate_weight_fastest(weight):
|
||||
wmma_n = 16
|
||||
wmma_k = 32
|
||||
N = weight.shape[0]
|
||||
K = weight.shape[1]
|
||||
|
||||
# Create a lookup table for the permutation
|
||||
mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int)
|
||||
for ii in range(wmma_n):
|
||||
for jj in range(wmma_k):
|
||||
mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj)
|
||||
|
||||
# Reshape weight for the final format
|
||||
permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8")
|
||||
|
||||
# Use advanced indexing for the entire operation
|
||||
i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis]
|
||||
j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis]
|
||||
|
||||
# Create the source indices
|
||||
src_i = i_indices * wmma_n + mapping[:, :, 0]
|
||||
src_j = j_indices * wmma_k + mapping[:, :, 1]
|
||||
|
||||
# Extract and reshape in one go
|
||||
permutated_weight = weight[src_i, src_j]
|
||||
|
||||
return permutated_weight
|
||||
|
||||
|
||||
def compress_int2_to_int8(int2_weight):
|
||||
int8_weight = np.zeros(
|
||||
(*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8
|
||||
)
|
||||
for j in range(int2_weight.shape[-1] // 4):
|
||||
for k in range(4):
|
||||
int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2)
|
||||
return int8_weight
|
||||
|
||||
|
||||
def interleave_weight_int8(qweight, nbits=2):\
|
||||
# reinterpret the data type of qweight to int32
|
||||
# shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30]
|
||||
# index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
|
||||
qweight = qweight.view(np.int32)
|
||||
new_qweight = np.zeros_like(qweight)
|
||||
bits_stride = 8
|
||||
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
|
||||
num_groups = 32 // bits_stride # 4
|
||||
elems_per_group = bits_stride // nbits # 4
|
||||
for i in range(num_groups):
|
||||
for j in range(elems_per_group):
|
||||
offset = i * elems_per_group + j
|
||||
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
|
||||
|
||||
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
|
||||
return new_qweight.view(np.int8)
|
||||
|
||||
|
||||
|
||||
def convert_weight_int8_to_int2(weight):
|
||||
N = weight.shape[0]
|
||||
K = weight.shape[1]
|
||||
|
||||
weight = weight+2
|
||||
|
||||
weight = weight.cpu().numpy()
|
||||
|
||||
# print(weight)
|
||||
# print(torch.max(weight), torch.min(weight))
|
||||
|
||||
# permutated_weight_slow = permutate_weight(weight)
|
||||
permutated_weight = permutate_weight_fastest(weight)
|
||||
# assert np.all(permutated_weight_slow == permutated_weight)
|
||||
# print("Permutation is correct")
|
||||
compressed_weight = compress_int2_to_int8(permutated_weight)
|
||||
interleaved_weight = interleave_weight_int8(compressed_weight, 2)
|
||||
|
||||
ret = torch.from_numpy(interleaved_weight)
|
||||
|
||||
ret = torch.reshape(ret, (N, K // 4))
|
||||
|
||||
return ret
|
||||
Executable
+9
@@ -0,0 +1,9 @@
|
||||
fire
|
||||
sentencepiece
|
||||
torch>=2.2.0
|
||||
xformers>=0.0.22
|
||||
tiktoken
|
||||
blobfile
|
||||
flask
|
||||
einops
|
||||
transformers
|
||||
Executable
+31
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
@torch.compile
|
||||
def top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): probability distribution tensor.
|
||||
p (float): probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative
|
||||
probability mass exceeds the threshold p. The distribution is
|
||||
renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
Executable
+57
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseStats:
|
||||
name: str
|
||||
tokens: int
|
||||
time: float
|
||||
|
||||
def show(self) -> str:
|
||||
tps = self.tokens / self.time
|
||||
return (
|
||||
f"[{self.name}] "
|
||||
f"generated tokens: {self.tokens}"
|
||||
f" - total time: {self.time:.3f}s"
|
||||
f" - {tps:.1f} tokens per second"
|
||||
)
|
||||
|
||||
|
||||
class Stats:
|
||||
"""
|
||||
Generation stats, split by phases.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.phases = []
|
||||
self.current = None
|
||||
|
||||
def end_phase(self, tokens: int, now: Optional[float] = None):
|
||||
"""Terminate the current phase."""
|
||||
if self.current is None:
|
||||
return
|
||||
if now is None:
|
||||
now = time.time()
|
||||
cname, ctokens, ctime = self.current
|
||||
stats = PhaseStats(
|
||||
name=cname,
|
||||
tokens=tokens - ctokens,
|
||||
time=now - ctime,
|
||||
)
|
||||
self.phases.append(stats)
|
||||
|
||||
def phase(self, name: str, tokens: int = 0):
|
||||
"""
|
||||
Start a new phase, and terminate the current one,
|
||||
if one is ongoing.
|
||||
"""
|
||||
now = time.time()
|
||||
self.end_phase(tokens, now)
|
||||
self.current = (name, tokens, now)
|
||||
+99
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from torch import nn
|
||||
|
||||
from pack_weight import convert_weight_int8_to_int2
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
import ctypes
|
||||
import numpy as np
|
||||
# set all seed
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
|
||||
|
||||
def bitnet_int8xint2_linear(input0, input1, s, ws, ret):
|
||||
out_shape = list(input0.shape)
|
||||
out_shape[-1] = input1.shape[0]
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
M = input0.shape[0]
|
||||
if len(out_shape) == 3:
|
||||
M *= input0.shape[1]
|
||||
N = input1.shape[0]
|
||||
K = input1.shape[1] * 4
|
||||
|
||||
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
|
||||
|
||||
return ret
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_list = [
|
||||
(2560, 2560),
|
||||
(3840, 2560),
|
||||
(13824, 2560),
|
||||
(2560, 6912) ,
|
||||
(3200, 3200),
|
||||
(4800, 3200),
|
||||
(3200, 10240),
|
||||
(20480, 3200),
|
||||
]
|
||||
for N,K in test_list:
|
||||
weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda')
|
||||
weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||||
weight_compressed = convert_weight_int8_to_int2(weight).to('cuda')
|
||||
|
||||
for i in range(1):
|
||||
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
|
||||
input0_bf16 = input0.to(torch.bfloat16)
|
||||
input_np = input0.cpu().to(torch.int32).numpy()
|
||||
weight_np = weight.cpu().to(torch.int32).T.numpy()
|
||||
out_np = np.matmul(input_np,weight_np)
|
||||
out_np = torch.tensor(out_np).cuda().to(torch.bfloat16)
|
||||
|
||||
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||||
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
|
||||
out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
|
||||
|
||||
print(f'custom == np {torch.all(out==out_np)}')
|
||||
|
||||
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
|
||||
input0_fp16 = input0.to(torch.float16)
|
||||
input0_bf16 = input0.to(torch.bfloat16)
|
||||
weight_fp16 = weight.to(torch.float16).T
|
||||
weight_bf16 = weight.to(torch.bfloat16).T
|
||||
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
|
||||
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||||
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
|
||||
t0 = benchmark.Timer(
|
||||
stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)",
|
||||
setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear",
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
t1 = benchmark.Timer(
|
||||
stmt="torch.matmul(input0_bf16,weight_bf16)",
|
||||
setup="from __main__ import input0_bf16, weight_bf16",
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
time0 = t0.timeit(50)
|
||||
time1 = t1.timeit(50)
|
||||
|
||||
print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us')
|
||||
# activities = [ ProfilerActivity.CUDA,
|
||||
# # ProfilerActivity.CPU
|
||||
# ]
|
||||
# sort_by_keyword = 'cuda' + "_time_total"
|
||||
# with profile(activities=activities, record_shapes=True) as prof:
|
||||
# with record_function("model_inference1"):
|
||||
# for _ in range(10):
|
||||
# bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
|
||||
# torch.matmul(input0_fp16,weight_fp16)
|
||||
# torch.matmul(input0_bf16,weight_bf16)
|
||||
|
||||
# print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15))
|
||||
|
||||
Executable
+128000
File diff suppressed because it is too large
Load Diff
Executable
+257
@@ -0,0 +1,257 @@
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
cast,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
Role = Literal["system", "user", "assistant"]
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
Dialog = Sequence[Message]
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>", # end of turn
|
||||
] + [
|
||||
f"<|reserved_special_token_{i}|>"
|
||||
for i in range(5, self.num_reserved_special_tokens - 5)
|
||||
]
|
||||
self.special_tokens = {
|
||||
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
||||
}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
logger.info(f"Reloaded tiktoken model from {model_path}")
|
||||
|
||||
self.n_words: int = self.model.n_vocab
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
self.pad_id: int = self.n_words - 1
|
||||
self.stop_tokens = {
|
||||
self.special_tokens["<|end_of_text|>"],
|
||||
self.special_tokens["<|eot_id|>"],
|
||||
}
|
||||
logger.info(
|
||||
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
assert type(s) is str
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(
|
||||
s: str, max_consecutive_slice_len: int
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
||||
|
||||
class ChatFormat:
|
||||
def __init__(self, tokenizer: Tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
self.eot_id = tokenizer.special_tokens["<|eot_id|>"]
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
# Decode the tokens to a string.
|
||||
decoded_str = self.tokenizer.decode(tokens)
|
||||
# Remove the special tokens from the decoded string.
|
||||
decoded_str = decoded_str.replace("<|eot_id|>", "")
|
||||
return decoded_str
|
||||
|
||||
def encode_header(self, message: Message) -> List[int]:
|
||||
tokens = []
|
||||
if message["role"] == "system":
|
||||
tokens.extend(self.tokenizer.encode("System: ", bos=False, eos=False))
|
||||
elif message["role"] == "user":
|
||||
tokens.extend(self.tokenizer.encode("User: ", bos=False, eos=False))
|
||||
elif message["role"] == "assistant":
|
||||
tokens.extend(self.tokenizer.encode("Assistant: ", bos=False, eos=False))
|
||||
else:
|
||||
raise NotImplementedError(f"Role {message['role']} not implemented.")
|
||||
# tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
# tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
|
||||
# tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
# tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_message(self, message: Message, return_target=False) -> List[int]:
|
||||
tokens, targets = [], []
|
||||
headers = self.encode_header(message)
|
||||
contents = self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
|
||||
contents.append(self.tokenizer.special_tokens["<|eot_id|>"])
|
||||
tokens = headers + contents
|
||||
|
||||
if message["role"] == "assistant":
|
||||
targets = [-1] * len(headers) + contents
|
||||
else:
|
||||
targets = [-1] * len(tokens)
|
||||
|
||||
if return_target:
|
||||
return tokens, targets
|
||||
|
||||
return tokens, None
|
||||
|
||||
def encode_dialog_prompt(self, dialog: Dialog, completion=False, return_target=False) -> List[int]:
|
||||
tokens = [self.tokenizer.special_tokens["<|begin_of_text|>"]]
|
||||
targets = [-1]
|
||||
for message in dialog:
|
||||
_tokens, _targets = self.encode_message(message, return_target=return_target)
|
||||
tokens.extend(_tokens)
|
||||
if _targets is not None:
|
||||
targets.extend(_targets)
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
if completion:
|
||||
tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
|
||||
|
||||
if return_target:
|
||||
return tokens, targets
|
||||
|
||||
return tokens
|
||||
Reference in New Issue
Block a user