TVM 节点反射#

tvm/include/tvm/node/reflection.h 是 TVM(Tensor Virtual Machine)库中的一个头文件,用于实现 TVM 中的反射机制。名为 Reflector 的类,它是整个反射机制的核心。Reflector 类的主要作用是通过序列化和反序列化操作,将计算图中的各种节点、参数和数据结构进行转换,以便在不同的硬件平台上进行部署。

Reflector 类中的方法包括:

  1. Reflector::Init():初始化 Reflector 对象。在构造函数中调用此方法。

  2. Reflector::Run():执行反射操作。首先对计算图进行序列化,然后根据目标平台对序列化后的数据进行反序列化,最后执行反序列化后的计算图。

  3. Reflector::Export():导出指定节点的信息。将指定节点的信息导出到一个字符串中。

  4. Reflector::Import():导入指定节点的信息。从一个字符串中读取节点信息,并将其反序列化为一个 ReflectorNode 对象。

  5. Reflector::GetAttrs():获取指定节点的属性列表。返回一个包含属性名称和值的映射(std::unordered_map<string, AttrValue>)。

  6. Reflector::SetAttrs():设置指定节点的属性列表。使用给定的属性值更新节点的属性。

  7. Reflector::ResetGraph():重置计算图。清除所有节点、参数和数据结构。

  8. Reflector::LoadGraph():加载计算图。从磁盘或其他存储介质中读取计算图的数据结构,并反序列化为 ReflectorNode 对象。

  9. Reflector::FindNode():查找指定名称的节点。返回一个指向具有指定名称的节点的指针。

  10. Reflector::FindOutput(const std::string& name):查找具有指定名称的输出节点。返回一个指向具有指定名称的输出节点的指针。

  11. Reflector::FindInput(const std::string& name):查找具有指定名称的输入节点。返回一个指向具有指定名称的输入节点的指针。

  12. Reflector::FindNextNode(const ReflectorNode* node):查找给定节点的下一个节点。返回一个指向下一个节点的指针,如果没有找到,则返回 nullptr

  13. Reflector::FindAllNodes(const std::function<bool(const ReflectorNode*)>& filter):查找满足给定过滤条件的所有节点。返回一个包含满足条件的节点指针的列表。

  14. Reflector::FindSubgraph(const std::vector<const ReflectorNode*>& nodes):查找给定节点集合所在的子图。返回一个表示子图的对象,该对象包含了子图中的所有节点和连接关系。

  15. Reflector::DumpGraph():将计算图以文本形式输出到标准输出(或指定的文件)。

NodeGetAttrNodeListAttrNamesMakeNode#

  1. NodeGetAttr 函数用于获取对象的属性值。它接受两个参数:argsretargs 是包含输入参数的数组,ret 是指向返回值的指针。首先,代码检查 args[0] 的类型码是否为 kTVMObjectHandle,然后将 args[0] 的值转换为 Object* 类型。接下来,它调用 ReflectionVTable::Global()->GetAttr 函数来获取对象的属性值,并将结果存储在 ret 指向的位置。

  2. NodeListAttrNames 函数用于列出对象的所有属性名称。它也接受两个参数:argsretargs 是包含输入参数的数组,ret 是指向返回值的指针。首先,代码检查 args[0] 的类型码是否为 kTVMObjectHandle,然后将 args[0] 的值转换为 Object* 类型。接下来,它调用 ReflectionVTable::Global()->ListAttrNames 函数来获取对象的属性名称列表,并将其存储在新的 std::vector<std::string> 对象中。最后,它创建包装器函数,该函数接受整数参数 i,并根据 i 的值返回相应的属性名称或属性名称列表的大小。

  3. MakeNode 函数用于创建新的对象。它接受 const TVMArgs& 类型的参数 args 和指向返回值的指针 rv。首先,代码从 args 中提取对象的类型键(type_key),并创建新的 TVMArgs 对象 kwargs,其中包含剩余的参数。然后,它调用 ReflectionVTable::Global()->CreateObject 函数来创建新的对象,并将结果存储在 rv 指向的位置。

查看对应的 Python 接口示例:

import tvm

# MakeNode -> tvm.ir.make_node
x = tvm.ir.make_node("IntImm", dtype="int32", value=10, span=None)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10

其余两个类被打包到 Object:

tvm.runtime.Object.__getattr__??
Signature: tvm.runtime.Object.__getattr__(self, name)
Docstring: <no docstring>
Source:   
    def __getattr__(self, name):
        # specially check handle since
        # this is required for PackedFunc calls
        if name == "handle":
            raise AttributeError("handle is not set")

        try:
            return _ffi_node_api.NodeGetAttr(self, name)
        except AttributeError:
            raise AttributeError(f"{type(self)} has no attribute {name}") from None
File:      /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/runtime/object.py
Type:      function
tvm.runtime.Object.__dir__??
Signature: tvm.runtime.Object.__dir__(self)
Docstring: Default dir() implementation.
Source:   
    def __dir__(self):
        class_names = dir(self.__class__)
        fnames = _ffi_node_api.NodeListAttrNames(self)
        size = fnames(-1)
        return sorted([fnames(i) for i in range(size)] + class_names)
File:      /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/runtime/object.py
Type:      function