PyTorch 2.0.0新手教程:注册算子的主要方式

释放双眼,带上耳机,听听看~!
本文基于PyTorch 2.0.0官方教程,详细介绍了注册算子的主要方式,涵盖了C/C++中的相关知识,适合新手学习。

注:新手文章,欢迎指正!以下内容基于pytorch2.0.0

pytorch的官方教程pytorch.org/tutorials/a… 中,写了注册算子的主要方式是:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

需要注意的是,在C/C++中,&functionfunction是一样的,可见:
stackoverflow.com/questions/6…

pytorch代码中,/home/pytorch/torch/library.h中定义了TORCH_LIBRARY_IMPL宏:

#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)

_TORCH_LIBRARY_IMPL宏的定义如下:

#define _TORCH_LIBRARY_IMPL(ns, k, m, uid)                             
  static void C10_CONCATENATE(                                         
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);    
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(        
      TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(              
      torch::Library::IMPL,                                            
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( 
          c10::DispatchKey::k)>(                                       
          []() {                                                       
            return &C10_CONCATENATE(                                   
                TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);           
          },                                                           
          []() { return [](torch::Library&) -> void {}; }),            
      #ns,                                                             
      c10::make_optional(c10::DispatchKey::k),                         
      __FILE__,                                                        
      __LINE__);                                                       
  void C10_CONCATENATE(                                                
      TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)

首先看C10_UID,其定义为:

#define C10_UID __COUNTER__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__)

因此其实际上为一个全局唯一的ID号。

C10_CONCATENATE的定义如下:

#define C10_CONCATENATE_IMPL(s1, s2) s1##s2
#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2)

可见其就是连接了两个字符串,如果看不懂可以查一下##在C/C++预处理中的作用。

_TORCH_LIBRARY_IMPL的定义可以被分为以下三个部分:

  1. 声明一个静态函数:
static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);

函数名为TORCH_LIBRARY_IMPL_init_+ns+k+uid,假设TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m)的UID为20,那么函数名为:
TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

  1. 定义一个cpp文件内部的常量:
  static const torch::detail::TorchLibraryInit C10_CONCATENATE(        
      TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(              
      torch::Library::IMPL,                                            
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check( 
          c10::DispatchKey::k)>(                                       
          []() {                                                       
            return &C10_CONCATENATE(                                   
                TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);           
          },                                                           
          []() { return [](torch::Library&) -> void {}; }),            
      #ns,                                                             
      c10::make_optional(c10::DispatchKey::k),                         
      __FILE__,                                                        
      __LINE__);                                                       

该常量类型为static const torch::detail::TorchLibraryInit,仍然以上面的例子为例,其名字为:
TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,其和上面定义的静态函数的名字的差别就是多了一个static字符串。宏展开后,整段代码为如下:

  static const torch::detail::TorchLibraryInit                     //返回类型              
  TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(          
      torch::Library::IMPL,                                        //参数1,Library::Kind类型    
      c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(                                       
          []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
          []() { return [](torch::Library&) -> void {}; }
          ),                                                      //参数2,InitFn*类型
      "aten",                                                     //参数3,const char*类型 
      c10::make_optional(c10::DispatchKey::AutogradPrivateUse1),  //参数4,c10::optional<c10::DispatchKey>类型          
      __FILE__,                                                   //参数5,const char*类型
      __LINE__);                                                  //参数6,uint32_t类型

TorchLibraryInit的类定义如下:

class TorchLibraryInit final {
 private:
  using InitFn = void(Library&);
  Library lib_;

 public:
  TorchLibraryInit(
      Library::Kind kind,
      InitFn* fn,
      const char* ns,
      c10::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line)
      : lib_(kind, ns, k, file, line) {
    fn(lib_);
  }
};

其有只包含一个Library类型的私有成员变量,注意其初始构造函数中,会先用kind, ns, k, file, line初始化lib_,再用传入的InitFn类型,也就是void(Library&)类型的函数初始化这个私有成员变量lib_

在定义TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的时候,第一个参数Library::Kind kindtorch::Library::IMPL,第二个参数为

c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(                                       
          []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
          []() { return [](torch::Library&) -> void {}; }
          ),                                                      //参数2,InitFn*类型

首先看模板参数c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1),其定义为:

constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
#ifdef C10_MOBILE
 return true;
 // Disabled for now: to be enabled later!
 // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
#else
 return true;
#endif
} 

可见其目前无脑返回true,因此第二个参数变成:

c10::guts::if_constexpr<true>(                                       
         []() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},           
         []() { return [](torch::Library&) -> void {}; }
         ),                                                      //参数2,InitFn*类型

if_constexpr的定义如下:


template <bool Condition, class ThenCallback, class ElseCallback>
decltype(auto) if_constexpr(
   ThenCallback&& thenCallback,
   ElseCallback&& elseCallback) {
#if defined(__cpp_if_constexpr)
 // If we have C++17, just use it's "if constexpr" feature instead of wrapping
 // it. This will give us better error messages.
 if constexpr (Condition) {
   if constexpr (detail::function_takes_identity_argument<
                     ThenCallback>::value) {
     // Note that we use static_cast<T&&>(t) instead of std::forward (or
     // ::std::forward) because using the latter produces some compilation
     // errors about ambiguous `std` on MSVC when using C++17. This static_cast
     // is just what std::forward is doing under the hood, and is equivalent.
     return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
   } else {
     return static_cast<ThenCallback&&>(thenCallback)();
   }
 } else {
   if constexpr (detail::function_takes_identity_argument<
                     ElseCallback>::value) {
     return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
   } else {
     return static_cast<ElseCallback&&>(elseCallback)();
   }
 }
#else
 // C++14 implementation of if constexpr
 return detail::_if_constexpr<Condition>::call(
     static_cast<ThenCallback&&>(thenCallback),
     static_cast<ElseCallback&&>(elseCallback));
#endif
}

这里有点炫技的味道了,直接看注释:

Example 1: simple constexpr if/then/else
 template<int arg> int increment_absolute_value() {
   int result = arg;
   if_constexpr<(arg > 0)>(
     [&] { ++result; }  // then-case
     [&] { --result; }  // else-case
   );
   return result;
 }

所以这就是一个简单的模板编译期if else,由于其模板参数为true,因此第二个参数就是第一部分定义的静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20,之后的参数就不再赘述了,值得注意的是,第四个参数c10::make_optional(c10::DispatchKey::AutogradPrivateUse1)颇为复杂。

  1. 正式定义第一步声明的静态函数,宏展开后为:
    void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){ 
        m.impl(<myadd_schema>, &myadd_autograd); 
    }
    

整个代码简化之前为:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

宏展开+简化后为:

static void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m);

static const torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(          
  torch::Library::IMPL,                                       //参数1,Library::Kind类型    
  &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20,       //参数2,InitFn*类型
  "aten",                                                     //参数3,const char*类型 
  c10::make_optional(c10::DispatchKey::AutogradPrivateUse1),  //参数4,c10::optional<c10::DispatchKey>类型          
  __FILE__,                                                   //参数5,const char*类型
  __LINE__);                                                  //参数6,uint32_t类型
  
void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){ 
   m.impl(<myadd_schema>, &myadd_autograd); 
}

//TorchLibraryInit的定义,在library.h中定义
class TorchLibraryInit final {
 private:
  using InitFn = void(Library&);
  Library lib_;

 public:
  TorchLibraryInit(
      Library::Kind kind,
      InitFn* fn,
      const char* ns,
      c10::optional<c10::DispatchKey> k,
      const char* file,
      uint32_t line)
      : lib_(kind, ns, k, file, line) {
    fn(lib_);
  }
};

到这里总结一下:

① 第一部分声明了一个静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

② 第二部分声明了一个torch::detail::TorchLibraryInit类型的静态常量TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,在有一个Library类型的成员变量,通过传入的参数和第一部分声明的静态函数来初始化这个成员变量。

③ 第三部分则是实现了第一部分声明的函数。 注意这个函数通过调用torch::Library类型参数的impl成员函数来实现算子注册,而传入的实参实际上第二部分声明的静态常量的私有成员变量,而第二部分的静态常量名称为TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_##uid,也就是取决于命名空间(namespace)、设备(cpu or cuda or XXX)以及UID。

TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的初始构造函数利用TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20来初始化其私有成员变量lib_,初始化方法为调用其torch::Library类的私有成员变量lib_impl方法。

下面讲解torch::Library类的impl方法,其定义如下:

  /// Register an implementation for an operator.  You may register multiple
  /// implementations for a single operator at different dispatch keys
  /// (see torch::dispatch()).  Implementations must have a corresponding
  /// declaration (from def()), otherwise they are invalid.  If you plan
  /// to register multiple implementations, DO NOT provide a function
  /// implementation when you def() the operator.
  ///
  /// param name The name of the operator to implement.  Do NOT provide
  ///   schema here.
  /// param raw_f The C++ function that implements this operator.  Any
  ///   valid constructor of torch::CppFunction is accepted here;
  ///   typically you provide a function pointer or lambda.
  ///
  /// ```
  /// // Example:
  /// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
  ///   m.impl("add", add_cuda);
  /// }
  /// ```
  template <typename Name, typename Func>
  Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
    // TODO: need to raise an error when you impl a function that has a
    // catch all def
#if defined C10_MOBILE
    CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
#else
    CppFunction f(std::forward<Func>(raw_f));
#endif
    return _impl(name, std::move(f), rv);
  }

显然,这是一个万能引用+完美转发的函数(除了面试之外第一次见到),其内部先创建了CppFunction类型的对象。由于注册方法为:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  m.impl(<myadd_schema>, &myadd_autograd);
}

因此调用的CppFunction的初始构造函数如下:

  template <typename Func>
  explicit CppFunction(
      Func* f,
      std::enable_if_t<
          c10::guts::is_function_type<Func>::value,
          std::nullptr_t> = nullptr)
      : func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f)),
        cpp_signature_(c10::impl::CppSignature::make<Func>()),
        schema_(
            c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<Func>>()),
        debug_() {}

可见其在初始化列表中初始化了func_cpp_signature_schema_这三个私有成员变量,顾名思义,分别是函数、签名和模式。_impl函数的定义如下:

Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
  at::OperatorName name = _parseNameForLib(name_str);
  // See Note [Redundancy in registration code is OK]
  TORCH_CHECK(!(f.dispatch_key_.has_value() &&
                dispatch_key_.has_value() &&
                *f.dispatch_key_ != *dispatch_key_),
    IMPL_PRELUDE,
    "Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
    "with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, ").  "
    "Please declare a separate ", toString(kind_), " block for this dispatch key and "
    "move your impl() there.  "
    ERROR_CONTEXT
  );
  auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
  switch (rv) {
    case _RegisterOrVerify::REGISTER:
      registrars_.emplace_back(
        c10::Dispatcher::singleton().registerImpl(
          std::move(name),
          dispatch_key,
          std::move(f.func_),
          std::move(f.cpp_signature_),
          std::move(f.schema_),
          debugString(std::move(f.debug_), file_, line_)
        )
      );
      break;
    case _RegisterOrVerify::VERIFY:
      c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
      break;
  }
  return *this;
}

可见其通过Dispatcher类型的全局单例的registerImpl成员函数把算子注册到Dispatcher中了。

至此可以总结一下:
每次用TORCH_LIBRARY_IMPL宏注册算子的时候,都会生成一个全局唯一TorchLibraryInit类型的静态变量,并在这个静态变量的初始构造函数中调用生成的全局唯一的函数,从而把算子注册到Dispatcher中,而Dispatcher则在pytorch中负责根据tensor的各种信息分配相应的后端算子。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

[标题内容]

2023-12-22 13:57:14

AI教程

深度学习框架DeepSpeed使用教程及配置参数分享

2023-12-22 14:03:00

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索