diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index c1e012683a0..30018138e0a 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -1,5 +1,5 @@ #=============================================================================== -# Copyright 2021 Intel Corporation +# Copyright 2021-2023 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,6 +66,21 @@ else() endif() message(STATUS "Enabled primitive GPU ISA: ${DNNL_ENABLE_PRIMITIVE_GPU_ISA}") +if (ONEDNN_ENABLE_GEMM_KERNELS_ISA STREQUAL "ALL") + set(BUILD_GEMM_KERNELS_ALL TRUE) +elseif (ONEDNN_ENABLE_GEMM_KERNELS_ISA STREQUAL "NONE") + set(BUILD_GEMM_KERNELS_NONE TRUE) +else() + foreach(isa ${ONEDNN_ENABLE_GEMM_KERNELS_ISA}) + string(TOUPPER ${isa} uisa) + if(NOT "${uisa}" MATCHES "^(SSE41|AVX2|AVX512)$") + message(FATAL_ERROR "Unsupported primitive CPU ISA: ${uisa}") + endif() + set(BUILD_GEMM_${uisa} TRUE) + endforeach() +endif() +message(STATUS "Enabled GeMM kernels ISA: ${ONEDNN_ENABLE_GEMM_KERNELS_ISA}") + # When certain primitives or primitive ISA are switched off, some functions may # become unused which is expected. Switch off warning for unused functions in # such cases. diff --git a/cmake/dnnl_compat.cmake b/cmake/dnnl_compat.cmake index 27e136afc37..c600637ed13 100644 --- a/cmake/dnnl_compat.cmake +++ b/cmake/dnnl_compat.cmake @@ -61,6 +61,8 @@ set(COMPAT_CACHE_STRING_VARS "LIBRARY_NAME" "ENABLE_WORKLOAD" "ENABLE_PRIMITIVE" + "ENABLE_PRIMITIVE_CPU_ISA" + "ENABLE_PRIMITIVE_GPU_ISA" "ARCH_OPT_FLAGS" "CPU_RUNTIME" "GPU_RUNTIME" diff --git a/cmake/options.cmake b/cmake/options.cmake index 2db3de7439d..6b295049b96 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -149,6 +149,16 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING - ;;... Includes only selected ISA to be enabled. Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC.") +set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING + "Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be + available at build time. Valid values: + - ALL (the default). Includes all ISA kernels to be enabled. + - NONE. Removes all kernels and interfaces. + - . Enables all ISA up to ISA_NAME included. + Possible value are: SSE41, AVX2, AVX512. The linear order is + SSE41 < AVX2 < AVX512 < AMX (or ALL). It means that if user selects, e.g. + AVX2 ISA, SSE41 kernels will also present at build time.") + # ============= # Optimizations # ============= diff --git a/doc/build/build_options.md b/doc/build/build_options.md index ca8f6630872..d00e29977f3 100644 --- a/doc/build/build_options.md +++ b/doc/build/build_options.md @@ -24,6 +24,7 @@ oneDNN supports the following build-time options. | ONEDNN_ENABLE_PRIMITIVE | **ALL**, PRIMITIVE_NAME | Specifies a set of functionality to be available based on primitives | | ONEDNN_ENABLE_PRIMITIVE_CPU_ISA | **ALL**, CPU_ISA_NAME | Specifies a set of functionality to be available for CPU backend based on CPU ISA | | ONEDNN_ENABLE_PRIMITIVE_GPU_ISA | **ALL**, GPU_ISA_NAME | Specifies a set of functionality to be available for GPU backend based on GPU ISA | +| ONEDNN_ENABLE_GEMM_KERNELS_ISA | **ALL**, NONE, ISA_NAME | Specifies a set of functionality to be available for GeMM kernels for CPU backend based on ISA | | ONEDNN_EXPERIMENTAL | ON, **OFF** | Enables [experimental features](@ref dev_guide_experimental) | | ONEDNN_VERBOSE | **ON**, OFF | Enables [verbose mode](@ref dev_guide_verbose) | | ONEDNN_AARCH64_USE_ACL | ON, **OFF** | Enables integration with Arm Compute Library for AArch64 builds | @@ -109,6 +110,17 @@ always be available. Example that enables XeLP and XeHP set: -DONEDNN_ENABLE_PRIMITIVE_GPU_ISA=XELP;XEHP ``` +#### ONEDNN_ENABLE_GEMM_KERNELS_ISA +This option supports several values: `ALL` (the default) which enables all +ISA kernels from x64/gemm folder, `NONE` which disables all kernels and removes +correspondent interfaces, or one of `SSE41`, `AVX2`, and `AVX512`. Values are +linearly ordered as `SSE41` < `AVX2` < `AVX512`. When specified, selected ISA +and all ISA that are "smaller" will be available. Example that leaves SSE41 and +AVX2 sets, but removes AVX512 and AMX kernels: +``` +-DONEDNN_ENABLE_GEMM_KERNELS_ISA=AVX2 +``` + ## CPU Options Intel Architecture Processors and compatible devices are supported by oneDNN CPU engine. The CPU engine is built by default but can be disabled diff --git a/include/oneapi/dnnl/dnnl_config.h.in b/include/oneapi/dnnl/dnnl_config.h.in index 7d8536ac641..5fb44873d0d 100644 --- a/include/oneapi/dnnl/dnnl_config.h.in +++ b/include/oneapi/dnnl/dnnl_config.h.in @@ -193,4 +193,10 @@ #cmakedefine01 BUILD_XEHP #cmakedefine01 BUILD_XEHPG #cmakedefine01 BUILD_XEHPC +// GeMM kernels ISA controls +#cmakedefine01 BUILD_GEMM_KERNELS_ALL +#cmakedefine01 BUILD_GEMM_KERNELS_NONE +#cmakedefine01 BUILD_GEMM_SSE41 +#cmakedefine01 BUILD_GEMM_AVX2 +#cmakedefine01 BUILD_GEMM_AVX512 #endif diff --git a/src/cpu/gemm/gemm.cpp b/src/cpu/gemm/gemm.cpp index ac398509af7..4dcb1fde288 100644 --- a/src/cpu/gemm/gemm.cpp +++ b/src/cpu/gemm/gemm.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2022 Intel Corporation +* Copyright 2018-2023 Intel Corporation * Copyright 2022 IBM Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -134,13 +134,14 @@ dnnl_status_t extended_sgemm(const char *transa, const char *transb, } #endif -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE if (mayiuse(sse41)) { float *dummy_ao = nullptr; float *dummy_bo = nullptr; - return gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, alpha, - A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias, + auto status = gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, + alpha, A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias, force_jit_nocopy_gemm); + if (status == status::success) return status; } #endif @@ -201,10 +202,12 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, LDA, ao, B, LDB, bo, beta, C, LDC, co); if (status == dnnl_success) return status; -#if DNNL_X64 - if (mayiuse(sse41)) - return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, - B, LDB, bo, beta, C, LDC, co, false); +#if DNNL_X64 && !__BUILD_GEMM_NONE + if (mayiuse(sse41)) { + auto status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, + LDA, ao, B, LDB, bo, beta, C, LDC, co, false); + if (status == status::success) return status; + } #elif DNNL_PPC64 #ifdef __MMA__ int ATflag = (*transa == 'T') || (*transa == 't'); @@ -237,18 +240,23 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE bool use_jit = mayiuse(avx512_core); bool use_s8u8 = true && utils::everyone_is(0, *ao, *bo) // so far a requirement && IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41)); - if (use_jit) - return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, - B, LDB, bo, beta, C, LDC, co, false); - else if (use_s8u8) - return simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, alpha, A, - LDA, ao, B, LDB, bo, beta, C, LDC, co); + if (use_jit) { + auto status = gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, + LDA, ao, B, LDB, bo, beta, C, LDC, co, false); + if (status == status::success) return status; + } + + if (use_s8u8) { + auto status = simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + if (status == status::success) return status; + } #endif #if DNNL_PPC64 @@ -285,16 +293,18 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, ldb, C, ldc, alpha, beta, false); if (status != dnnl_success) return status; -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE char *dummyOffsetC = nullptr; bfloat16_t *dummy_ao = nullptr; bfloat16_t *dummy_bo = nullptr; float *dummy_co = nullptr; - if (mayiuse(avx512_core)) - return gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, + if (mayiuse(avx512_core)) { + auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, (const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B, ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false); + if (status == status::success) return status; + } #elif DNNL_PPC64 #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) bool trA = *transa == 't' || *transa == 'T'; diff --git a/src/cpu/gemm/gemm.hpp b/src/cpu/gemm/gemm.hpp index 35082757a11..be583890f18 100644 --- a/src/cpu/gemm/gemm.hpp +++ b/src/cpu/gemm/gemm.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2022 Intel Corporation +* Copyright 2018-2023 Intel Corporation * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,19 @@ #if DNNL_X64 #include "cpu/x64/cpu_isa_traits.hpp" + +// Kernels ISA section for configuring knobs. +#define __BUILD_GEMM_AMX BUILD_GEMM_KERNELS_ALL +#define __BUILD_GEMM_AVX512 __BUILD_GEMM_AMX || BUILD_GEMM_AVX512 +#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2 +#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41 +#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE +#else +#define __BUILD_GEMM_AMX 0 +#define __BUILD_GEMM_AVX512 0 +#define __BUILD_GEMM_AVX2 0 +#define __BUILD_GEMM_SSE41 0 +#define __BUILD_GEMM_NONE 0 #endif #if DNNL_AARCH64 diff --git a/src/cpu/gemm/gemm_pack.cpp b/src/cpu/gemm/gemm_pack.cpp index 496a32e5b66..f67549cbf27 100644 --- a/src/cpu/gemm/gemm_pack.cpp +++ b/src/cpu/gemm/gemm_pack.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020 Intel Corporation +* Copyright 2020-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "cpu/platform.hpp" +#include "cpu/gemm/gemm.hpp" #include "cpu/gemm/gemm_pack.hpp" #if DNNL_X64 @@ -27,13 +28,13 @@ namespace impl { namespace cpu { bool pack_sgemm_supported() { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::pack_sgemm_supported(); #endif return false; } bool pack_gemm_bf16bf16f32_supported() { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::pack_gemm_bf16bf16f32_supported(); #endif return false; @@ -42,7 +43,7 @@ bool pack_gemm_bf16bf16f32_supported() { dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -53,7 +54,7 @@ dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -64,7 +65,7 @@ dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -75,7 +76,7 @@ dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_pack_get_size( identifier, transa, transb, M, N, K, lda, ldb, size, pack); #endif @@ -85,7 +86,7 @@ dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, dnnl_status_t sgemm_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const float *src, float *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -96,7 +97,7 @@ dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const bfloat16_t *src, bfloat16_t *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -106,7 +107,7 @@ dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -116,7 +117,7 @@ dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_pack( identifier, transa, transb, M, N, K, lda, ldb, src, dst); #endif @@ -127,7 +128,7 @@ dnnl_status_t sgemm_compute(const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const float *A, const dim_t *lda, const float *B, const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::sgemm_compute( transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); #endif @@ -138,7 +139,7 @@ dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_bf16bf16f32_compute( transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); #endif @@ -149,7 +150,7 @@ dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8u8s32_compute( transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); #endif @@ -160,7 +161,7 @@ dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { -#if DNNL_X64 +#if DNNL_X64 && !__BUILD_GEMM_NONE return x64::gemm_s8s8s32_compute( transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); #endif diff --git a/src/cpu/rnn/rnn_utils.hpp b/src/cpu/rnn/rnn_utils.hpp index 8af361d63a0..21ca3f11fb4 100644 --- a/src/cpu/rnn/rnn_utils.hpp +++ b/src/cpu/rnn/rnn_utils.hpp @@ -873,7 +873,7 @@ bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, rnn.diff_weights_overwrite = rd.flags & rnn_flags::diff_weights_overwrite; -#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL +#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL || BUILD_GEMM_KERNELS_NONE // XXX: Threadpool runtime may use different number of threads at execute // and create stages. GEMM packed API is not aware of number of threads as // of now. In order to synchronize all layers, GEMM pack API should be diff --git a/src/cpu/x64/CMakeLists.txt b/src/cpu/x64/CMakeLists.txt index 5cacdb215d0..75b007892b6 100644 --- a/src/cpu/x64/CMakeLists.txt +++ b/src/cpu/x64/CMakeLists.txt @@ -56,6 +56,35 @@ else() PROPERTIES COMPILE_FLAGS "${OPT_LEVEL}") endif() +# Discard GeMM kernel files when requested +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(AVX512|AVX2|SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AMX ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*amx*) + foreach(amx_file ${SOURCES_AMX}) + list(REMOVE_ITEM SOURCES "${amx_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(AVX2|SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AVX512 ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*avx512*) + foreach(avx512_file ${SOURCES_AVX512}) + list(REMOVE_ITEM SOURCES "${avx512_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(SSE41|NONE)$") + file(GLOB_RECURSE SOURCES_AVX ${CMAKE_CURRENT_SOURCE_DIR}/gemm/jit*avx*) + foreach(avx_file ${SOURCES_AVX}) + list(REMOVE_ITEM SOURCES "${avx_file}") + endforeach() +endif() + +if(ONEDNN_ENABLE_GEMM_KERNELS_ISA MATCHES "^(NONE)$") + file(GLOB_RECURSE SOURCES_SSE41 ${CMAKE_CURRENT_SOURCE_DIR}/gemm/*) + foreach(sse41_file ${SOURCES_SSE41}) + list(REMOVE_ITEM SOURCES "${sse41_file}") + endforeach() +endif() + set(OBJ_LIB ${LIB_PACKAGE_NAME}_cpu_x64) add_library(${OBJ_LIB} OBJECT ${SOURCES}) set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS diff --git a/src/cpu/x64/gemm/gemm_driver.cpp b/src/cpu/x64/gemm/gemm_driver.cpp index 0b66a9adba7..0e035e75279 100644 --- a/src/cpu/x64/gemm/gemm_driver.cpp +++ b/src/cpu/x64/gemm/gemm_driver.cpp @@ -29,6 +29,7 @@ #include "cpu/platform.hpp" #include "cpu/gemm/f32/gemm_utils_f32.hpp" +#include "cpu/gemm/gemm.hpp" #include "cpu/gemm/gemm_msan_unpoison.hpp" #include "cpu/x64/jit_generator.hpp" @@ -1657,21 +1658,29 @@ static dnnl_status_t call_no_copy_sgemm( int nthrs, gemm_info_t *arg) { if (arg->packing == pack_type::none) { +#if __BUILD_GEMM_AVX2 auto transa_char = (arg->transa != do_trans) ? "N" : "T"; auto transb_char = (arg->transb != do_trans) ? "N" : "T"; +#endif - if (mayiuse(avx512_core)) + if (mayiuse(avx512_core)) { +#if __BUILD_GEMM_AVX512 return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, &arg->ldc, (float *)arg->co); - else +#endif + } else { +#if __BUILD_GEMM_AVX2 return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, &arg->ldc, (float *)arg->co); +#endif + } } else return pack_no_copy(arg); + return status::unimplemented; } template @@ -1687,12 +1696,14 @@ static dnnl_status_t gemm_threading_driver( if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success; +#if __BUILD_GEMM_AVX512 if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg)) return dnnl_success; if (!is_a_packed && !is_b_packed && jump_to_gemm_smalln_tn(arg) == dnnl_success) return dnnl_success; +#endif if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success) return dnnl_success; @@ -1927,6 +1938,7 @@ static dnnl_status_t gemm_threading_driver( assert(arg->packing == pack_type::none); if (mayiuse(avx512_core)) { +#if __BUILD_GEMM_AVX512 thread_arg[ithr].result = avx512_common_gemm_f32:: sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -1935,7 +1947,9 @@ static dnnl_status_t gemm_threading_driver( arg->lda, (float *)b, arg->ldb, &beta_eff, (float *)c_eff, ldc_eff, nullptr); +#endif } else { +#if __BUILD_GEMM_AVX2 thread_arg[ithr].result = avx_gemm_f32::sgemm_nocopy_driver( arg->transa == no_trans ? "N" : "T", @@ -1944,6 +1958,7 @@ static dnnl_status_t gemm_threading_driver( arg->lda, (float *)b, arg->ldb, &beta_eff, (float *)c_eff, ldc_eff, nullptr); +#endif } break; } diff --git a/src/cpu/x64/gemm/gemm_info.cpp b/src/cpu/x64/gemm/gemm_info.cpp index 849edde47e1..acab57ca9eb 100644 --- a/src/cpu/x64/gemm/gemm_info.cpp +++ b/src/cpu/x64/gemm/gemm_info.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2022 Intel Corporation +* Copyright 2019-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,8 @@ #include "common/bfloat16.hpp" #include "common/dnnl_traits.hpp" -#include "common/dnnl_sel_build.hpp" + +#include "cpu/gemm/gemm.hpp" #include "cpu/x64/cpu_isa_traits.hpp" #include "cpu/x64/jit_generator.hpp" @@ -359,7 +360,9 @@ void gemm_info_t::jit_init(void) { static std::once_flag initialized; static std::atomic st(dnnl_success); std::call_once(initialized, [&, um] { +#if __BUILD_GEMM_AVX512 const bool b_is_s8 = data_traits::data_type == data_type::s8; +#endif constexpr bool is_int8 = utils::one_of( data_traits::data_type, data_type::s8, data_type::u8); constexpr bool is_bf16 = data_traits::data_type == data_type::bf16; @@ -373,7 +376,7 @@ void gemm_info_t::jit_init(void) { switch (data_traits::data_type) { case data_type::s8: if (mayiuse(amx_int8)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_amx_int8) { +#if __BUILD_GEMM_AMX for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -382,124 +385,124 @@ void gemm_info_t::jit_init(void) { copy_b[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( false, isTrans, sizeof(b_t))); - } } +#endif } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx512_core) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); - - copy_a[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); - copy_b[do_trans][do_sum].reset( - new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); - } +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); + + copy_a[no_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); + copy_b[do_trans][do_sum].reset( + new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); +#endif } else if (mayiuse(avx2_vnni)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx2_vnni) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_vnni_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx2_vnni_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_vnni_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx2_vnni_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx2) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx2_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx2_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_avx) { - copy_a[no_trans][no_sum].reset( - new jit_avx_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_avx_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_avx_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_avx_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_avx_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_avx_u8_copy_sum_bt_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_copy_kern_s8_sse41) { - copy_a[no_trans][no_sum].reset( - new jit_sse41_u8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_sse41_u8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_sse41_u8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_sse41_u8_copy_bt_kern()); - - copy_a[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_an_kern()); - copy_a[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_at_kern()); - - copy_b[no_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bn_kern()); - copy_b[do_trans][do_sum].reset( - new jit_sse41_u8_copy_sum_bt_kern()); - } +#if __BUILD_GEMM_SSE41 + copy_a[no_trans][no_sum].reset( + new jit_sse41_u8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_sse41_u8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_sse41_u8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_sse41_u8_copy_bt_kern()); + + copy_a[no_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_an_kern()); + copy_a[do_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_at_kern()); + + copy_b[no_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_bn_kern()); + copy_b[do_trans][do_sum].reset( + new jit_sse41_u8_copy_sum_bt_kern()); +#endif } break; case data_type::bf16: if (mayiuse(amx_bf16)) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_amx_bf16) { +#if __BUILD_GEMM_AMX for (int isTrans : {no_trans, do_trans}) { copy_a[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( @@ -508,213 +511,215 @@ void gemm_info_t::jit_init(void) { copy_b[isTrans][no_sum].reset( new jit_avx512_core_amx_copy_kern( false, isTrans, sizeof(b_t))); - } } +#endif } else if (mayiuse(avx512_core) && !use_bf16_ymm) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_not_use_bf16_ymm) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_48x8_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_s16_48x8_copy_bt_kern()); +#endif } else if (mayiuse(avx512_core) && use_bf16_ymm) { - DNNL_CSCOPE(jit_init_copy_kern_bf16_avx512_core_use_bf16_ymm) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_s16_24x8_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_s16_24x8_copy_bt_kern()); +#endif } break; case data_type::f32: if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx512_core) { - copy_a[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx512_core_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX512 + copy_a[no_trans][no_sum].reset( + new jit_avx512_core_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx512_core_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx512_core_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx512_core_f32_copy_bt_kern()); +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx2) { - copy_a[no_trans][no_sum].reset( - new jit_avx2_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx2_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx2_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx2_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx2_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx2_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx2_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx2_f32_copy_bt_kern()); +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_avx) { - copy_a[no_trans][no_sum].reset( - new jit_avx_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_avx_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_avx_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_avx_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_avx_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_avx_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_avx_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_avx_f32_copy_bt_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_copy_kern_f32_sse41) { - copy_a[no_trans][no_sum].reset( - new jit_sse41_f32_copy_an_kern()); - copy_a[do_trans][no_sum].reset( - new jit_sse41_f32_copy_at_kern()); - - copy_b[no_trans][no_sum].reset( - new jit_sse41_f32_copy_bn_kern()); - copy_b[do_trans][no_sum].reset( - new jit_sse41_f32_copy_bt_kern()); - } +#if __BUILD_GEMM_AVX2 + copy_a[no_trans][no_sum].reset( + new jit_sse41_f32_copy_an_kern()); + copy_a[do_trans][no_sum].reset( + new jit_sse41_f32_copy_at_kern()); + + copy_b[no_trans][no_sum].reset( + new jit_sse41_f32_copy_bn_kern()); + copy_b[do_trans][no_sum].reset( + new jit_sse41_f32_copy_bt_kern()); +#endif } break; default: break; } +#if __BUILD_GEMM_AMX constexpr bool is_a_s8 = data_traits::data_type == data_type::s8; constexpr bool is_b_s8 = data_traits::data_type == data_type::s8; constexpr bool is_c_s32 = data_traits::data_type == data_type::s32; +#endif static maybe_unique_ptr kernel[2][2][2][2] = {{{{nullptr}}}}; switch (data_traits::data_type) { case data_type::s8: if (mayiuse(avx512_core_amx)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core_bf16_amx_int8) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( - is_a_s8, is_b_s8, is_c_s32, isBeta0)); - } +#if __BUILD_GEMM_AMX + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx512_core_amx_gemm_kern( + is_a_s8, is_b_s8, is_c_s32, isBeta0)); } +#endif } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx512_core) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int doColSum : {no_sum, do_sum}) - for (int doRowSum : {no_sum, do_sum}) { - kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( - new jit_avx512_core_gemm_s8u8s32_kern( - isBeta0, doColSum, doRowSum)); - } - } +#if __BUILD_GEMM_AVX512 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int doColSum : {no_sum, do_sum}) + for (int doRowSum : {no_sum, do_sum}) { + kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( + new jit_avx512_core_gemm_s8u8s32_kern( + isBeta0, doColSum, doRowSum)); + } +#endif } else if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx2) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int doColSum : {no_sum, do_sum}) - for (int doRowSum : {no_sum, do_sum}) { - kernel[isBeta0][do_alpha1][doColSum][doRowSum] - .reset(new jit_avx2_gemm_s8u8s32_kern( - isBeta0, doColSum, doRowSum, - um)); - } - } +#if __BUILD_GEMM_AVX2 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int doColSum : {no_sum, do_sum}) + for (int doRowSum : {no_sum, do_sum}) { + kernel[isBeta0][do_alpha1][doColSum][doRowSum] + .reset(new jit_avx2_gemm_s8u8s32_kern( + isBeta0, doColSum, doRowSum, + um)); + } +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_avx) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_c_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_r_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b_gemm_s8u8s32_kern()); - - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); - } +#if __BUILD_GEMM_AVX2 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_avx_kernel_c_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_avx_kernel_r_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_avx_kernel_b_gemm_s8u8s32_kern()); + + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_b0_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemm_kern_s8_sse41) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_c_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_r_gemm_s8u8s32_kern()); - kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b_gemm_s8u8s32_kern()); - - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( - new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( - new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); - kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( - new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); - } +#if __BUILD_GEMM_SSE41 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_sse41_kernel_c_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_sse41_kernel_r_gemm_s8u8s32_kern()); + kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_sse41_kernel_b_gemm_s8u8s32_kern()); + + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( + new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( + new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); + kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( + new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); +#endif } break; case data_type::bf16: if (mayiuse(avx512_core_amx)) { - DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core_bf16_amx_bf16) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx512_core_amx_gemm_kern( - is_a_s8, is_b_s8, is_c_s32, isBeta0)); - } +#if __BUILD_GEMM_AMX + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx512_core_amx_gemm_kern( + is_a_s8, is_b_s8, is_c_s32, isBeta0)); } +#endif } else if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemm_kern_bf16_avx512_core) { - for (int isBeta0 : {no_beta0, do_beta0}) - for (int isAlpha1 : {no_alpha1, do_alpha1}) { - kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( - new jit_avx512_core_gemm_bf16bf16f32_kern( - isBeta0, isAlpha1, !use_bf16_ymm)); - } - } +#if __BUILD_GEMM_AVX512 + for (int isBeta0 : {no_beta0, do_beta0}) + for (int isAlpha1 : {no_alpha1, do_alpha1}) { + kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( + new jit_avx512_core_gemm_bf16bf16f32_kern( + isBeta0, isAlpha1, !use_bf16_ymm)); + } +#endif } break; case data_type::f32: if (mayiuse(avx2)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_avx2) { - for (int isBeta0 : {no_beta0, do_beta0}) { - kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx2_kernel_sgemm_kern(isBeta0)); - } +#if __BUILD_GEMM_AVX2 + for (int isBeta0 : {no_beta0, do_beta0}) { + kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx2_kernel_sgemm_kern(isBeta0)); } +#endif } else if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_avx) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_sgemm_kern()); - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_avx_kernel_b0_sgemm_kern()); - } +#if __BUILD_GEMM_AVX2 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_sgemm_kern()); + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_avx_kernel_b0_sgemm_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemm_kern_f32_sse41) { - kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_sgemm_kern()); - kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( - new jit_sse41_kernel_b0_sgemm_kern()); - } +#if __BUILD_GEMM_SSE41 + kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_sgemm_kern()); + kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( + new jit_sse41_kernel_b0_sgemm_kern()); +#endif } break; @@ -728,42 +733,42 @@ void gemm_info_t::jit_init(void) { switch (data_traits::data_type) { case data_type::s8: if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemv_kern_s8_avx512_core) { - gemv_s8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); - gemv_s8u8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); - gemv_u8s8s32_kernel.reset( - new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); - } +#if __BUILD_GEMM_AVX512 + gemv_s8s8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); + gemv_s8u8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); + gemv_u8s8s32_kernel.reset( + new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); +#endif } break; case data_type::bf16: if (mayiuse(avx512_core)) { - DNNL_CSCOPE(jit_init_gemv_kern_bf16_avx512_core) { - for (int isTrans : {no_trans, do_trans}) - gemv_kernel[isTrans].reset( - new jit_avx512_core_gemv_bf16bf16f32_kern( - isTrans)); - } +#if __BUILD_GEMM_AVX512 + for (int isTrans : {no_trans, do_trans}) + gemv_kernel[isTrans].reset( + new jit_avx512_core_gemv_bf16bf16f32_kern( + isTrans)); +#endif } break; case data_type::f32: if (mayiuse(avx)) { - DNNL_CSCOPE(jit_init_gemv_kern_f32_avx) { - gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); - gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); - } +#if __BUILD_GEMM_AVX2 + gemv_kernel[no_trans].reset( + new jit_sse41_gemv_n_f32_kern()); + gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); +#endif } else if (mayiuse(sse41)) { - DNNL_CSCOPE(jit_init_gemv_kern_f32_sse41) { - gemv_kernel[no_trans].reset( - new jit_sse41_gemv_n_f32_kern()); - gemv_kernel[do_trans].reset( - new jit_sse41_gemv_t_f32_kern()); - } +#if __BUILD_GEMM_SSE41 + gemv_kernel[no_trans].reset( + new jit_sse41_gemv_n_f32_kern()); + gemv_kernel[do_trans].reset( + new jit_sse41_gemv_t_f32_kern()); +#endif } break; default: assert(!"unsupported data type!"); diff --git a/tests/gtests/in/gemm_in.h b/tests/gtests/in/gemm_in.h index bcabd886ce9..1ec1bf9b64a 100644 --- a/tests/gtests/in/gemm_in.h +++ b/tests/gtests/in/gemm_in.h @@ -128,6 +128,7 @@ CPU_INST_TEST_CASE(TestGEMM_stkmem, test_params {'n', 'n', 2, 16, 256, 1.0f, 0.0f, 256, 16, 16}); #if defined(FP32) || defined(BF16BF16F32) +#if !BUILD_GEMM_KERNELS_NONE INST_TEST_CASE(TestGEMM_packed, test_params {'t', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, {}, {false, true}, true, dnnl_invalid_arguments}, @@ -198,6 +199,7 @@ INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 't', 'n', 200, 300, 8000, 1.0f, 3.0f, 200, 300, 300)); #endif +#endif #elif defined(BF16BF16BF16) @@ -254,6 +256,7 @@ constexpr test_igemm_params fix_no_offsets = {'F', false, false, false}; constexpr test_igemm_params col_no_offsets = {'C', false, false, false}; constexpr test_igemm_params row_no_offsets = {'R', false, false, false}; +#if !BUILD_GEMM_KERNELS_NONE INST_TEST_CASE(TestGEMM_expected_failures, test_params {'t', 'n', 3, 2, 1, 1.0, 0.0, 2, 5, 8, {}, {}, true, dnnl_invalid_arguments}, @@ -290,6 +293,7 @@ INST_TEST_CASE(TestGEMM_expected_failures, true, dnnl_invalid_arguments}, test_params {'n', 'd', 3, 2, 1, 1.0, 0.0, 3, 3, 3, {}, {false, true}, true, dnnl_invalid_arguments}); +#endif CPU_INST_TEST_CASE(TestGEMM_stkmem, test_params {'n', 'n', 10, 4000, 2, 1.0, 0.0, 2, 4000, 4000, @@ -733,6 +737,7 @@ CPU_INST_TEST_CASE(TestGEMV_kblocking, test_params {'t', 'n', 1, 550, 7000, 1.0f, 1.0f, 7000, 550, 550, fix_no_offsets}); +#if !BUILD_GEMM_KERNELS_NONE CPU_INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 'N', 'n', 30, 20, 10, 1.0f, 1.0f, 60, 50, 80, fix_use_oc), @@ -830,6 +835,7 @@ CPU_INST_TEST_CASE(TestGEMM_packed, make_test_params_pack({false, true}, 'n', 'T', 1, 200, 200, 1.0f, 0.0f, 200, 200, 200, fix_no_offsets)); #endif +#endif CPU_INST_TEST_CASE(TestGEMM_heavy, test_params {'n', 'n', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, @@ -841,6 +847,7 @@ CPU_INST_TEST_CASE(TestGEMM_heavy, test_params {'t', 't', 3000, 3000, 3000, 1.0, 0.0, 3000, 3000, 3000, fix_use_oc}); +#if !BUILD_GEMM_KERNELS_NONE CPU_INST_TEST_CASE(TestGEMM_packed_heavy, make_test_params_pack({false, true}, 'n', 'n', 3000, 3000, 3000, 1.0f, 0.0f, 3000, 3000, 3000, fix_use_oc), @@ -874,5 +881,5 @@ CPU_INST_TEST_CASE(TestGEMM_packed_heavy, 3.0f, 8000, 8000, 200, row_use_oc), make_test_params_pack({false, true}, 't', 'n', 200, 300, 8000, 1.0f, 0.0f, 200, 300, 300, col_use_oc)); - +#endif #endif