env_func#

源码:tvm/src/ir/env_func.cc

EnvFuncNodeEnvFunc#

/*!
 * \brief A serializable function backed by TVM's global environment.
 *
 * This is a wrapper to enable serializable global PackedFunc.
 * An EnvFunc is saved by its name in the global registry
 * under the assumption that the same function is registered during load.
 * \sa EnvFunc
 */
class EnvFuncNode : public Object {
 public:
  /*! \brief Unique name of the global function */
  String name;
  /*! \brief The internal packed function */
  runtime::PackedFunc func;
  /*! \brief constructor */
  EnvFuncNode() {}

  void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

  bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
    // name uniquely identifies the env function.
    return name == other->name;
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    // Name uniquely identifies the env function.
    hash_reduce(name);
  }

  static constexpr const char* _type_key = "EnvFunc";
  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;
  TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};

EnvFuncNode 是继承自 Object 的类,它包含字符串类型的成员变量 nameruntime::PackedFunc 类型的成员变量 func。这个类的主要目的是作为 TVM 全局环境的包装器,使得函数可以被序列化。在加载时,通过名称在全局注册表中查找相同的函数。此外,它还提供了一些方法来访问和操作这些成员变量。例如,VisitAttrs 方法允许访问 name 属性,而 SEqualReduceSHashReduce 方法则用于比较两个 EnvFuncNode 对象是否相等以及计算它们的哈希值。在类的声明中,还使用了一些宏来定义一些常量和类型信息。例如,_type_key 常量被定义为 "EnvFunc",表示这个类的类型键; _type_has_method_sequal_reduce_type_has_method_shash_reduce 常量被定义为 true,表示这个类支持相等性和哈希性的计算方法。最后,TVM_DECLARE_FINAL_OBJECT_INFO 宏用于声明这个类的最终对象信息。

/*!
 * \brief Managed reference to EnvFuncNode.
 * \sa EnvFuncNode
 */
class EnvFunc : public ObjectRef {
 public:
  EnvFunc() {}
  explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
  /*! \return The internal global function pointer */
  const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
  /*!
   * \brief Invoke the function.
   * \param args The arguments
   * \returns The return value.
   */
  template <typename... Args>
  runtime::TVMRetValue operator()(Args&&... args) const {
    const EnvFuncNode* n = operator->();
    ICHECK(n != nullptr);
    return n->func(std::forward<Args>(args)...);
  }
  /*!
   * \brief Get a global function based on the name.
   * \param name The name of the global function.
   * \return The created global function.
   * \note The function can be unique
   */
  TVM_DLL static EnvFunc Get(const String& name);
  /*! \brief specify container node */
  using ContainerType = EnvFuncNode;
};

EnvFuncEnvFuncNode 的引用类型,它提供了一种方法来调用内部存储的函数。这个类继承自 ObjectRef 类,并提供了以下功能:

  1. 构造函数:EnvFunc()explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}。这两个构造函数分别用于创建空的 EnvFunc 对象和用给定的 ObjectPtr<Object> 初始化的 EnvFunc 对象。

  2. 获取内部全局函数指针:const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }。这个方法返回指向内部全局函数指针的常量指针。

  3. 调用函数:template <typename... Args> runtime::TVMRetValue operator()(Args&&... args) const。这个方法接受一系列参数,并使用这些参数调用内部全局函数。它返回内部全局函数的返回值。

  4. 根据名称获取全局函数:TVM_DLL static EnvFunc Get(const String& name);。这个方法根据给定的名称在全局环境中查找并返回对应的全局函数。

  5. 指定容器节点:using ContainerType = EnvFuncNode;。这行代码声明了 EnvFunc 类可以作为 EnvFuncNode 类型的容器节点。

总的来说,EnvFuncNodeEnvFunc 提供了一种机制,可以将 TVM 中的函数封装为可序列化的全局环境,并提供了方便的方法来调用这些函数。

TypedEnvFunc#

/*!
 * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
 */
template <typename FType>
class TypedEnvFunc;

/*!
 * \anchor TypedEnvFuncAnchor
 * \brief A typed version of EnvFunc.
 * It is backed by a GlobalFuncNode internally.
 *
 * \tparam R The return value of the function.
 * \tparam Args The argument signature of the function.
 * \sa EnvFunc
 */
template <typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public ObjectRef {
 public:
  /*! \brief short hand for this function type */
  using TSelf = TypedEnvFunc<R(Args...)>;
  TypedEnvFunc() {}
  explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
  /*!
   * \brief Assign global function to a TypedEnvFunc
   * \param other Another global function.
   * \return reference to self.
   */
  TSelf& operator=(const EnvFunc& other) {
    ObjectRef::operator=(other);
    return *this;
  }
  /*! \return The internal global function pointer */
  const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
  /*!
   * \brief Invoke the function.
   * \param args The arguments
   * \returns The return value.
   */
  R operator()(Args... args) const {
    const EnvFuncNode* n = operator->();
    ICHECK(n != nullptr);
    return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func,
                                                                 std::forward<Args>(args)...);
  }
  /*! \brief specify container node */
  using ContainerType = EnvFuncNode;
};

这段代码定义了名为 TypedEnvFunc 的模板类,它是对 EnvFunc 类的泛型版本。TypedEnvFunc 类的主要目的是将全局函数封装为类型安全的函数对象。

TypedEnvFunc 类有两个模板参数:RArgs。其中,R 表示函数的返回值类型,Args 表示函数的参数类型。TypedEnvFunc<R(Args...)> 表示接受 Args... 类型参数并返回 R 类型的函数对象。

TypedEnvFunc 类继承自 ObjectRef 类,因此它具有引用计数功能。它提供了一些成员函数,如 operator=operator->operator(),分别用于赋值、获取内部全局函数指针和调用函数。

operator() 函数中,首先通过 operator-() 获取内部全局函数指针,然后使用 runtime::detail::typed_packed_call_dispatcher<R>::run() 函数调用全局函数,并将结果返回。

此外,TypedEnvFunc 类还定义了名为 ContainerType 的类型别名,用于指定容器节点类型为 EnvFuncNode

EnvFunc 的实现#

/*!
 * \file env_func.cc
 */
#include <tvm/ir/env_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>

namespace tvm {

using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
      auto* op = static_cast<const EnvFuncNode*>(node.get());
      p->stream << "EnvFunc(" << op->name << ")";
    });

ObjectPtr<Object> CreateEnvNode(const std::string& name) {
  auto* f = runtime::Registry::Get(name);
  ICHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
  ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();
  n->func = *f;
  n->name = name;
  return n;
}

EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }

TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);

TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) {
  EnvFunc env = args[0];
  ICHECK_GE(args.size(), 1);
  env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv);
});

TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) {
  return n->func;
});

TVM_REGISTER_NODE_TYPE(EnvFuncNode)
    .set_creator(CreateEnvNode)
    .set_repr_bytes([](const Object* n) -> std::string {
      return static_cast<const EnvFuncNode*>(n)->name;
    });

}  // namespace tvm

这段代码用于处理环境函数。环境函数是一种特殊类型的函数,它们在运行时被调用,而不是在编译时。

代码中定义了两个主要的函数:CreateEnvNodeEnvFunc::Get

CreateEnvNode 函数接受字符串参数 name,这个字符串应该是全局函数的名称。然后,它从注册表中获取这个函数,并创建新的 EnvFuncNode 对象,将这个函数和它的名称存储在这个对象中。最后,它返回这个新创建的对象。

EnvFunc::Get 函数接受字符串参数 name,并使用 CreateEnvNode 函数来创建对应的 EnvFuncNode 对象。然后,它返回这个新创建的对象。

此外,代码还注册了几个全局函数,包括 ir.EnvFuncGetir.EnvFuncCallir.EnvFuncGetPackedFunc。这些函数分别用于获取环境函数、调用环境函数和获取环境函数的打包函数。

最后,代码还注册了节点类型 EnvFuncNode,并设置了它的创建函数和表示函数。创建函数是 CreateEnvNode,表示函数是字符串,表示环境函数的名称。