node functor

node functor#

源码:tvm/include/tvm/node/functor.h

NodeFunctor#

NodeFunctor 的模板类,它用于根据第一个参数的类型动态分派函数。这个类在构造基于 AST/IR 节点类型的多态分派时非常有用。

/*!
 * \brief A dynamically dispatched functor on the type of the first argument.
 *
 * This is a class that is useful to construct polymorphic dispatching
 * base on the AST/IR node's type.
 *
 * \code
 *   NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
 *   tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
 *     return prefix + "Add";
 *   });
 *   tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
 *     return prefix + "IntImm"
 *   });
 *
 *   Expr x = make_const(1);
 *   Expr y = x + x;
 *   // dispatch to IntImm, outputs "MyIntImm"
 *   LOG(INFO) << tostr(x, "My");
 *   // dispatch to IntImm, outputs "MyAdd"
 *   LOG(INFO) << tostr(y, "My");
 * \endcode
 *
 * \tparam FType function signiture
 *  This type if only defined for FType with function signature
 */
template <typename FType>
class NodeFunctor;

template <typename R, typename... Args>
class NodeFunctor<R(const ObjectRef& n, Args...)> {
 private:
  /*! \brief internal function pointer type */
  typedef R (*FPointer)(const ObjectRef& n, Args...);
  /*! \brief refer to itself. */
  using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
  /*! \brief internal function table */
  std::vector<FPointer> func_;

 public:
  /*! \brief the result type of this functor */
  using result_type = R;
  /*!
   * \brief Whether the functor can dispatch the corresponding Node
   * \param n The node to be dispatched
   * \return Whether dispatching function is registered for n's type.
   */
  bool can_dispatch(const ObjectRef& n) const {
    uint32_t type_index = n->type_index();
    return type_index < func_.size() && func_[type_index] != nullptr;
  }
  /*!
   * \brief invoke the functor, dispatch on type of n
   * \param n The Node argument
   * \param args The additional arguments
   * \return The result.
   */
  R operator()(const ObjectRef& n, Args... args) const {
    ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
                            << n->GetTypeKey();
    return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
  }
  /*!
   * \brief set the dispatcher for type TNode
   * \param f The function to be set.
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template <typename TNode>
  TSelf& set_dispatch(FPointer f) {  // NOLINT(*)
    uint32_t tindex = TNode::RuntimeTypeIndex();
    if (func_.size() <= tindex) {
      func_.resize(tindex + 1, nullptr);
    }
    ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
    func_[tindex] = f;
    return *this;
  }
  /*!
   * \brief unset the dispatcher for type TNode
   *
   * \tparam TNode the type of Node to be dispatched.
   * \return reference to self.
   */
  template <typename TNode>
  TSelf& clear_dispatch() {  // NOLINT(*)
    uint32_t tindex = TNode::RuntimeTypeIndex();
    ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
    func_[tindex] = nullptr;
    return *this;
  }
};

NodeFunctor 类有两个模板参数:R 表示返回值类型,Args... 表示函数的其他参数类型。它的内部包含函数指针类型 FPointer,用于存储指向特定类型的函数的指针。此外,还有名为 func_ 的内部向量,用于存储这些函数指针。

NodeFunctor 类提供了以下成员函数:

  1. can_dispatch(const ObjectRef& n) const:检查是否可以对给定的节点进行分派。如果节点的类型索引小于 func_ 的大小且对应的函数指针不为空,则返回 true

  2. operator()(const ObjectRef& n, Args... args) const:调用分派函数。首先检查是否可以对给定的节点进行分派,然后使用节点的类型索引从 func_ 中获取相应的函数指针,并调用该函数。

  3. set_dispatch(FPointer f):为特定类型的节点设置分派函数。首先计算节点类型的运行时类型索引,然后调整 func_ 的大小以容纳新的函数指针(如果需要),并将新函数指针设置为给定的函数。

  4. clear_dispatch():清除特定类型的节点的分派函数。首先计算节点类型的运行时类型索引,然后将对应的函数指针设置为 nullptr

NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
    return prefix + "Add";
});
tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
    return prefix + "IntImm"
});
Expr x = make_const(1);
Expr y = x + x;
// dispatch to IntImm, outputs "MyIntImm"
LOG(INFO) << tostr(x, "My");
// dispatch to IntImm, outputs "MyAdd"
LOG(INFO) << tostr(y, "My");

这段代码定义了名为 tostrNodeFunctor 对象,该对象用于将不同类型的节点转换为字符串。NodeFunctor 是模板类,接受函数签名作为参数,该函数签名表示如何将节点转换为字符串。

在这段代码中,NodeFunctor 的函数签名为std::string (const ObjectRef& n, std::string prefix),表示它接受 ObjectRef 类型的节点和字符串前缀,并返回字符串。

接下来,使用 set_dispatch 方法为 NodeFunctor 设置两个分派函数。第一个分派函数处理 Add 类型的节点,它将节点转换为字符串并将前缀添加到字符串末尾。第二个分派函数处理 IntImm 类型的节点,它也将节点转换为字符串并将前缀添加到字符串末尾。

然后,创建两个表达式 xy,其中 x 是常量节点,值为 1。通过将这两个表达式传递给 tostr 对象,可以将其转换为字符串。由于 xAdd 类型的节点,因此调用 tostr(x, "My") 时,将调用第一个分派函数,输出结果为 "MyAdd"。同样,由于 yIntImm 类型的节点,因此调用 tostr(y, "My") 时,将调用第二个分派函数,输出结果为 "MyIntImm"

TVM_STATIC_IR_FUNCTOR#

#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName

这段代码是宏定义,用于生成一个名为 __make_functor##_##ClsName 的函数对象。这个函数对象是静态的(static)并且具有 TVM_ATTRIBUTE_UNUSED 属性,表示它不会被使用。

解析如下:

  1. #define 是 C/C++ 预处理器指令,用于定义宏。

  2. TVM_REG_FUNC_VAR_DEF(ClsName) 是宏的名称,其中 ClsName 是参数,表示类名。

  3. static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName 是宏的定义部分。

    • static 表示这是静态成员函数。

    • TVM_ATTRIBUTE_UNUSED 是属性,表示该变量或函数未被使用,编译器不会发出警告。

    • auto& 表示返回值类型为引用到自动类型的变量。

    • __make_functor##_##ClsName 是生成的函数对象的名称,其中 ## 是连接符,用于将两个字符串连接在一起。

综上所述,这段代码的作用是定义名为 __make_functor##_##ClsName 的静态函数对象,该函数对象具有 TVM_ATTRIBUTE_UNUSED 属性,表示它不会被使用。

/*!
 * \brief Useful macro to set NodeFunctor dispatch in a global static field.
 *
 * \code
 *  // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
 *  // vtable allows easy patch of new Node types, without changing
 *  // interface of ReprPrinter.
 *
 *  class ReprPrinter {
 *   public:
 *    std::ostream& stream;
 *    // the dispatch function.
 *    void print(Expr e) {
 *      const static FType& f = *vtable();
 *      f(e, this);
 *    }
 *
 *    using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
 *    // function to return global function table
 *    static FType& vtable();
 *  };
 *
 *  // in cpp/cc file
 *  ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
 *    static FType inst; return inst;
 *  }
 *
 *  TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 *  .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
 *    auto* n = static_cast<const Add*>(ref.get());
 *    p->print(n->a);
 *    p->stream << '+'
 *    p->print(n->b);
 *  });
 *
 *
 * \endcode
 *
 * \param ClsName The name of the class
 * \param FField The static function that returns a singleton of NodeFunctor.
 */
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
  TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()

这段代码定义了名为 TVM_STATIC_IR_FUNCTOR 的宏,用于设置 NodeFunctor 的调度。NodeFunctor 是一种用于实现类似于访问者模式的函数对象。

在这段代码中,ReprPrinter 类使用了 NodeFunctor 来实现打印功能。通过使用 vtable,可以轻松地为新的节点类型添加新的调度函数,而无需更改 ReprPrinter 接口。

ReprPrinter::FType& ReprPrinter::vtable() 函数返回全局函数表。这个函数表是一个静态成员变量,它存储了 NodeFunctor 的实例。

TVM_STATIC_IR_FUNCTOR(ClsName, FField) 宏的作用是将 ClsName 类的 FField 函数作为 NodeFunctor 的调度函数添加到全局函数表中。这样,当调用 print 方法时,会根据节点的类型选择相应的调度函数进行处理。