register_object

register_object#

tvm.register_object(type_key=None) 实现的关键接口是 _LIB.TVMObjectTypeKey2Index,函数的作用是根据给定的 key 获取对应的类型索引。(根据注释的说明,具体的实现细节可能在其他地方进行定义。如果你需要使用这个函数,可以在代码中包含该函数的声明,并在需要的地方调用它来获取类型索引。)拿到索引后,调用 _register_object(tindex, cls) 在 Python 端完成注册。

_LIB.TVMObjectTypeKey2Index 的实现如下查找链路:

//   src/runtime/object.cc
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
  API_BEGIN();
  out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
  API_END();
}

-> tvm::runtime::ObjectInternal::ObjectTypeKey2Index 定义如下:

static uint32_t ObjectTypeKey2Index(const std::string& type_key) {
  return Object::TypeKey2Index(type_key);
}

-> Object::TypeKey2Index 定义如下:

uint32_t Object::TypeKey2Index(const std::string& key) {
  return TypeContext::Global()->TypeKey2Index(key);
}

->

uint32_t TypeKey2Index(const std::string& skey) {
    auto it = type_key2index_.find(skey);
    ICHECK(it != type_key2index_.end())
        << "Cannot find type " << skey
        << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?";
    return it->second;
  }

->

std::unordered_map<std::string, uint32_t> type_key2index_;

TVMObjectTypeKey2Index 接受两个参数:一个指向字符类型的指针 type_key 和一个指向无符号整数类型的指针 out_tindex。函数的返回类型是 int

函数的参数解释如下:

  • const char* type_key:表示类型键的字符串指针。

  • unsigned* out_tindex:指向无符号整数的指针,用于存储转换后的类型索引。

函数的返回值解释如下:

  • 当成功时,返回 0

  • 当失败时,返回非零值。

如果你需要使用这个函数,可以在代码中包含该函数的声明,并在需要的地方调用它来将类型键转换为类型索引。

在上述查找过程发现:TVM_REGISTER_NODE_TYPE 宏,用于注册 key2index 的绑定。

#define TVM_REGISTER_NODE_TYPE(TypeName)                                             \
  TVM_REGISTER_OBJECT_TYPE(TypeName);                                                \
  TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
      .set_creator([](const std::string&) -> ObjectPtr<Object> {                     \
        return ::tvm::runtime::make_object<TypeName>();                              \
      })

TVM_REGISTER_NODE_TYPE 宏用于在 C++ 中注册节点类型。

首先,TVM_REGISTER_NODE_TYPE(TypeName) 宏定义了函数调用,该函数调用了 TVM_REGISTER_OBJECT_TYPE(TypeName)TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) 两个函数。

  • TVM_REGISTER_OBJECT_TYPE(TypeName) 函数用于注册对象类型,将给定的类型名称 TypeName 与相应的对象类型关联起来。

  • TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) 函数用于注册反射虚函数表(vtable),将给定的类型名称 TypeName 与相应的反射虚函数表关联起来。这个虚函数表中包含了该类型的反射方法。

  • 接下来,.set_creator([](const std::string&) -> ObjectPtr<Object> {...}) 是可选的设置函数,用于指定如何创建该类型的对象实例。在这个例子中,使用了 lambda 表达式作为创建函数,它接受字符串参数,并返回新创建的 TypeName 类型的对象实例。

综上所述,这段代码的作用是注册节点类型,并提供了创建该类型对象实例的方法。

register_object 示例#

src/tvm_ext.cc 中定义 test.BaseObj

#include <string.h>
#include <tvm/runtime/object.h>
#include <tvm/node/reflection.h>


namespace tvm {
namespace runtime {
class TestNode :public Object {
public:
    // 对象字段
    std::string name;
    // 对象属性
    static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
    static constexpr const char* _type_key = "app.TestNode";
    // 告诉 TVM 编译器,TestNode 类是 Object 类的子类,
    // 并且需要在编译时进行一些特殊的处理。
    TVM_DECLARE_BASE_OBJECT_INFO(TestNode, Object);
    void VisitAttrs(AttrVisitor* v) {
        v->Visit("name", &name);
    }
};
TVM_REGISTER_NODE_TYPE(TestNode); // 注册节点类型
}
}

在 Python 端调用:

import tvm
from tvm.runtime import Object
from tvm._ffi.base import _LIB
import ctypes
# _LIB.TVMObjectTypeKey2Index

def load_dll(lib_path="lib/libtvm_ext.so"):
    """加载库,函数将被注册到 TVM"""
    # 作为全局加载,这样全局 extern symbol 对其他 dll 是可见的。
    # curr_path = f"{ROOT}/"
    lib = ctypes.CDLL(lib_path, ctypes.RTLD_GLOBAL)
    return lib
load_dll("./libs/libtvm_ext.so")
<CDLL './libs/libtvm_ext.so', handle 3b36bf0 at 0x7f1e3c4f68f0>
node = tvm.ir.make_node("app.TestNode", name="A")
node
app.TestNode(0x46bcb20)

或者:

@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
    def __init__(self, handle):
        """Initialize the function with handle

        Parameters
        ----------
        handle : SymbolHandle
            the handle to the underlying C++ Symbol
        """
        super().__init__(handle)
        self.handle = handle

如果想要改变 node 实例的显示内容,也可以在 C++ 端写入:

#include <tvm/node/repr_printer.h>
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<TestNode>([](const ObjectRef& ref, ReprPrinter* p) {
      auto* op = static_cast<const TestNode*>(ref.get());
      p->stream << "Test(";
      p->stream << "name=" << op->name<< ", ";
      p->stream << ")";
    });
@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
    def __init__(self, handle):
        """Initialize the function with handle

        Parameters
        ----------
        handle : SymbolHandle
            the handle to the underlying C++ Symbol
        """
        super().__init__(handle)
        self.handle = handle

node = tvm.ir.make_node("app.TestNode", name="A")
node
Test(name=A, )

或者直接在 Python 端改写:

@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
    def __init__(self, handle):
        """Initialize the function with handle

        Parameters
        ----------
        handle : SymbolHandle
            the handle to the underlying C++ Symbol
        """
        super().__init__(handle)
        self.handle = handle

    def __repr__(self):
        return f"{self.__class__.__name__}_{self.name}"

node = tvm.ir.make_node("app.TestNode", name="A")
node
TestNode_A