#include "bitnet_kernels.h" extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){ if (M == 1 && N == 3840 && K == 2560){ ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<>>(input0, input1, output0, s, ws); } else if (M == 1 && N == 2560 && K == 2560){ ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else if (M == 1 && N == 13824 && K == 2560){ ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<>>(input0, input1, output0, s, ws); } else if (M == 1 && N == 2560 && K == 6912){ ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 4800 && K == 3200){ ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 3200 && K == 3200){ ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 20480 && K == 3200){ ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 3200 && K == 10240){ ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 5120 && K == 27648){ ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else if(M == 1 && N == 55296 && K == 5120){ ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<>>(input0, input1, output0, s, ws); } else{ std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl; } }