diff --git a/3rdparty/llama.cpp b/3rdparty/llama.cpp index 40ed0f2..0f0e7da 160000 --- a/3rdparty/llama.cpp +++ b/3rdparty/llama.cpp @@ -1 +1 @@ -Subproject commit 40ed0f290203a9a78540b8f7eb18bd828043fe21 +Subproject commit 0f0e7daec25c467800af808b55ce28a69461f904 diff --git a/include/gemm-config.h b/include/gemm-config.h new file mode 100644 index 0000000..d766dfb --- /dev/null +++ b/include/gemm-config.h @@ -0,0 +1,23 @@ +#define ACT_PARALLEL +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +#if defined(ACT_PARALLEL) + #define ROW_BLOCK_SIZE 4 + #define COL_BLOCK_SIZE 128 + #define PARALLEL_SIZE 4 +#else + #define ROW_BLOCK_SIZE 32 + #define COL_BLOCK_SIZE 4 + #define PARALLEL_SIZE 4 +#endif +#elif defined(__ARM_NEON) +#if defined(ACT_PARALLEL) + #define ROW_BLOCK_SIZE 8 + #define COL_BLOCK_SIZE 64 + #define PARALLEL_SIZE 8 +#else + #define ROW_BLOCK_SIZE 16 + #define COL_BLOCK_SIZE 4 + #define PARALLEL_SIZE 4 +#endif +#endif + diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp index eeca82b..4ba9d65 100644 --- a/src/ggml-bitnet-mad.cpp +++ b/src/ggml-bitnet-mad.cpp @@ -1,13 +1,18 @@ #include #include - +#include #include "ggml-bitnet.h" #include "ggml-quants.h" +#include "gemm-config.h" +#include "ggml-cpu-impl.h" #include #include +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #define QK_I2_S 128 -#define QK_I2 128 +#elif defined(__ARM_NEON) +#define QK_I2_S 64 +#endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #include @@ -44,8 +49,8 @@ static inline int hsum_i32_8(const __m256i a) { #endif size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - // 2 bits per weight - +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +#if defined(ACT_PARALLEL) size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); int n = nrow * n_per_row; @@ -73,11 +78,11 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_ // -1, 0, 1 uint8_t* i2_weight = (uint8_t*)dst; - for (int i = 0; i < n / QK_I2; i++) { - for (int j = 0; j < QK_I2; j++) { + for (int i = 0; i < n / QK_I2_S; i++) { + for (int j = 0; j < QK_I2_S; j++) { int group_idx = j / 32; int group_pos = j % 32; - uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx)); + uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx)); i2_weight[i * 32 + group_pos] |= temp; } } @@ -89,9 +94,207 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_ // 32B for alignment return nrow * row_size / 4 + 32; +#else + assert((nrow % 4) == 0 && "quantize_i2_s_1x4 requires nrow % 4 == 0"); + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + int64_t n = nrow * n_per_row; + + double max = 0; + for (int64_t i = 0; i < n; ++i) { + max = fmax(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); + for (int64_t i=0; i 0 ? 2 : 0; + } + + uint8_t* out = (uint8_t*)dst; + memset(out, 0, (size_t)(n / 4)); + + // for each group of 4 rows, for each column, write one byte + int64_t nrow4 = nrow / 4; + for (int64_t rg = 0; rg < nrow4; rg++) { + int64_t r0 = rg * 4 + 0; + int64_t r1 = rg * 4 + 1; + int64_t r2 = rg * 4 + 2; + int64_t r3 = rg * 4 + 3; + + int64_t base = rg * n_per_row; + + for (int64_t col = 0; col < n_per_row; col++) { + uint8_t q0 = q8[r0 * n_per_row + col]; + uint8_t q1 = q8[r1 * n_per_row + col]; + uint8_t q2 = q8[r2 * n_per_row + col]; + uint8_t q3 = q8[r3 * n_per_row + col]; + + uint8_t packed = (uint8_t)((q0 << 6) | (q1 << 4) | (q2 << 2) | (q3 << 0)); + out[base + col] = packed; + } + } + + // store scale at the end of quantized data (same location pattern as quantize_i2_s) + float* scale_ptr = (float*)((char*)out + n / 4); + scale_ptr[0] = (float)i2_scale; + + free(q8); + + // return size (keep same formula as quantize_i2_s) + return nrow * row_size / 4 + 32; +#endif +#elif defined(__ARM_NEON) + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + + int n = nrow * n_per_row; + + // f32 -> q8 + double max = 0; + for (int i = 0; i < n; ++i) { + max = fmax(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); + for (int i=0; i 0 ? 2 : 0; + } + + memset(dst, 0, n * sizeof(uint8_t) / 4); + + // q8 -> 0, 1, 2 + // | | | + // -1, 0, 1 + + uint8_t* i2_weight = (uint8_t*)dst; + for (int i = 0; i < n / QK_I2_S; i++) { + for (int j = 0; j < QK_I2_S; j++) { + int group_idx = j / 16; + int group_pos = j % 16; + uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx)); + i2_weight[i * 16 + group_pos] |= temp; + } + } + + float* scale_ptr = (float*)((char*)i2_weight + n / 4); + scale_ptr[0] = i2_scale; + + free(q8); + + // 32B for alignment + return nrow * row_size / 4 + 32; +#endif } -void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + __m256i mask = _mm256_set1_epi8(0x03); + __m256i one16 = _mm256_set1_epi16(1); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row++) { + __m256i accu = _mm256_setzero_si256(); + + // 计算当前行的x指针偏移 + const uint8_t * x_row = x + row * bx / 4; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x_row + i * 1024; // 32 * 32 + const int8_t *py = y + i * 4096; // 32 * 128 + __m256i accu32 = _mm256_setzero_si256(); + + for (int j = 0; j < 32; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); + + px += 32; + py += 128; + } + accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, one16), accu); + } + + for (int i = 0; i < groupla_num; i++) { + __m256i accula = _mm256_setzero_si256(); + const uint8_t *px = x_row + group32_num * 1024; // 32 * 32 + const int8_t *py = y + group32_num * 4096; // 32 * 128 + + for (int j = 0; j < la_num; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); + + px += 32; + py += 128; + } + accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, one16)); + } + + int sumi = hsum_i32_8(accu); + s[row] = (float)sumi; + } +#elif defined(__ARM_NEON) const uint8_t * x = (uint8_t *)vx; const int8_t * y = (int8_t *)vy; @@ -100,264 +303,754 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b const int la_num = nb % 32; const int groupla_num = nb % 32 != 0 ? 1 : 0; -#if defined(__AVX2__) - - __m256i mask = _mm256_set1_epi8(0x03); - __m256i accu = _mm256_setzero_si256(); - - for (int i=0; i < group32_num; i++){ - __m256i accu32 = _mm256_setzero_si256(); - for (int j=0; j < 32; j++) { - // 128 index - __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32)); - __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); - __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); - __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); - - // each 32 index - xq8_3 = _mm256_and_si256(xq8_3, mask); - xq8_2 = _mm256_and_si256(xq8_2, mask); - xq8_1 = _mm256_and_si256(xq8_1, mask); - xq8_0 = _mm256_and_si256(xq8_0, mask); - - // each 32 index - __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0)); - __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32)); - __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64)); - __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96)); - - // 128 index accumulation add - // split into 32 accumulation block - // each block each 128 index accumulated 4index - // each index maximum 256 - // each block maximum 4 * 256 - // each block accumulation maximum 127 * 256 - // each 32 group index (128 index in one group) needs cast to int32 - xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); - xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); - xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); - xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); - - accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); - accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); - } - accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu); - } - - for (int i = 0; i < groupla_num; i++){ - __m256i accula = _mm256_setzero_si256(); - for (int j = 0; j < la_num; j++) { - // 128 index - __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32)); - __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); - __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); - __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); - - // each 32 index - xq8_3 = _mm256_and_si256(xq8_3, mask); - xq8_2 = _mm256_and_si256(xq8_2, mask); - xq8_1 = _mm256_and_si256(xq8_1, mask); - xq8_0 = _mm256_and_si256(xq8_0, mask); - - // each 32 index - __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0)); - __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32)); - __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64)); - __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96)); - - // 128 index accumulation add - // split into 32 accumulation block - // each block each 128 index accumulated 4index - // each index maximum 256 - // each block maximum 4 * 256 - // each block accumulation maximum 127 * 256 - // each 32 group index (128 index in one group) needs cast to int32 - xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); - xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); - xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); - xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); - - accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); - accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); - } - accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1))); - } - int sumi = hsum_i32_8(accu); - *s = (float)sumi; - -#elif defined(__ARM_NEON) - - int32x4_t accu_0 = vdupq_n_s32(0); - int32x4_t accu_1 = vdupq_n_s32(0); - int32x4_t accu_2 = vdupq_n_s32(0); - int32x4_t accu_3 = vdupq_n_s32(0); const uint8x16_t mask = vdupq_n_u8(3); - for (int i=0; i < group32_num; i++) { + // 处理多列,nrc表示要处理的列数 + for (int row = 0; row < nrc; row++) { + int32x4_t accu = vdupq_n_s32(0); + + // 计算当前行的x指针偏移 + const uint8_t * x_row = x + row * bx / 4; + + for (int i=0; i < group32_num; i++) { #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accu32_0 = vdupq_n_s16(0); - int16x8_t accu32_1 = vdupq_n_s16(0); - int16x8_t accu32_2 = vdupq_n_s16(0); - int16x8_t accu32_3 = vdupq_n_s16(0); + int16x8_t accu32 = vdupq_n_s16(0); #endif + for (int j=0; j < 32; j++) { + uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - for (int j=0; j < 32; j++) { - uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32); - uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16); - uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); - uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); - uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); - int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); - int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); - int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48); - const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64); - const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80); - const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96); - const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112); + const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); #if defined(__ARM_FEATURE_DOTPROD) - accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); - accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); - accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); - accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); - accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); - accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); - accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); - accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); + accu = vdotq_s32(accu, q8_0, yq8_0); + accu = vdotq_s32(accu, q8_1, yq8_1); + accu = vdotq_s32(accu, q8_2, yq8_2); + accu = vdotq_s32(accu, q8_3, yq8_3); #else - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32))); + accu = vaddq_s32(accu, vmovl_high_s16(accu32)); #endif } + for (int i = 0; i < groupla_num; i++){ #if defined(__ARM_FEATURE_DOTPROD) #else - accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0))); - accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0)); - accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1))); - accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1)); - accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2))); - accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2)); - accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3))); - accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3)); + int16x8_t accula = vdupq_n_s16(0); #endif - } + for (int j = 0; j < la_num; j++) { + uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); - for (int i = 0; i < groupla_num; i++){ #if defined(__ARM_FEATURE_DOTPROD) - + accu = vdotq_s32(accu, q8_0, yq8_0); + accu = vdotq_s32(accu, q8_1, yq8_1); + accu = vdotq_s32(accu, q8_2, yq8_2); + accu = vdotq_s32(accu, q8_3, yq8_3); #else - int16x8_t accula_0 = vdupq_n_s16(0); - int16x8_t accula_1 = vdupq_n_s16(0); - int16x8_t accula_2 = vdupq_n_s16(0); - int16x8_t accula_3 = vdupq_n_s16(0); + accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3)); #endif - for (int j = 0; j < la_num; j++) { - uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32); - uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16); - uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); - uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); - uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); - int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); - int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); - int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48); - const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64); - const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80); - const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96); - const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112); - + } #if defined(__ARM_FEATURE_DOTPROD) - accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); - accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); - accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); - accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); - accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); - accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); - accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); - accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); + #else - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); + accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula))); + accu = vaddq_s32(accu, vmovl_high_s16(accula)); #endif } + int sumi = vaddlvq_s32(accu); + s[row] = (float)sumi; + } +#endif +} + +void ggml_vec_dot_i2_i8_s_1x4_32W(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const __m256i mask = _mm256_set1_epi8(0x03); + const __m256i one16 = _mm256_set1_epi16(1); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row+=4) { + __m256i accu[4]; + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_setzero_si256(); + } + const uint8_t * x_row = x + (row) * bx / 4; + // 计算当前行的x指针偏移 + + for (int i = 0; i < group32_num; i++) { + const uint8_t * px = x_row + i * 1024 * 4; + __m256i accu32[4]; + for(int rb = 0; rb < 4; rb++) { + accu32[rb] = _mm256_setzero_si256(); + } + const int8_t *py = y + i * 4096; + + for (int j = 0; j < 32 * 4; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i xq8[4]; + xq8[3] = _mm256_loadu_si256((const __m256i*)(px)); + xq8[2] = _mm256_srli_epi16(xq8[3], 2); + xq8[1] = _mm256_srli_epi16(xq8[3], 4); + xq8[0] = _mm256_srli_epi16(xq8[3], 6); + xq8[3] = _mm256_and_si256(xq8[3], mask); + xq8[2] = _mm256_and_si256(xq8[2], mask); + xq8[1] = _mm256_and_si256(xq8[1], mask); + xq8[0] = _mm256_and_si256(xq8[0], mask); + for (int rb = 0; rb < 4; rb++) + { + xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0); + accu32[rb] = _mm256_add_epi16(accu32[rb], xq8[rb]); + } + px += 32; + py += 32; + } + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const int8_t *py = y + group32_num * 4096; // 32 * 128 + __m256i accula[4]; + for(int rb = 0; rb < 4; rb++) { + accula[rb] = _mm256_setzero_si256(); + } + const uint8_t * px = x_row + group32_num * 1024 * 4; + + for (int j = 0; j < la_num * 4; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i xq8[4]; + xq8[3] = _mm256_loadu_si256((const __m256i*)(px)); + xq8[2] = _mm256_srli_epi16(xq8[3], 2); + xq8[1] = _mm256_srli_epi16(xq8[3], 4); + xq8[0] = _mm256_srli_epi16(xq8[3], 6); + xq8[3] = _mm256_and_si256(xq8[3], mask); + xq8[2] = _mm256_and_si256(xq8[2], mask); + xq8[1] = _mm256_and_si256(xq8[1], mask); + xq8[0] = _mm256_and_si256(xq8[0], mask); + + for (int rb = 0; rb < 4; rb++) { + xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0); + accula[rb] = _mm256_add_epi16(accula[rb], xq8[rb]); + } + px += 32; + py += 32; + } + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16)); + } + } + + for(int rb = 0; rb < 4; rb++) { + int sumi = hsum_i32_8(accu[rb]); + s[row + rb] = (float)sumi; + } + } +#elif defined(__ARM_NEON) + +#endif +} + +void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const __m256i mask = _mm256_set1_epi8(0x03); + const __m256i one16 = _mm256_set1_epi16(1); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row+=PARALLEL_SIZE) { + //__m256i accu = _mm256_setzero_si256(); + __m256i accu[PARALLEL_SIZE]; + const uint8_t * x_row[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_setzero_si256(); + x_row[rb] = x + (row + rb) * bx / 4; + } + // 计算当前行的x指针偏移 + + for (int i = 0; i < group32_num; i++) { + const uint8_t * px[PARALLEL_SIZE]; + __m256i accu32[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + i * 1024; // 32 * 32 + accu32[rb] = _mm256_setzero_si256(); + } + const int8_t *py = y + i * 4096; // 32 * 128 + + for (int j = 0; j < 32; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + for (int rb = 0; rb < PARALLEL_SIZE; rb++) + { + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb])); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_0, xq8_1)); + accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_2, xq8_3)); + + px[rb] += 32; + } + py += 128; + } + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const int8_t *py = y + group32_num * 4096; // 32 * 128 + const uint8_t * px[PARALLEL_SIZE]; + __m256i accula[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + group32_num * 1024; // 32 * 32 + accula[rb] = _mm256_setzero_si256(); + } + + for (int j = 0; j < la_num; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb])); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_0, xq8_1)); + accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_2, xq8_3)); + + px[rb] += 32; + } + py += 128; + } + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16)); + } + } + + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + int sumi = hsum_i32_8(accu[rb]); + s[row + rb] = (float)sumi; + } + } +#elif defined(__ARM_NEON) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const uint8x16_t mask = vdupq_n_u8(3); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row += PARALLEL_SIZE) { + + int32x4_t accu[PARALLEL_SIZE]; + const uint8_t * x_row[PARALLEL_SIZE]; + + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vdupq_n_s32(0); + x_row[rb] = x + (row + rb) * bx / 4; + } + + for (int i = 0; i < group32_num; i++) { #if defined(__ARM_FEATURE_DOTPROD) #else - accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0))); - accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0)); - accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1))); - accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1)); - accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2))); - accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2)); - accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3))); - accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3)); + int16x8_t accu32[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu32[rb] = vdupq_n_s16(0); + } #endif - } - accu_0 = vaddq_s32(accu_0, accu_1); - accu_2 = vaddq_s32(accu_2, accu_3); - accu_0 = vaddq_s32(accu_0, accu_2); - int sumi = vaddlvq_s32(accu_0); - *s = (float)sumi; + const uint8_t * px[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + i * 32 * 16; + } + + for (int j = 0; j < 32; j++) { + // 加载 y 数据(对所有行共享) + const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); + + // 处理每一行 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); + accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); + accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); + accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); +#else + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + +#endif + px[rb] += 16; + } + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb]))); + accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb])); + } +#endif + } + + for (int i = 0; i < groupla_num; i++) { +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accula[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accula[rb] = vdupq_n_s16(0); + } +#endif + const uint8_t * px[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + group32_num * 32 * 16; + } + + for (int j = 0; j < la_num; j++) { + // 加载 y 数据(对所有行共享) + const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); + + // 处理每一行 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); + accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); + accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); + accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); +#else + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); #endif + px[rb] += 16; + } + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb]))); + accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb])); + } +#endif + } + + // 合并结果并写回 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + int sumi = vaddlvq_s32(accu[rb]); + s[row + rb] = (float)sumi; + } + } +#endif +} + +void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + __m256i mask = _mm256_set1_epi8(0x03); + __m256i one16 = _mm256_set1_epi16(1); + + for (int col = 0; col < nrc; col += PARALLEL_SIZE) { + __m256i accu[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_setzero_si256(); + } + + int8_t * y_col = y + col * by; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x + i * 1024; + const int8_t *py = y_col + i * 4096; + __m256i accu32[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu32[iy] = _mm256_setzero_si256(); + } + + for (int j = 0; j < 32; j++) { + + __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_3 = _mm256_and_si256(xq8, mask); + __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask); + __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask); + __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask); + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) + { + accu32[iy] = _mm256_add_epi16(accu32[iy], _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))), + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by)))))); + } + + px += 32; + py += 128; + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accu32[iy], one16), accu[iy]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const uint8_t *px = x + group32_num * 1024; + const int8_t *py = y_col + group32_num * 4096; + __m256i accula[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accula[iy] = _mm256_setzero_si256(); + } + + for (int j = 0; j < la_num; j++) { + + __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_3 = _mm256_and_si256(xq8, mask); + __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask); + __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask); + __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask); + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) + { + accula[iy] = _mm256_add_epi16(accula[iy], _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))), + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by)))))); + } + + px += 32; + py += 128; + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accula[iy], one16), accu[iy]); + } + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + int sumi = hsum_i32_8(accu[iy]); + s[(col + iy) * bs] = (float)sumi; + } + } +#elif defined(__ARM_NEON) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const uint8x16_t mask = vdupq_n_u8(3); + + for (int col = 0; col < nrc; col += PARALLEL_SIZE) { + int32x4_t accu[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vdupq_n_s32(0); + } + + const int8_t * y_col = y + col * by; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x + i * 512; // i * 32 * 16 + const int8_t *py = y_col + i * 2048; // i * 32 * 64 + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accu32[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu32[iy] = vdupq_n_s16(0); + } +#endif + for (int j = 0; j < 32; j++) { + // 加载并解包 x 数据(对所有列共享) + uint8x16_t xq8_3 = vld1q_u8(px + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + // 处理每一列 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); + const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); + const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); + const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); + accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); + accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); + accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); +#else + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } + + px += 16; + py += 64; + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy])))); + } +#endif + } + + for (int i = 0; i < groupla_num; i++) { + const uint8_t *px = x + group32_num * 512; + const int8_t *py = y_col + group32_num * 2048; + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accula[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accula[iy] = vdupq_n_s16(0); + } +#endif + + for (int j = 0; j < la_num; j++) { + // 加载并解包 x 数据(对所有列共享) + uint8x16_t xq8_3 = vld1q_u8(px + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + // 处理每一列 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); + const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); + const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); + const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); + accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); + accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); + accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); +#else + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } + + px += 16; + py += 64; + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy])))); + } +#endif + } + + // 合并结果并写回 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + int sumi = vaddlvq_s32(accu[iy]); + s[(col + iy) * bs] = (float)sumi; + } + } +#endif +} + + +void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + if (nrc % PARALLEL_SIZE == 0) + { +#if defined(ACT_PARALLEL) + ggml_vec_dot_i2_i8_s_Nx1(n, s, bs, vx, bx, vy, by, nrc); +#else + ggml_vec_dot_i2_i8_s_1xN(n, s, bs, vx, bx, vy, by, nrc); +#endif + } + else + { + ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc); + } } \ No newline at end of file