commit paper code

This commit is contained in:
Eddie-Wang1120
2025-02-16 15:03:25 +08:00
parent 437b321dcf
commit 4c736e3728
10 changed files with 1956 additions and 26 deletions
File diff suppressed because it is too large Load Diff
+116 -6
View File
@@ -517,6 +517,92 @@ def preprocess_weights_tl1(
return weight
def preprocess_two_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 2, 2))
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
lo_weight = np.split(weight, 2, axis=1)[1]
weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
weight = weight + 4
weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
weight = weight.reshape((M // BM, BM, K // 2)).transpose(0, 2, 1)
weight = weight.reshape((M // BM, K // BY, BY // 2, BM)).transpose(0, 1, 3, 2)
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 2)).transpose(0, 1, 2, 4, 3)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 2, bm)).transpose(0, 1, 2, 3, 5, 4)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 2))
weight_0 = weight[:, :, :, :, :, 0]
weight_1 = weight[:, :, :, :, :, 1]
weight_0 = weight_0 << 4
weight_1 = weight_1
weight = weight_0 + weight_1
weight = weight.reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
for i in range(weight.shape[0]):
final_weight.append(weight[i, :])
def preprocess_three_weights_tl2_loss(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 3, 3))
split_weights = np.split(weight, 3, axis=1)
first_weight = np.multiply(split_weights[0], 9)
second_weight = np.multiply(split_weights[1], 3)
third_weight = split_weights[2]
weight = np.reshape((first_weight + second_weight + third_weight), weight_num // 3)
sign_weight = np.sign(weight)
sign_weight = np.where(sign_weight < 1, 0, sign_weight)
weight = np.abs(weight)
weight = np.reshape(weight, (M, K // 3)).astype(np.uint8)
sign_weight = np.reshape(sign_weight, (M, K // 3)).astype(np.uint8)
weight = weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
weight = weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
weight = weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, by // 3, bm)).transpose(0, 1, 2, 3, 5, 4)
weight = weight.reshape((M // BM, K // BY, BM // bm, BY // by, bm, by // 3))
weight_list = []
for i in range(by // 3):
weight_list.append(weight[:, :, :, :, :, i])
for i in range(by // 3 // 2):
weight_list[i] = weight_list[i] << 4
weight_list[i + by // 3 // 2] = weight_list[i + by // 3 // 2]
weight_list[i] = weight_list[i] + weight_list[i + by // 3 // 2]
weight_list[i] = weight_list[i].reshape(M * K // bm // by, bm).reshape(M * K // by // 16, 16)
for i in range(weight_list[0].shape[0]):
for j in range(by // 3 // 2):
final_weight.append(weight_list[j][i, :])
sign_weight = sign_weight.reshape((M // BM, BM, K // 3)).transpose(0, 2, 1)
sign_weight = sign_weight.reshape((M // BM, K // BY, BY // 3, BM)).transpose(0, 1, 3, 2)
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, bm, BY // 3)).transpose(0, 1, 2, 4, 3)
sign_weight = sign_weight.reshape((M // BM, K // BY, BM // bm, BY // (by * 4), by // 3 * 4, bm)).transpose(0, 1, 2, 3, 5, 4).astype(np.uint8)
combine_weight_list = []
for i in range(by // 3 // 2):
combine_weight = np.zeros((M // BM, K // BY, BM // bm, BY // (by * 4), bm), dtype=np.uint8)
combine_weight_list.append(combine_weight)
for i in range(8):
for j in range(by // 3 // 2):
if bm == 16:
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * i + j] << 7 - i)
elif bm == 32:
if i > 3 :
ti = (i - 4) * 2 + 1
else:
ti = i * 2
combine_weight_list[j] = combine_weight_list[j] + (sign_weight[:, :, :, :, :, by // 3 // 2 * ti + j] << 7 - i)
for i in range(by // 3 // 2):
combine_weight_list[i] = combine_weight_list[i].reshape((M * K // (by * 4)) // 16, 16)
for i in range(combine_weight_list[0].shape[0]):
for j in range(by // 3 // 2):
final_weight.append(combine_weight_list[j][i, :])
def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
weight = np.reshape(weight, (weight_num // 2, 2))
hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
@@ -603,7 +689,6 @@ def preprocess_weights_tl2(
weight = w
weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
weight = np.sign(weight)
weight_num = np.prod(weight.shape)
config.read('include/kernel_config.ini')
BM = -1
@@ -631,7 +716,8 @@ def preprocess_weights_tl2(
final_weight = []
preprocess_three_weights_tl2(three_weight.shape[0],
if args.loss:
preprocess_three_weights_tl2_loss(three_weight.shape[0],
three_weight.shape[1],
three_weight.shape[0] * three_weight.shape[1],
BM,
@@ -641,8 +727,29 @@ def preprocess_weights_tl2(
three_weight,
final_weight)
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2( two_weight.shape[0],
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2_loss(two_weight.shape[0],
two_weight.shape[1],
two_weight.shape[0] * two_weight.shape[1],
BM,
32,
32,
4,
two_weight,
final_weight)
else:
preprocess_three_weights_tl2(three_weight.shape[0],
three_weight.shape[1],
three_weight.shape[0] * three_weight.shape[1],
BM,
BY,
bm,
by,
three_weight,
final_weight)
if (weight.shape[1] % BY != 0):
preprocess_two_weights_tl2(two_weight.shape[0],
two_weight.shape[1],
two_weight.shape[0] * two_weight.shape[1],
BM,
@@ -652,8 +759,10 @@ def preprocess_weights_tl2(
two_weight,
final_weight)
weight = np.array(final_weight, dtype=np.uint8).reshape(-1)
weight = np.pad(weight, (0, (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8 -
weight.shape[0]), mode='constant', constant_values=0)
pad_nums = (K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8
pad_align_nums = 32 - ((K - 256) * M // 3 * 5 // 8 + 256 * M // 2 * 4 // 8) % 32
pad_nums = pad_nums + pad_align_nums
weight = np.pad(weight, (0, pad_nums - weight.shape[0]), mode='constant', constant_values=0)
return weight
def transform_to_tl1(x: np.ndarray):
@@ -1116,6 +1225,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--quant-embd", action="store_true", help="quantize the embedding layer")
parser.add_argument("--loss", action="store_true", help="use loss tl2")
return parser.parse_args()