Skip to content

Latest commit

 

History

History
1444 lines (1170 loc) · 73.8 KB

PHI_kernel_registration.md

File metadata and controls

1444 lines (1170 loc) · 73.8 KB

PHI算子库kernel注册全流程——以bitwise_add算子为例

bitwise_add这个算子为例。例如我们能在.cc.cu文件里面看到:

PD_REGISTER_KERNEL(bitwise_and,
                   CPU,
                   ALL_LAYOUT,
                   phi::BitwiseAndKernel,
                   bool,
                   uint8_t,
                   int8_t,
                   int16_t,
                   int,
                   int64_t) {}

面对一个个类型,如何让他们一个个注册?

显然,PD_REGISTER_KERNEL是注册算子的宏,我们进去看看(前方是一大堆套娃):

#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
  _PD_REGISTER_KERNEL(::phi::RegType::INNER,                                  \
                      kernel_name,                                            \
                      backend,                                                \
                      ::phi::backend##Context,                                \
                      layout,                                                 \
                      meta_kernel_fn,                                         \
                      FUNCTION_KERNEL_INSTANTIATION,                          \
                      ARG_PARSE_FUNCTOR,                                      \
                      PHI_KERNEL,                                             \
                      PHI_VARIADIC_KERNEL,                                    \
                      __VA_ARGS__)

首先要明确,#define会在编译的预处理阶段展开,通俗来说就是复制粘贴,所以.cc.cu最后短短的一个注册,其实会根据下面的套娃宏不断一层层展开。

可以看到,这里调用了_PD_REGISTER_KERNEL这个宏,我们继续展开:

#define _PD_REGISTER_KERNEL(reg_type,                                      \
                            kernel_name,                                   \
                            backend,                                       \
                            context,                                       \
                            layout,                                        \
                            meta_kernel_fn,                                \
                            kernel_instantiation_macro,                    \
                            arg_parse_functor_macro,                       \
                            kernel_unfold_macro,                           \
                            variadic_kernel_unfold_marco,                  \
                            ...)                                           \
  PD_STATIC_ASSERT_GLOBAL_NAMESPACE(                                       \
      PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
      "PD_REGISTER_KERNEL must be called in global namespace.");           \
  PD_EXPAND(_PD_REGISTER_2TA_KERNEL(reg_type,                              \
                                    kernel_name,                           \
                                    backend,                               \
                                    context,                               \
                                    layout,                                \
                                    meta_kernel_fn,                        \
                                    kernel_instantiation_macro,            \
                                    arg_parse_functor_macro,               \
                                    kernel_unfold_macro,                   \
                                    variadic_kernel_unfold_marco,          \
                                    __VA_ARGS__))

可以发现,这里调用了PD_STATIC_ASSERT_GLOBAL_NAMESPACEPD_EXPAND(x) x

  1. PD_STATIC_ASSERT_GLOBAL_NAMESPACE中,我们继续展开:

    #define PD_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
      _PD_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)

    可以看到是调用了_PD_STATIC_ASSERT_GLOBAL_NAMESPACE宏,我们继续展开:

    #define _PD_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                    \
      struct __test_global_namespace_##uniq_name##__ {};                          \
      static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                                 __test_global_namespace_##uniq_name##__>::value, \
                    msg)

    这一小部分套娃结束,我们回忆一下这里的uniq_name是什么?就是把kernel_name(例子中的bitwise_and)、backend(例子中的CPU)、layout(例子中的ALL_LAYOUT)拼接一下,例子中就是PD_REGISTER_tp_kernel_ns_check_bitwise_and_CPU_ALL_LAYOUT,然后目的就是用static_assert判断一下当前注册的时候,是不是在全局namespace中注册,如果不是在全局注册,则报错。

  2. PD_EXPAND中,我们继续展开:

    #define PD_EXPAND(x) x

    看样子它直接返回了输入?具体为什么加这个,还不太清楚。它包了一层_PD_REGISTER_2TA_KERNEL

    然后是在_PD_REGISTER_2TA_KERNEL,我们继续展开(以linux下为例):

    #define _PD_REGISTER_2TA_KERNEL(reg_type,                                   \
                                    kernel_name,                                \
                                    backend,                                    \
                                    context,                                    \
                                    layout,                                     \
                                    meta_kernel_fn,                             \
                                    kernel_instantiation_macro,                 \
                                    arg_parse_functor_macro,                    \
                                    kernel_unfold_macro,                        \
                                    variadic_kernel_unfold_marco,               \
                                    ...)                                        \
      static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
          const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
      PD_EXPAND(PD_KERNEL_REGISTRAR_INIT(                                       \
          reg_type,                                                             \
          kernel_name,                                                          \
          backend,                                                              \
          context,                                                              \
          layout,                                                               \
          &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \
          meta_kernel_fn,                                                       \
          arg_parse_functor_macro,                                              \
          kernel_unfold_macro,                                                  \
          variadic_kernel_unfold_marco,                                         \
          __VA_ARGS__));                                                        \
      void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
          const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)

    其中:

    Note: 2TA means 2 template argument

    在这个宏中,

    1. 声明了一个函数

      static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);

      在这个例子中,就是

      static void __PD_KERNEL_args_def_FN_bitwise_add_CPU_ALL_LAYOUT( 
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);

      这里暂时不展开,在稍后会提到这部分。

    2. 这个函数,在这里例子中,就是__PD_KERNEL_args_def_FN_bitwise_and_CPU_ALL_LAYOUT,而后,又是一个PD_EXPAND,套了一个PD_KERNEL_REGISTRAR_INIT宏,我们继续展开它:

      #define PD_KERNEL_REGISTRAR_INIT(reg_type,                          \
                                       kernel_name,                       \
                                       backend,                           \
                                       context,                           \
                                       layout,                            \
                                       args_def_fn,                       \
                                       meta_kernel_fn,                    \
                                       arg_parse_functor_macro,           \
                                       kernel_unfold_macro,               \
                                       variadic_kernel_unfold_marco,      \
                                       ...)                               \
        PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__),        \
                                            reg_type,                     \
                                            kernel_name,                  \
                                            backend,                      \
                                            context,                      \
                                            layout,                       \
                                            args_def_fn,                  \
                                            meta_kernel_fn,               \
                                            arg_parse_functor_macro,      \
                                            kernel_unfold_macro,          \
                                            variadic_kernel_unfold_marco, \
                                            __VA_ARGS__))

      然后又是一个PD_EXPAND套了一层_PD_KERNEL_REGISTRAR_INIT,我们继续展开:

      #define _PD_KERNEL_REGISTRAR_INIT(N,                       \
                                        reg_type,                \
                                        kernel_name,             \
                                        backend,                 \
                                        context,                 \
                                        layout,                  \
                                        args_def_fn,             \
                                        meta_kernel_fn,          \
                                        arg_parse_functor_macro,       \
                                        kernel_unfold_macro,               \
                                        variadic_kernel_unfold_marco,      \
                                        ...)                     \
        PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
          reg_type,                                              \
          kernel_name,                                           \
          backend,                                               \
          context,                                               \
          layout,                                                \
          PD_ID,                                                 \
          args_def_fn,                                           \
          meta_kernel_fn,                                        \
          arg_parse_functor_macro,                                     \
          kernel_unfold_macro,                                             \
          variadic_kernel_unfold_marco,                                    \
          __VA_ARGS__))

      这里PD_EXPAND套了一层PD_CONCATENATE,这个PD_CONCATENATE就是:

      #define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2)
      #define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2)
      #define PD_CONCATENATE2(arg1, arg2) arg1##arg2

      (暂时不明白为什么要套娃这么多层,为什么不直接#define PD_CONCATENATE(arg1, arg2) arg1##arg2呢?)

      这里concat的目的是把_PD_KERNEL_REGISTRAR_INIT_N连接起来,这里的N就是PD_NARGS(__VA_ARGS__),也就是在这里:

      #define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N()))
      #define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__)
      #define _PD_ARG_N_EXPAND(                                                     \
          _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \
        N
      #define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
      #define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0

      可以一步步看看宏展开了什么,第一步先把_PD_RESQ_N()展开,得到:

      _PD_NARGS((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))

      然后把_PD_NARGS展开,得到:

      _PD_ARG_N((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))

      然后把_PD_ARG_N展开,得到:

      _PD_ARG_N_EXPAND (__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
      

      然后我们回忆一下,__VA_ARGS__其实就是我们注册算子的时候的那几个类型,在bitwise_add的例子中,就是:

                         bool,
                         uint8_t,
                         int8_t,
                         int16_t,
                         int,
                         int64_t) {}

      所以把_PD_ARG_N_EXPAND展开,得到:

      _PD_ARG_N_EXPAND(bool, uint8_t, int8_t, int16_t, int, int64_t, 15, 14, 13, 12, 11, 10, 9, 8, 7, [[6]], 5, 4, 3, 2, 1, 0)
      
      _PD_ARG_N_EXPAND(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) N

      上下对照宏来看,会发现宏定义中N的位置在具体例子中是6,所以展开得到的是6,后面的5,4,3,2,1,0都进入了...部分,而6也恰好是我们注册的type数量。

      所以总的来看,PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N)N就是表示我们注册的类型数量,所以在

        PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
          reg_type,                                              \
          kernel_name,                                           \
          backend,                                               \
          context,                                               \
          layout,                                                \
          PD_ID,                                                 \
          args_def_fn,                                           \
          meta_kernel_fn,                                        \
          arg_parse_functor_macro,                                     \
          kernel_unfold_macro,                                             \
          variadic_kernel_unfold_marco,                                    \
          __VA_ARGS__))

      中,对bitwise_and这个例子而言,就是变成了:

        _PD_KERNEL_REGISTRAR_INIT_6 ( \
          reg_type,                                              \
          kernel_name,                                           \
          backend,                                               \
          context,                                               \
          layout,                                                \
          PD_ID,                                                 \
          args_def_fn,                                           \
          meta_kernel_fn,                                        \
          arg_parse_functor_macro,                                     \
          kernel_unfold_macro,                                             \
          variadic_kernel_unfold_marco,                                    \
          __VA_ARGS__))

      然后我们展开_PD_KERNEL_REGISTRAR_INIT_6

      #define _PD_KERNEL_REGISTRAR_INIT_6(reg_type,                         \
                                          kernel_name,                      \
                                          backend,                          \
                                          context,                          \
                                          layout,                           \
                                          registrar_id,                     \
                                          args_def_fn,                      \
                                          meta_kernel_fn,                   \
                                          arg_parse_functor_macro,          \
                                          kernel_unfold_macro,              \
                                          variadic_kernel_unfold_marco,     \
                                          cpp_dtype,                        \
                                          ...)                              \
        _PD_CREATE_REGISTRAR_OBJECT(reg_type,                               \
                                    kernel_name,                            \
                                    backend,                                \
                                    context,                                \
                                    layout,                                 \
                                    registrar_id,                           \
                                    args_def_fn,                            \
                                    meta_kernel_fn,                         \
                                    arg_parse_functor_macro,                \
                                    kernel_unfold_macro,                    \
                                    variadic_kernel_unfold_marco,           \
                                    cpp_dtype)                              \
        PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type,                     \
                                              kernel_name,                  \
                                              backend,                      \
                                              context,                      \
                                              layout,                       \
                                              PD_ID,                        \
                                              args_def_fn,                  \
                                              meta_kernel_fn,               \
                                              arg_parse_functor_macro,      \
                                              kernel_unfold_macro,          \
                                              variadic_kernel_unfold_marco, \
                                              __VA_ARGS__))

      可以看到,主要是调用了两个宏:_PD_CREATE_REGISTRAR_OBJECT_PD_KERNEL_REGISTRAR_INIT_5,我们先来看第一个:

      1. _PD_CREATE_REGISTRAR_OBJECT的定义:

        #define _PD_CREATE_REGISTRAR_OBJECT(reg_type,                                  \
                                            kernel_name,                               \
                                            backend,                                   \
                                            context,                                   \
                                            layout,                                    \
                                            registrar_id,                              \
                                            args_def_fn,                               \
                                            meta_kernel_fn,                            \
                                            arg_parse_functor_macro,                   \
                                            kernel_unfold_macro,                       \
                                            variadic_kernel_unfold_marco,              \
                                            cpp_dtype)                                 \
          static const ::phi::KernelRegistrar PD_CONCATENATE(                          \
              __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
              reg_type,                                                                \
              #kernel_name,                                                            \
              #backend,                                                                \
              DATA_LAYOUT(layout),                                                     \
              ::phi::CppTypeToDataType<cpp_dtype>::Type(),                             \
              arg_parse_functor_macro(meta_kernel_fn, cpp_dtype, context),             \
              args_def_fn,                                                             \
              kernel_unfold_macro(meta_kernel_fn<cpp_dtype, context>),                 \
              variadic_kernel_unfold_marco(meta_kernel_fn<cpp_dtype, context>));

        这里的cpp_dtype就是把例子中的类型一个个拆解开来,当前在_PD_KERNEL_REGISTRAR_INIT_6中的_PD_CREATE_REGISTRAR_OBJECT传入的cpp_dtype就是bool

        然后把PD_CONCATENATE展开,可以看到这个宏中主要是声明了一个函数初始化了一个类,也就是调用了KernelRegistrar的构造函数,在具体例子中,就是这样:

          static const ::phi::KernelRegistrar __reg_phi_kernel_bitwise_add_CPU_ALL_LAYOUT_0(
              reg_type,                                                                
              "bitwise_add",                                                            
              "CPU",                                                                
              phi::DataLayout::ALL_LAYOUT,                                                     
              ::phi::CppTypeToDataType<bool>::Type(),// 这里做了cpp中type到paddle中DateType的映射,也就是得到了DataType::BOOL,本质其实是指代一个枚举"1"
              arg_parse_functor_macro(meta_kernel_fn, bool, context),             
              args_def_fn,                                                             
              kernel_unfold_macro(meta_kernel_fn<cpp_dtype, context>),                 
              variadic_kernel_unfold_marco(meta_kernel_fn<cpp_dtype, context>));
      2. _PD_KERNEL_REGISTRAR_INIT_5_PD_KERNEL_REGISTRAR_INIT_6中的第二个部分,可以发现他们长得很像,就是最后一个数字不一样,并且可预料的,_PD_KERNEL_REGISTRAR_INIT_5中也会像_PD_KERNEL_REGISTRAR_INIT_6一样,由_PD_CREATE_REGISTRAR_OBJECT_PD_KERNEL_REGISTRAR_INIT_4构成,这样不断递归下去。

        其实留心可以发现,例如在_PD_KERNEL_REGISTRAR_INIT_6的调用时:

          _PD_KERNEL_REGISTRAR_INIT_6 ( \
            reg_type,                                              \
            kernel_name,                                           \
            backend,                                               \
            context,                                               \
            layout,                                                \
            PD_ID,                                                 \
            args_def_fn,                                           \
            meta_kernel_fn,                                        \
            arg_parse_functor_macro,                                     \
            kernel_unfold_macro,                                             \
            variadic_kernel_unfold_marco,                                    \
            __VA_ARGS__))

        传入了11个参数+1个可变参数宏__VA_ARGS__

        然后在_PD_KERNEL_REGISTRAR_INIT_6的定义时:

        #define _PD_KERNEL_REGISTRAR_INIT_6(reg_type,                         \
                                            kernel_name,                      \
                                            backend,                          \
                                            context,                          \
                                            layout,                           \
                                            registrar_id,                     \
                                            args_def_fn,                      \
                                            meta_kernel_fn,                   \
                                            arg_parse_functor_macro,          \
                                            kernel_unfold_macro,              \
                                            variadic_kernel_unfold_marco,     \
                                            cpp_dtype,                        \
                                            ...)                              \

        是有12个参数,这就意味着,第二十个参数是从传入的__VA_ARGS__中解析出来的,而我们知道__VA_ARGS__里面存的是一个个注册的cpp type,所以,_PD_KERNEL_REGISTRAR_INIT_6->_PD_KERNEL_REGISTRAR_INIT_1这样不断递归的行为,就是把所有类型一个个拿出来,并且给他们做_PD_CREATE_REGISTRAR_OBJECT操作。

        然后我们可以直接直接看到_PD_KERNEL_REGISTRAR_INIT_1的宏定义:

        #define _PD_KERNEL_REGISTRAR_INIT_1(reg_type,                                \
                                            kernel_name,                             \
                                            backend,                                 \
                                            context,                                 \
                                            layout,                                  \
                                            registrar_id,                            \
                                            args_def_fn,                             \
                                            meta_kernel_fn,                          \
                                            arg_parse_functor_macro,                 \
                                            kernel_unfold_macro,                     \
                                            variadic_kernel_unfold_marco,            \
                                            cpp_dtype)                               \
          _PD_CREATE_REGISTRAR_OBJECT(reg_type,                                      \
                                      kernel_name,                                   \
                                      backend,                                       \
                                      context,                                       \
                                      layout,                                        \
                                      registrar_id,                                  \
                                      args_def_fn,                                   \
                                      meta_kernel_fn,                                \
                                      arg_parse_functor_macro,                       \
                                      kernel_unfold_macro,                           \
                                      variadic_kernel_unfold_marco,                  \
                                      cpp_dtype)                                     \
          TEST_API int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
            return 0;                                                                \
          }

        在我们的例子中,此时传入的cpp_dtype应该是int64_t了:

        PD_REGISTER_KERNEL(bitwise_and,
                           CPU,
                           ALL_LAYOUT,
                           phi::BitwiseAndKernel,
                           bool,
                           uint8_t,
                           int8_t,
                           int16_t,
                           int,
                           int64_t) {}

        然后最后定义了一个函数:

          TEST_API int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \
            return 0;                                                                \
          }

        在这个例子中,就是

        TEST_API int TouchKernelSymbolFor_bitwise_add_CPU_ALL_LAYOUT() { 
        	return 0;                                                                
        }

        明白了这里是递归注册所有传入的类型,那么具体的注册过程,就要看_PD_CREATE_REGISTRAR_OBJECT


对于某一个类型,如何进行“注册”?

飞桨高可复用算子库 PHI 设计文档中,可以知道,"注册"就是把kernel相关信息插入到一个全局的哈希表中。

前面提到_PD_CREATE_REGISTRAR_OBJECT,是用来注册具体的某个类,在例子中,就是:

  static const ::phi::KernelRegistrar __reg_phi_kernel_bitwise_add_CPU_ALL_LAYOUT_0(
      reg_type,                                                                
      "bitwise_add",                                                            
      "CPU",                                                                
      phi::DataLayout::ALL_LAYOUT,                                                     
      ::phi::CppTypeToDataType<bool>::Type(),// 这里做了cpp中type到paddle中DateType的映射,也就是得到了DataType::BOOL,本质其实是指代一个枚举"1"
      arg_parse_functor_macro(meta_kernel_fn, bool, context),             
      args_def_fn,                                                             
      kernel_unfold_macro(meta_kernel_fn<cpp_dtype, context>),                 
      variadic_kernel_unfold_marco(meta_kernel_fn<cpp_dtype, context>));

这是KernelRegistrar类的构造函数,所以我们进入KernelRegistrar类中看看:

可以发现,它有两个构造函数:

  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
                  DataLayout layout,
                  DataType dtype,  // 传入了dtype
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
    ConstructKernel(reg_type,
                    kernel_name_cstr,
                    backend_cstr,
                    layout,
                    dtype,
                    args_parse_fn,
                    args_def_fn,
                    kernel_fn,
                    variadic_kernel_fn);
  }

  KernelRegistrar(RegType reg_type,
                  const char* kernel_name_cstr,
                  const char* backend_cstr,
                  DataLayout layout,
                  KernelArgsParseFn args_parse_fn,
                  KernelArgsDefFn args_def_fn,
                  KernelFn kernel_fn,
                  void* variadic_kernel_fn) {
    for (size_t dtype = static_cast<size_t>(DataType::BOOL);
         dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
         dtype++) {
      // NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
      // of these type.
      if (dtype == static_cast<size_t>(DataType::UINT32) ||
          dtype == static_cast<size_t>(DataType::UINT64) ||
          dtype == static_cast<size_t>(DataType::UINT16)) {
        continue;
      }
      // NOTE(zhoushunjie): Only the strings kernels can support pstring dtype
      constexpr char strings_kernels_prefix[] = "strings_";
      if (dtype == static_cast<size_t>(DataType::PSTRING) &&
          strncmp(kernel_name_cstr,
                  strings_kernels_prefix,
                  strlen(strings_kernels_prefix))) {
        continue;
      }
      ConstructKernel(reg_type,
                      kernel_name_cstr,
                      backend_cstr,
                      layout,
                      static_cast<DataType>(dtype),
                      args_parse_fn,
                      args_def_fn,
                      kernel_fn,
                      variadic_kernel_fn);
    }
  }

两者区别就在于,有没有传入dtype,在bitwise_add的注册过程中,传入了dtype,所以是走第一个构造函数。然后可以看到里面是调用了ConstructKernel,而且发现参数都是一样的,所以这里单纯包了一层,转发了一下参数,我们继续看ConstructKernel

  void ConstructKernel(RegType reg_type,
                       const char* kernel_name_cstr,
                       const char* backend_cstr,
                       DataLayout layout,
                       DataType dtype,
                       KernelArgsParseFn args_parse_fn,
                       KernelArgsDefFn args_def_fn,
                       KernelFn kernel_fn,
                       void* variadic_kernel_fn) {
    std::string kernel_name(kernel_name_cstr);
    KernelKey kernel_key(
        paddle::experimental::StringToBackend(backend_cstr), layout, dtype);
    Kernel kernel(kernel_fn, variadic_kernel_fn);
    if (kernel.GetKernelRegisteredType() == KernelRegisteredType::FUNCTION) {
      args_parse_fn(kernel_key, kernel.mutable_args_def());
    }
    args_def_fn(kernel_key, &kernel);
    if (reg_type == RegType::INNER) {
      KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
    } else {
      CustomKernelMap::Instance().RegisterCustomKernel(
          kernel_name, kernel_key, kernel);
    }
  }

我们一行行来看他的实现:

  1. std::string kernel_name(kernel_name_cstr);

    这里就是转成string类,得到的kernel_name在例子中就是"bitwise_add"这个字符串(从最初注册时输入的bitwise_add,被#kernel_name变成chat*,然后在这里转成string类)

  2. KernelKey kernel_key(paddle::experimental::StringToBackend(backend_cstr), layout, dtype);

    这里传入了三个参数,类型分别为Backend,DataLayout,DataType

    • 第一个参数先把backend_cstr这个字符串转成Backend类,其实Backend类也还是个枚举类(从最初注册时输入的CPU,被#backend变成char*,然后在这里变成Backend类)

    • 第二个参数,在例子中就是ALL_LAYOUT(从最初注册时输入ALL_LAYOUT,被DATA_LAYOUT(layout)中的DATA_LAYOUT宏:#define DATA_LAYOUT(arg__) phi::DataLayout::arg__直接变成了DataLayout类,也就是现在这里的类型)

    • 第三个参数,在例子中,这里以init_6为例,也就是抽出第一个注册的dtype时,就是DataType::BOOL(从最初注册时输入的bool,被::phi::CppTypeToDataType<cpp_dtype>::Type()中的CppTypeToDataType<bool>模板类,从cpp的基础类型转化成了DataType类)

    然后利用这三个参数,构造了一个KernelKey的对象,我们看看KernelKey的构造函数:

    class KernelKey {
     public:
      KernelKey() = default;
    
      KernelKey(Backend backend, DataLayout layout, DataType dtype)
          : backend_(backend), layout_(layout), dtype_(dtype) {}
    
      explicit KernelKey(const Place& place)
          : backend_(TransToPhiBackend(place)),
            layout_(DataLayout::ALL_LAYOUT),
            dtype_(DataType::ALL_DTYPE) {}
    
      explicit KernelKey(const int& dtype, const Place& place)
          : backend_(TransToPhiBackend(place)),
            layout_(DataLayout::ALL_LAYOUT),
            dtype_(phi::TransToPhiDataType(dtype)) {}
    
      explicit KernelKey(const Place& place,
                         const DataLayout& layout,
                         const DataType& dtype)
          : backend_(TransToPhiBackend(place)), layout_(layout), dtype_(dtype) {}

    可以看到有四个,在我们的例子中,这里会调用第一个构造函数,就是存一下backend,layout,dtype。其他三个构造函数功能其实一样,只是做了一下兼容性方面的处理。

  3. Kernel kernel(kernel_fn, variadic_kernel_fn);

    • 第一个参数kernel_fn,我们回溯回去看看它是什么,可以在_PD_CREATE_REGISTRAR_OBJECT宏中,发现它就是kernel_unfold_macro(meta_kernel_fn<cpp_dtype, context>),继续分析:

      • 这里的cpp_dtpe此时就是bool(以N=6时解析出第一个类型为例),而context就是之前::phi::backend##Context根据backend得到的,例子中的context就是::phi::CPUContext

        meta_kernel_fn则是我们在注册时候传入的phi::BitwiseAndKernel,这就是我们在.cc中实现的kernel

        #define DEFINE_BITWISE_KERNEL(op_type)                                 \
          template <typename T, typename Context>                              \
          void Bitwise##op_type##Kernel(const Context& dev_ctx,                \
                                        const DenseTensor& x,                  \
                                        const DenseTensor& y,                  \
                                        DenseTensor* out) {                    \
            funcs::Bitwise##op_type##Functor<T> func;                          \
            funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
                dev_ctx, x, y, func, out);                                     \
          }
        
        DEFINE_BITWISE_KERNEL(And)

        可以看到,在kernel的template中,就能对上了,T就是此时注册的cpp类型,然后context就是根据backend得到的上下文信息CPUContext

      • 接下来是调用了kernel_unfold_macro这个宏,而这个kernel_unfold_macro就是PHI_KERNEL宏,它一直作为参数一层层传到这里才发生展开,我们看看它做了什么:

        #define PHI_KERNEL(...) \
          ::phi::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
        

        可以看到,&__VA_ARGS__此时就是BitwiseAddKernel的函数指针,在这个例子中,就是变成:

        ::phi::KernelImpl<decltype(&BitwiseAddKernel), &BitwiseAddKernel>::Compute
        

        这是一个Kernel_Fn类型:

        using KernelFn = std::function<void(KernelContext* ctx)>;

        这里由于::phi::KernelImpl<decltype(&BitwiseAddKernel), &BitwiseAddKernel>::Compute传入的是"函数指针类别"和"具体的函数指针",刚好和 KernelImpl<Return (*)(DevCtx, Args...), kernel_fn>匹配。

        可以看到调用了KernelImpl中的Compute方法:

        template <typename Fn, Fn fn>
        struct KernelImpl;
        
        template <typename Return,
                  typename DevCtx,
                  typename... Args,
                  Return (*kernel_fn)(DevCtx, Args...)>
        struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
          static void Compute(KernelContext* ctx) {
            KernelCallHelper<DevCtx, Args..., TypeTag<int>>::
                template Compute<0, 0, 0, 0>(ctx);
          }

        可以看到它的静态方法Compute调用了KernelCallHelper,此时模板中的Return被自动推导得到voidDevCtx被自动推导为Context类,Args则是将后面的其他参数打包成了参数包。

        至此,kernel_unfold_macro的展开就是得到了这里的Compute方法。

        也就是Kernel对象构建:

        Kernel kernel(kernel_fn, variadic_kernel_fn);

        这里的第一个参数kernel_fn

    • 第二个参数variadic_kernel_fn的流程和第一个参数kernel_fn基本一致。

      传入的是PHI_VARIADIC_KERNEL(kernel_fn),可以看看PHI_VARIADIC_KERNEL这个宏:

      #define PHI_VARIADIC_KERNEL(...)                                     \
        reinterpret_cast<void*>(&::phi::KernelImpl<decltype(&__VA_ARGS__), \
                                                   &__VA_ARGS__>::VariadicCompute)

      在这个例子中,就是展开变成

      reinterpret_cast<void*>(&::phi::KernelImpl<decltype(&BitwiseAddKernel),&BitwiseAddKernel>::VariadicCompute)

      和前面的kernel_fn相比,这里依然是用的同一个KernelImpl实例(传入给模板的参数和前面一样),调用了VariadicCompute静态方法:

        static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) {
          return kernel_fn(static_cast<DevCtx>(dev_ctx), std::forward<Args>(args)...);
        }

      得到的是这个函数指针,指向由BitwiseAddKernel相关信息特化的VariadicCompute函数。

      在编译时期,传入的kernel_fn,也就是BitwiseAddKernel,作为宏体中的__VA_ARGS__,顺利地将::phi::KernelImpl进行了实例化,所以此时variadic_kernel_fn这个参数应该不为空。

    然后构造Kernel对象:

      explicit Kernel(KernelFn fn, void* variadic_fn)
          : fn_(fn), variadic_fn_(variadic_fn) {
        if (variadic_fn == nullptr) {
          kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
        } else {
          kernel_registered_type_ = KernelRegisteredType::FUNCTION;
        }
      }

    可以发现,主要是存了一下传入的fnvariadic_fn,因为variadic_fn不为空,所以kernel_registered_type_赋值为KernelRegisteredType::FUNCTION(看到这里涉及structurefunction,猜测这块可能是兼容老的op体系用的?老的fluid体系为结构体算子,新的phi体系算子为函数式算子)

  4. if (kernel.GetKernelRegisteredType() == KernelRegisteredType::FUNCTION) {
          args_parse_fn(kernel_key, kernel.mutable_args_def());
    }

    前面我们知道,在例子中,kernel的GetKernelRegisteredTypeKernelRegisteredType::FUNCTION,所以这里是要走if的。

    args_parse_fn是来自于arg_parse_functor_macro(meta_kernel_fn, cpp_dtype, context),在例子中,就是:

    arg_parse_functor_macro(phi::BitwiseAddKernel, bool, CPUContext)

    arg_parse_functor_macro就是一个宏,我们看看他的定义:

    // The macro for passing KernelArgsParseFunctor's function
    #define ARG_PARSE_FUNCTOR(meta_kernel_fn, cpp_dtype, context) \
      ::phi::KernelArgsParseFunctor<                              \
          decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse

    可以发现这个例子中,展开后是这样的:

    ::phi::KernelArgsParseFunctor<decltype(&phi::BitwiseAddKernel<bool, CPUContext>)>::Parse

    看看KernelArgsParseFunctor中的实现:

    template <typename Return_, typename... Args_>
    struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
      using Args = std::tuple<Args_...>;
      enum : std::size_t { Arity = sizeof...(Args_) };
      using Indices = std::make_index_sequence<Arity>;
      template <std::size_t Index>
      using Arg = typename std::tuple_element<Index, Args>::type;
    
      static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) {
        // TODO(chenweihang): The fluid Tensor's default layout is NCHW,
        // it is not same as kernel's layout, we should fix this error on
        // fluid Tensor
    
        auto args_type = ParseArgType(Indices{});
        SetKernelArgsDef(args_type, default_key, args_def);
      }
    
     private:
      template <std::size_t... INDEX>
      static std::vector<std::type_index> ParseArgType(
          std::index_sequence<INDEX...>) {
        return {std::type_index(typeid(Arg<INDEX>))...};
      }
    };

    在前面的if中,调用了args_parse_fn(kernel_key, kernel.mutable_args_def());,所以就是:

    ::phi::KernelArgsParseFunctor<decltype(&phi::BitwiseAddKernel<bool, CPUContext>)>::Parse(kernel_key, kernel.mutable_args_def())

    传入的第一个参数是带有backend,layout,dtypeKernelKey信息,第二个参数是kernel对象的指向成员变量args_def_KernelArgsDef*指针:

    KernelArgsDef* mutable_args_def() { return &args_def_; }

    当然此时这个kernelargs_def_是空的,而接下来要做的Parse操作,就是从传入的kernel_key中提取参数信息,去填充kernel对象的args_def变量(而不是利用这个变量去做什么其他事情,他是作为输出传进来的)。我们看看这里的Parse具体是怎么做的,下面有比较多的std标准库元函数的使用:

    • auto args_type = ParseArgType(Indices{});

      这里的Indicesstd::make_index_sequence<Arity>;,而Arityenum : std::size_t { Arity = sizeof...(Args_) };,可以知道,这里的Arity表示的是参数包大小,在例子中,就是表示phi::BitwiseAddKernel的参数量:

      #define DEFINE_BITWISE_KERNEL(op_type)                                 \
        template <typename T, typename Context>                              \
        void Bitwise##op_type##Kernel(const Context& dev_ctx,                \
                                      const DenseTensor& x,                  \
                                      const DenseTensor& y,                  \
                                      DenseTensor* out) {                    \
          funcs::Bitwise##op_type##Functor<T> func;                          \
          funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
              dev_ctx, x, y, func, out);                                     \
        }
      
      DEFINE_BITWISE_KERNEL(And)
      

      可见,Arity就是4,所以Indices{}将会是一个从 0 到 3 的整数序列。我们继续看ParseArgType的实现:

        template <std::size_t... INDEX>
        static std::vector<std::type_index> ParseArgType(
            std::index_sequence<INDEX...>) {
          return {std::type_index(typeid(Arg<INDEX>))...};
        }

      这里是通过传入的Indices{},推导出了INDEX就是0到3的整数序列。

      然后Arg是:

        template <std::size_t Index>
        using Arg = typename std::tuple_element<Index, Args>::type;

      其中的Args是一个tuple:

      using Args = std::tuple<Args_...>;

      可以发现,Arg<Index>就是在Args这个tuple中取下标为Index的元素。

      而后

      {std::type_index(typeid(Arg<INDEX>))...}是一个折叠表达式,展开可以得到:

      return {std::type_index(typeid(Arg<0>)), std::type_index(typeid(Arg<1>)), std::type_index(typeid(Arg<2>)), std::type_index(typeid(Arg<3>))}

      然后将Arg展开,在具体例子中,得到:

      return {std::type_index(typeid(const CPUContext&)), std::type_index(typeid(const DenseTensor&)), std::type_index(typeid(const DenseTensor&)), std::type_index(typeid(DenseTensor*))}

      所以ParseArgType这个函数就是巧妙地利用了模板自动推导,在::phi::KernelArgsParseFunctor<decltype(&phi::BitwiseAddKernel<bool, CPUContext>)>实例化的时候拆解了kernel,自动推导得到了当前kernel的参数类型,保存在args_type变量中。然后我们看接下来如何把这些类型信息赋值给args_def

    • SetKernelArgsDef(args_type, default_key, args_def);

      可以看到,这里调用了SetKernelArgsDef函数。这个函数比较简单,篇幅较长,下面我们一段段来看:

      void SetKernelArgsDef(const std::vector<std::type_index>& args_type,
                            const KernelKey& default_key,
                            KernelArgsDef* args_def) {
        auto default_tensor_layout = phi::DataLayout::NCHW;
        if (default_key.layout() != phi::DataLayout::ANY) {
          default_tensor_layout = default_key.layout();
        }

      首先,这里做了一下特殊的处理,应该是和前面的todo相关:

        static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) {
          // TODO(chenweihang): The fluid Tensor's default layout is NCHW,
          // it is not same as kernel's layout, we should fix this error on
          // fluid Tensor

      具体例子中,由于传入的kernel_key的layout是ALL_LAYOUT,也就是ANY,所以tensor的默认layout并不影响,所以不用走这个if语句。

      然后是for循环,遍历我们前面拿到的kernel的所有参数:

        for (auto arg_type : args_type) {
          if (arg_type == std::type_index(typeid(const CPUContext&))
      #if defined(PADDLE_WITH_DNNL)
              || arg_type == std::type_index(typeid(const OneDNNContext&))
      #endif
      #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
              || arg_type == std::type_index(typeid(const GPUContext&))
      #elif defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
              || arg_type == std::type_index(typeid(const XPUContext&))
      #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_KP)
                || arg_type == std::type_index(typeid(const KPSContext&))
      #endif
      #if defined(PADDLE_WITH_CUSTOM_DEVICE)
              || arg_type == std::type_index(typeid(const CustomContext&))) {
      #else
          ) {
      #endif
            // do nothing, skip context arg now
          }

      可以发现,如果遇到参数类型是Context相关的,就跳过这个参数,在具体例子中,就会跳过std::type_index(typeid(const CPUContext&))这个参数,剩下还有:

      {std::type_index(typeid(const DenseTensor&)), std::type_index(typeid(const DenseTensor&)), std::type_index(typeid(DenseTensor*))}

      这三个参数。

      而后,是检测输入相关的参数:

      else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type ==
                     std::type_index(typeid(const paddle::optional<DenseTensor>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type ==
                     std::type_index(typeid(
                         const paddle::optional<std::vector<const DenseTensor*>>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type ==
                     std::type_index(typeid(const paddle::optional<SelectedRows>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     const std::vector<const DenseTensor*>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type ==
                     std::type_index(typeid(const phi::ExtendedTensor&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     const std::vector<const ExtendedTensor*>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     const std::vector<const SelectedRows*>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type ==
                     std::type_index(typeid(const std::vector<const TensorBase*>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     const std::vector<const TensorArray*>&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(const StringTensor&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     paddle::optional<const SparseCooTensor&>))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(const SparseCsrTensor&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(
                                     paddle::optional<const SparseCsrTensor&>))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);
          } else if (arg_type == std::type_index(typeid(const TensorArray&))) {
            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);

      从这里其实就可以知道,我们在算子注册的时候,输入变量的类型为什么必须严格带const,而且常常需要DenseTensor的指针了吧。因为要在这里对上,才能顺利地把参数类型信息存到args_def中去。

      具体例子中,我们有输入相关的std::type_index(typeid(const DenseTensor&)), std::type_index(typeid(const DenseTensor&))这两个参数,都是走这个:

            args_def->AppendInput(default_key.backend(),
                                  default_tensor_layout,
                                  default_key.dtype(),
                                  arg_type);

      所以args_def中的input_defs_现在存了TensorArgDef(Backend::CPU, DataLayout::ALL_LAYOUT, DataType::BOOL, std::type_index(typeid(const DenseTensor&)))TensorArgDef(Backend::CPU, DataLayout::ALL_LAYOUT, DataType::BOOL, std::type_index(typeid(const DenseTensor&)))这两个相同元素

      接下来是一系列输出类型的判断:

      else if (arg_type == std::type_index(typeid(DenseTensor*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(std::vector<DenseTensor*>))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(TensorArray*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(SparseCooTensor*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(SparseCsrTensor*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(StringTensor*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);
          } else if (arg_type == std::type_index(typeid(ExtendedTensor*))) {
            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);

      具体例子中,此时遍历到了最后一个参数std::type_index(typeid(DenseTensor*))

      显然会走这个:

            args_def->AppendOutput(default_key.backend(),
                                   default_tensor_layout,
                                   default_key.dtype(),
                                   arg_type);

      所以给args_defoutput_defs_中加入了TensorArgDef(backend, layout, dtype, type_index)

      具体例子中就是:

      TensorArgDef(Backend::CPU, DataLayout::ALL_LAYOUT, DataType::BOOL, std::type_index(typeid(DenseTensor*)))

      这样就遍历完了四个参数,存入了args_def中,完成了Parse操作。

    1. args_def_fn(kernel_key, &kernel);
      

      我们继续回到ConstructKernel中,前面利用kernel_key中的backend, layout, dtype,结合具体kernel实现时的参数类型,已经完善了kernel对象的args_def_成员变量,存好了这些相关信息。

      接下来是args_def_fn操作,我们需要回溯到_PD_REGISTER_2TA_KERNEL这个宏定义中去查看,以linux下为例:

      #define _PD_REGISTER_2TA_KERNEL(reg_type,                                   \
                                      kernel_name,                                \
                                      backend,                                    \
                                      context,                                    \
                                      layout,                                     \
                                      meta_kernel_fn,                             \
                                      kernel_instantiation_macro,                 \
                                      arg_parse_functor_macro,                    \
                                      kernel_unfold_macro,                        \
                                      variadic_kernel_unfold_marco,               \
                                      ...)                                        \
        static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);           \
        PD_EXPAND(PD_KERNEL_REGISTRAR_INIT(                                       \
            reg_type,                                                             \
            kernel_name,                                                          \
            backend,                                                              \
            context,                                                              \
            layout,                                                               \
            &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,        \ /* 此为之后的args_def_fn */
            meta_kernel_fn,                                                       \
            arg_parse_functor_macro,                                              \
            kernel_unfold_macro,                                                  \
            variadic_kernel_unfold_marco,                                         \
            __VA_ARGS__));                                                        \
        void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)

      可以看到,args_def_fn是一个函数指针&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout,具体例子中,这里就是:

      &__PD_KERNEL_args_def_FN_bitwise_add_CPU_ALL_LAYOUT

      这里有他的声明:

      static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( 
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);

      具体则是:

      static void __PD_KERNEL_args_def_FN_bitwise_add_CPU_ALL_LAYOUT( 
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel);

      定义则在下面:

        void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout(        \
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)

      这里的定义长得比较奇怪,第一眼:这是定义???括号呢???具体实现呢???

      然后仔细观察了一下,应该是注册的时候:

      PD_REGISTER_KERNEL(bitwise_and,
                         CPU,
                         ALL_LAYOUT,
                         phi::BitwiseAndKernel,
                         bool,
                         uint8_t,
                         int8_t,
                         int16_t,
                         int,
                         int64_t) {}

      最后不是带了个大括号{}吗?宏展开后,这个大括号就跑到这个定义下来了,这也是为什么定义写在了最后面,而不是直接跟声明写在一起,就是为了把这个地方的实现暴露给开发者,可以直接在注册的时候调整kernelkernel_key,很巧妙。所以这个函数在这个例子中,应该确实是什么都不干,不过在其他例子中,就有用到这块设计的,例如在full_kernel.cc中:

      namespace phi {
      
      template <typename T, typename Context>
      void FullBatchSizeLikeKernel(const Context& dev_ctx,
                                   const DenseTensor& x,
                                   const std::vector<int>& shape UNUSED,
                                   const Scalar& val,
                                   DataType dtype,
                                   int x_batch_size_dim,
                                   int out_batch_size_dim,
                                   DenseTensor* out) {
        if (!x.lod().empty() && x_batch_size_dim == 0) {
          // set the correct batch size for the LoDTensor.
          auto odims = out->dims();
          odims[out_batch_size_dim] = static_cast<int>(x.lod().back().size()) - 1;
          FullKernel<T, Context>(dev_ctx, common::vectorize(odims), val, dtype, out);
        }
        FullLikeKernel<T, Context>(dev_ctx, x, val, dtype, out);
      }
      
      }  // namespace phi
      
      PD_REGISTER_KERNEL(full_batch_size_like,
                         CPU,
                         ALL_LAYOUT,
                         phi::FullBatchSizeLikeKernel,
                         float,
                         double,
                         int,
                         int64_t,
                         bool) {
        kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
      }
      #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      PD_REGISTER_KERNEL(full_batch_size_like,
                         GPU,
                         ALL_LAYOUT,
                         phi::FullBatchSizeLikeKernel,
                         float,
                         double,
                         int,
                         int64_t,
                         bool,
                         phi::dtype::float16,
                         phi::dtype::bfloat16) {
        kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
      }
      #endif
      

      这里的模板参数Context在注册cpu下的kernel时,有这样的定义:

      void __PD_KERNEL_args_def_FN_full_batch_size_like_CPU_ALL_LAYOUT(        
            const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) {
          kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
      }

      这里的kernel->InputAt(0)其实就是kernel的const DenseTensor& x

      所以在args_def_fn(kernel_key, &kernel);的时候,在这之前,由于注册时指定了CPU,所以kernel_key的backend是cpu,所以之前存入的参数x的backend就是cpu。现在就是”逆天改命“,把x的backend变成ALL_BACKEND;第二个也同理,注册的backend为gpu,也是在这时候改成ALL_BACKEND

      (思考:对一个输入tensor而言,他的backend(或者说,处在这个位置的参数的backend意味着什么?)这是否意味着,在调用算子进行计算的时候,框架自动对其backend进行检测?有的也对输出的dtype进行修改,可能这里会调整kernel允许输出的类别?譬如什么都不加,就是输入T输出T,加入类似于kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);这种就是输入T,输出bool

    2. if (reg_type == RegType::INNER) {
          KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
      } else {
          CustomKernelMap::Instance().RegisterCustomKernel(
              kernel_name, kernel_key, kernel);
      }

      最后,就是把kernel_keykernel存到KernelFactory里面,在具体例子中,reg_typeRegType::INNER,所以走if分支。

      然后可以看看KernelFactory这个工厂模式的设计:

      class KernelFactory {
       public:
        static KernelFactory& Instance();
        KernelNameMap& kernels() { return kernels_; }
        bool HasCompatiblePhiKernel(const std::string& op_type) const;
        bool HasStructuredKernel(const std::string& op_type) const;
        KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
                                              const KernelKey& kernel_key,
                                              bool use_strided_kernel = false) const;
        bool HasKernel(const std::string& kernel_name,
                       const KernelKey& kernel_key) const;
        const Kernel& SelectKernel(const std::string& kernel_name,
                                   const KernelKey& kernel_key) const;
        const Kernel& SelectKernelWithGPUDNN(const std::string& kernel_name,
                                             const KernelKey& kernel_key) const;
        KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
        const KernelArgsDef& GetFirstKernelArgsDef(
            const std::string& kernel_name) const;
        void AddToLowPrecisionKernelList(const std::string& name,
                                         const DataType& kernel_key_type);
        std::map<const std::string, OpCount> GetLowPrecisionKernelList();
        void ClearLowPrecisionKernelList() { low_precision_kernels_.clear(); }
       private:
        KernelFactory() = default;
        KernelNameMap kernels_;
        // Get the low precision kernel list of current module.
        std::map<const std::string, OpCount> low_precision_kernels_;
      };

      全局的静态实例Instance,掌管着一个KernelNameMap类的成员变量kernels_

      KernelNameMap就是using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;这样一个哈希表,然后KernelKeyMap同样是using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;这样一个哈希表。所以这部分其实就是通过kernel_name找到一个哈希表,然后再通过kernel_key再找一层,这样就能找到想要的具体kernel对象了。可以从飞桨高可复用算子库 PHI 设计文档中看到,整体的设计如下图,可以看到图中组织地非常清晰了。

      将当前的kernel相关信息插入到哈希表中,这样就完成了这个算子的注册。

    kernel-design.png

    • KernelFactory作为管理 Kernel 的全局单例数据结构,和 fluid 的 OpKernelMap 类似,两级 map,第一层根据 name 找到 Kernel 集合,第二层根据 KernelKey 找到具体的 Kernel
    • KernelKey和原先的 OpKernelType 类似,但将 place 和 library_type 字段合二为一称之为 Backend,因为原先的 LibraryType 是一个有局限的枚举类,原本就和 place 是强相关的,拆分反而增加了理解成本
    • Kernel相比原先的 OpKernel 持有了更多信息,除了执行时的 Function,还持有了具体参数的信息,即KernelArgsDef,对于 Tensor 类输入输出,保存了 Tensor 类型信息、Device,数据类型、数据布局,对于 Attribute 类输入输出,保存了类型信息

参考资料

飞桨高可复用算子库 PHI 设计文档

Kernel选择分发体系梳理与优化

文中的一些QA:

Q1:

然后构造Kernel对象:

  explicit Kernel(KernelFn fn, void* variadic_fn)
      : fn_(fn), variadic_fn_(variadic_fn) {
    if (variadic_fn == nullptr) {
      kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
    } else {
      kernel_registered_type_ = KernelRegisteredType::FUNCTION;
    }
  }

可以发现,主要是存了一下传入的fnvariadic_fn,因为variadic_fn不为空,所以kernel_registered_type_赋值为KernelRegisteredType::FUNCTION(看到这里涉及structurefunction,猜测这块可能是兼容老的op体系用的?老的fluid体系为结构体算子,新的phi体系算子为函数式算子)

A1:

这个猜测是对的。

Q2:

​ (思考:对一个输入tensor而言,他的backend(或者说,处在这个位置的参数的backend意味着什么?)这是否意味着,在调用算子进行计算的时候,框架自动对其backend进行检测?有的也对输出的dtype进行修改,可能这里会调整kernel允许输出的类别?譬如什么都不加,就是输入T输出T,加入类似于kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);这种就是输入T,输出bool

A2:

因为默认情况下我们认为所有输入和输出信息比如dtype或者backend这种一般是一致的,但并不是所有算子都是这样,有的存在这些信息不一致的情况,所以这里可以对其进行修改,修改后框架会对Tensor做相应的变换以满足处理条件

Q3:

kernel->InputAt(0).SetDataType(phi::DataType::BOOL);
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);

在.cc或者.cu中注册时,这样修改了第0个输入的Tensor和第0个输出的Tensor,是不是意味着调度的时候会预先检查输入和输出的dtype呢?

如果是的话,想问如果注册的时候限定了layout是如何做限制的呢?毕竟我们调用的时候就是传入一个paddle.to_tensor([1,2,3])这样的tensor,看上去不带什么layout的信息。

A3:

你可以理解成是一种检查,但实际情况比较复杂,需要根据这里的信息判断是否要做transform,具体需要熟悉了调度流程才会明白,layout的话其实我们现在一般都是用的ALL_LAYOUT,也就是默认情况,只有一些sparse kernel会对其进行修改,设置成SPARSE_COO等格式,以满足sparse kernel的执行条件

Q4:

#define PD_EXPAND(x) x

看样子它直接返回了输入?具体为什么加这个,还不太清楚。它包了一层_PD_REGISTER_2TA_KERNEL

A4:

这里加这个主要是由于嵌套宏有时候无法正常展开,比如带有##连接的时候,多使用一个额外的宏包裹一下,比如这里的PD_EXPAND,可以让嵌套宏能够正常展开。嵌套宏展开比较复杂,不同平台不同编译器的处理情况可能不一样,可以提交一个pr把所有PD_EXPAND删了,看看ci上会不会有什么问题