mirror of
https://github.com/microsoft/BitNet.git
synced 2026-05-06 04:40:52 +00:00
commit paper code
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -517,6 +517,92 @@ def preprocess_weights_tl1(
|
||||
return weight
|
||||
|
||||
|
||||
def preprocess_two_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 2, 2))
|
||||
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
|
||||
lo_weight = np.split(weight, 2, axis=1)[1]
|
||||
|
||||
weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
|
||||
weight = weight + 4
|
||||
weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
|
||||
weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1)
|
||||
weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2))
|
||||
weight_0 = weight[:, :, :, :, :, 0]
|
||||
weight_1 = weight[:, :, :, :, :, 1]
|
||||
weight_0 = weight_0 << 4
|
||||
weight_1 = weight_1
|
||||
weight = weight_0 + weight_1
|
||||
weight = weight.reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
|
||||
|
||||
for i in range(weight.shape[0]):
|
||||
final_weight.append(weight[i, :])
|
||||
|
||||
def preprocess_three_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 3, 3))
|
||||
split_weights = np.split(weight, 3, axis=1)
|
||||
first_weight = np.multiply(split_weights[0], 9)
|
||||
second_weight = np.multiply(split_weights[1], 3)
|
||||
third_weight = split_weights[2]
|
||||
|
||||
weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3)
|
||||
sign_weight = np.sign(weight)
|
||||
sign_weight = np.where(sign_weight < 1, 0, sign_weight)
|
||||
weight = np.abs(weight)
|
||||
|
||||
weight = np.reshape(weight, (M, K // 3)).astype(np.uint8)
|
||||
sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8)
|
||||
|
||||
weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
|
||||
weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4)
|
||||
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3))
|
||||
|
||||
weight_list = []
|
||||
for i in range(by // 3):
|
||||
weight_list.append(weight[:, :, :, :, :, i])
|
||||
|
||||
for i in range(by // 3 // 2):
|
||||
weight_list[i] = weight_list[i] << 4
|
||||
weight_list[i + by // 3 // 2] = weight_list[i + by // 3 // 2]
|
||||
weight_list[i] = weight_list[i] + weight_list[i + by // 3 // 2]
|
||||
weight_list[i] = weight_list[i].reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
|
||||
|
||||
for i in range(weight_list[0].shape[0]):
|
||||
for j in range(by // 3 // 2):
|
||||
final_weight.append(weight_list[j][i, :])
|
||||
|
||||
sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
|
||||
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4).astype(np.uint8)
|
||||
|
||||
combine_weight_list = []
|
||||
for i in range(by // 3 // 2):
|
||||
combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm), dtype=np.uint8)
|
||||
combine_weight_list.append(combine_weight)
|
||||
|
||||
for i in range(8):
|
||||
for j in range(by // 3 // 2):
|
||||
if bm == 16:
|
||||
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * i + j] << 7 - i)
|
||||
elif bm == 32:
|
||||
if i > 3 :
|
||||
ti = (i - 4) * 2 + 1
|
||||
else:
|
||||
ti = i * 2
|
||||
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * ti + j] << 7 - i)
|
||||
|
||||
for i in range(by // 3 // 2):
|
||||
combine_weight_list[i] = combine_weight_list[i].reshape((M * K // (by * 4)) // 16, 16)
|
||||
|
||||
for i in range(combine_weight_list[0].shape[0]):
|
||||
for j in range(by // 3 // 2):
|
||||
final_weight.append(combine_weight_list[j][i, :])
|
||||
|
||||
def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
|
||||
weight = np.reshape(weight, (weight_num // 2, 2))
|
||||
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
|
||||
@@ -603,7 +689,6 @@ def preprocess_weights_tl2(
|
||||
weight = w
|
||||
weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
|
||||
weight = np.sign(weight)
|
||||
weight_num = np.prod(weight.shape)
|
||||
|
||||
config.read('include/kernel_config.ini')
|
||||
BM = -1
|
||||
@@ -631,7 +716,8 @@ def preprocess_weights_tl2(
|
||||
|
||||
final_weight = []
|
||||
|
||||
preprocess_three_weights_tl2(three_weight.shape[0],
|
||||
if args.loss:
|
||||
preprocess_three_weights_tl2_loss(three_weight.shape[0],
|
||||
three_weight.shape[1],
|
||||
three_weight.shape[0] * three_weight.shape[1],
|
||||
BM,
|
||||
@@ -641,8 +727,29 @@ def preprocess_weights_tl2(
|
||||
three_weight,
|
||||
final_weight)
|
||||
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2( two_weight.shape[0],
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2_loss(two_weight.shape[0],
|
||||
two_weight.shape[1],
|
||||
two_weight.shape[0] * two_weight.shape[1],
|
||||
BM,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
two_weight,
|
||||
final_weight)
|
||||
else:
|
||||
preprocess_three_weights_tl2(three_weight.shape[0],
|
||||
three_weight.shape[1],
|
||||
three_weight.shape[0] * three_weight.shape[1],
|
||||
BM,
|
||||
BY,
|
||||
bm,
|
||||
by,
|
||||
three_weight,
|
||||
final_weight)
|
||||
|
||||
if (weight.shape[1] % BY != 0):
|
||||
preprocess_two_weights_tl2(two_weight.shape[0],
|
||||
two_weight.shape[1],
|
||||
two_weight.shape[0] * two_weight.shape[1],
|
||||
BM,
|
||||
@@ -652,8 +759,10 @@ def preprocess_weights_tl2(
|
||||
two_weight,
|
||||
final_weight)
|
||||
weight = np.array(final_weight, dtype=np.uint8).reshape(-1)
|
||||
weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 -
|
||||
weight.shape[0]), mode='constant', constant_values=0)
|
||||
pad_nums = (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8
|
||||
pad_align_nums = 32 - ((K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8) % 32
|
||||
pad_nums = pad_nums + pad_align_nums
|
||||
weight = np.pad(weight, (0, pad_nums - weight.shape[0]), mode='constant', constant_values=0)
|
||||
return weight
|
||||
|
||||
def transform_to_tl1(x: np.ndarray):
|
||||
@@ -1116,6 +1225,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer")
|
||||
parser.add_argument("--loss", action="store_true", help="use loss tl2")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user