解读 GraphExecutorCodegen#

以双头网络作为引子#

创建双头输出小网络:

import numpy as np
import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name

x = relay.var("x", shape=(1, 1, 8, 8), dtype="int8")
w = relay.var("w", shape=(2, 1, 3, 3), dtype="int8")
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
mod = tvm.IRModule.from_expr(relay.Tuple([conv2d, relu]))
mod["main"] = bind_params_by_name(mod["main"], 
                                  {"w": tvm.nd.array(np.ones(shape=(2, 1, 3, 3), 
                                                             dtype="int8"))})
rt_lib = relay.build(mod, target="llvm")
rt_lib.params.keys(), rt_lib.params["p0"].shape, rt_lib.params["p0"].dtype
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
(dict_keys(['p0']), (2, 1, 3, 3), 'int8')

此网络结构如下:

print(rt_lib.ir_mod)
def @main(%x: Tensor[(1, 1, 8, 8), int8]) {
  %0 = nn.conv2d(%x, meta[relay.Constant][0], padding=[0, 0, 0, 0]);
  %1 = nn.relu(%0);
  (%0, %1)
}

查看 Graph Json:

import toml
bunch = eval(rt_lib.graph_json)
print(toml.dumps(bunch))
arg_nodes = [ 0, 1,]
heads = [ [ 2, 0, 0,], [ 3, 0, 0,],]
node_row_ptr = [ 0, 1, 2, 3, 4,]
[[nodes]]
op = "null"
name = "x"
inputs = []

[[nodes]]
op = "null"
name = "p0"
inputs = []

[[nodes]]
op = "tvm_op"
name = "tvmgen_default_fused_nn_conv2d"
inputs = [ [ 0, 0, 0,], [ 1, 0, 0,],]

[nodes.attrs]
num_outputs = "1"
num_inputs = "2"
flatten_data = "0"
func_name = "tvmgen_default_fused_nn_conv2d"
out_layout = ""
kernel_layout = "OIHW"
data_layout = "NCHW"
hash = "8f5bab575bcb83dc"
[[nodes]]
op = "tvm_op"
name = "tvmgen_default_fused_nn_relu"
inputs = [ [ 2, 0, 0,],]

[nodes.attrs]
num_outputs = "1"
num_inputs = "1"
flatten_data = "0"
func_name = "tvmgen_default_fused_nn_relu"
hash = "fd6e720bc47ba75c"

[attrs]
dltype = [ "list_str", [ "int8", "int8", "int8", "int8",],]
device_index = [ "list_int", [ 1, 1, 1, 1,],]
storage_id = [ "list_int", [ 0, 1, 2, 3,],]
shape = [ "list_shape", [ [ 1, 1, 8, 8,], [ 2, 1, 3, 3,], [ 1, 2, 6, 6,], [ 1, 2, 6, 6,],],]

解读 CreateGraphCodegenMod 源码#

定义计算图节点类型枚举类:

/*! \brief Node types */
enum GraphNodeType {
  kGraphNop,
  kGraphInputNode,
  kGraphOpNode,
};

使用 Python 实现为:

from enum import Enum


class GraphNodeType(Enum):
    """节点枚举类型
    Attrs:
        kGraphNop: 非算子节点
        kGraphInputNode: 参数节点的索引列表,它是计算图的占位符/变量/输入节点 或 constant/param。
        kGraphOpNode: 算子节点
    """
    kGraphNop: int = 0
    kGraphInputNode: int = 1
    kGraphOpNode: int = 2

节点基类定义如下:

/*! \brief Base Node class */
class GraphNode {
 public:
  GraphNode() {}
  virtual void Save(dmlc::JSONWriter* writer) const {}
  virtual void Load(dmlc::JSONReader* reader) {}
  virtual GraphNodeType Type() const { return kGraphNop; }
  virtual ~GraphNode() {}

 public:
  int num_outputs_{1};
  std::string name_;
  GraphAttrs attrs_;
};

使用 Python 实现如下:

from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod

GraphAttrs = dict[str, Any]

@dataclass
class GraphNode(ABC):
    name: str
    attrs: GraphAttrs
    
    @abstractmethod
    def Save(self, writer) -> None:
        ...

    @abstractmethod
    def Load(self, reader) -> None:
        ...

    @abstractmethod
    def Type(self) -> GraphNodeType:
        return GraphNodeType.kGraphNop

输入节点:

/*! \brief Input Node */
class GraphInputNode : public GraphNode {
 public:
  GraphInputNode() {}
  GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
    name_ = name;
    attrs_ = attrs;
  }

  GraphNodeType Type() const override { return kGraphInputNode; }

  void Save(dmlc::JSONWriter* writer) const override {
    const std::string op_name{"null"};
    writer->BeginObject();
    writer->WriteObjectKeyValue("op", op_name);
    writer->WriteObjectKeyValue("name", this->name_);
    writer->WriteObjectKeyValue("inputs", std::list<int>());
    writer->EndObject();
  }
  static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
                                                  const GraphAttrs& attrs) {
    auto ptr = std::make_shared<GraphInputNode>(name, attrs);
    return std::dynamic_pointer_cast<GraphNode>(ptr);
  }
};

使用 Python 实现:

@dataclass
class GraphInputNode(GraphNode):
    inputs: list[int]

    def Type(self) -> GraphNodeType:
        return GraphNodeType.kGraphInputNode
    def Save(self, writer) -> None:
        bunch = {
            "op": "null",
            "name": self.name,
            "inputs": []
        }
        # 写入到 writer 句柄
        ...

    def Load(self, reader) -> None:
        ...

    def make_node_ptr(self):
        # make_node(name, attrs)
        ...

同样使用 Python 实现算子节点类:

@dataclass
class GraphNodeRef:
    ident: int # 节点引用索引
    index: int = 0 # 暂不知作用
    version: int = 0 # 暂不知作用

@dataclass
class GraphOpNode(GraphNode):
    nd_attrs: GraphAttrs
    op_name: str
    inputs: list[GraphNodeRef]
    num_outputs: int = 1

    def __post_init__(self):
        self.attrs["func_name"] = self.op_name
        self.attrs["flatten_data"] = "0"
        self.attrs["num_inputs"] = str(sum(self.inputs))
        self.attrs["num_outputs"] = str(self.num_outputs)

    def Type(self) -> GraphNodeType:
        return GraphNodeType.kGraphOpNode
    
    def Save(self, writer) -> None:
        bunch = {
            "op": "tvm_op",
            "name": self.name,
            "attrs": self.attrs,
            "inputs": self.inputs
        }
        # 写入到 writer 句柄
        ...

    def Load(self, reader) -> None:
        ...

    def make_node_ptr(self):
        # make_node(name, nd_attrs, op_name, inputs, attrs, num_outputs)
        ...

下面进入正题:

代码生成器 GraphExecutorCodegen#

图执行器的代码生成器,生成包含 Graph JSON、模块和模块的参数。

@dataclass
class LoweredOutput:
    graph_json: str
    lowered_funcs: dict[str, tvm.IRModule]
    external_mods: list[tvm.IRModule]
    params: dict[str, tvm.runtime.NDArray]


@dataclass
class GraphExecutorCodegen:
    mod: tvm.runtime.Module
    targets: list[tvm.target.Target]

    def GetStorageInfo(self, expr) -> "tvm.relay.backend.StorageInfo":
        """获取单个表达式的存储信息"""
        ...

    def Codegen(self, mod: tvm.IRModule,
                func: relay.Function,
                mod_name: str) -> "tvm.relay.backend.LoweredOutput":
        """
        1. lowering 前需要规划内存并更新 workspace 大小
        2. 获取 lowered_main_func
        3. 将所有参数转换为输入节点。
        4. 收集外部代码生成的任何运行时模块。
        5. 收集外部代码提取的任何常量。
        6. 收集在 lowering 过程中提取的任何常数。
        7. 按目标分隔模块中的函数
        8. 需要保存 Graph Json 到输出
        """
        ...

回到双头网络的例子中#

下面仔细解读这些 Graph Json 信息。

由于双头网络有两个输出,故而

  1. heads = [ [ 2, 0, 0,], [ 3, 0, 0,],] 指示两个输出节点的索引。

  2. arg_nodes = [ 0, 1,] 说明参数节点的位置。

使用 Python 实现:

from dataclasses import field


@dataclass
class GraphAttrs:
    """`
    Args:
        dltype: 每个节点的数据类型按顺序排列。
        device_index: 按顺序为每个节点分配设备。
        storage_id: 存储布局中每个节点的内存 slot id。
        shape: 每个节点的 k 阶形状。
        storage_id: 存储布局中每个节点的内存 slot id。
                    将参数名称映射到一对 ({storage_id: tvm.runtime.NDArray})。在运行时,可以使用 storage_id 查找参数。
    """
    dltype: list
    device_index: list
    storage_id: list
    shape: list


@dataclass
class GraphNodeAttrs:
    """
    Args:
        flatten_data: 是否需要在执行前将数据扁平化(flattened)
        func_name: 融合函数名,对应于 Relay 编译过程生成的库中的符号。
        num_inputs: 此节点的 inputs 个数
        num_outputs: 此节点产生的 outputs 个数
    """
    func_name: str
    num_inputs: str
    num_outputs: str
    flatten_data: str = "0"
    hash: str|None = None
    


@dataclass
class GraphNode:
    """
    Args:
        op: 运算类型,`null` 意味着它是占位符/变量/输入节点,`tvm_op` 意味着这个节点可以被执行
        name: 节点名字
        inputs: 运算的 inputs 位置,inputs 是包含 `(nodeid, index, version)` 的元组列表。(可选)
    """
    op: str
    name: str
    inputs: list[int] = field(default_factory=list)
    attrs: Any = None


@dataclass
class GraphJson:
    """
    Args:
        arg_nodes:参数节点的索引列表,它是计算图的占位符/变量/输入节点或 constant/param。
        heads: 输出节点的索引列表。
        node_row_ptr: 存储 forward 路径的历史,所以推断任务中可以跳过某些算子来构建子图。
        attrs: 可以包含版本号或类似的有用信息。
        nodes: 节点是占位符或可计算节点。
    """
    arg_nodes: list[int]
    heads: list[GraphNodeRef]
    node_row_ptr: list[int]
    attrs: GraphAttrs
    nodes: list[GraphNode]

    def __post_init__(self):
        self.heads = [GraphNodeRef(*head) for head in self.heads]
        self.attrs = GraphAttrs(**self.attrs)
        self.nodes = [GraphNode(**node) for node in self.nodes]

备注

代码被维护在 tvm_book API 中。

from dataclasses import asdict
from tvm_book.tvm_utils.graph_json import GraphJson
from tvm_book.data.dataclass import TensorType


@dataclass
class Node:
    inputs: list[TensorType]
    outputs: list[TensorType]
    attrs: dict[str, Any]


graph_json = GraphJson(**eval(rt_lib.graph_json))

转换为字典:

asdict(graph_json).keys()
dict_keys(['arg_nodes', 'heads', 'node_row_ptr', 'attrs', 'nodes'])

其他信息:

graph_json.heads
[GraphNodeRef(ident=2, index=0, version=0),
 GraphNodeRef(ident=3, index=0, version=0)]
graph_json.attrs
GraphAttrs(dltype=['list_str', ['int8', 'int8', 'int8', 'int8']], device_index=['list_int', [1, 1, 1, 1]], storage_id=['list_int', [0, 1, 2, 3]], shape=['list_shape', [[1, 1, 8, 8], [2, 1, 3, 3], [1, 2, 6, 6], [1, 2, 6, 6]]])
graph_json.nodes
[GraphNode(op='null', name='x', inputs=[], attrs=None),
 GraphNode(op='null', name='p0', inputs=[], attrs=None),
 GraphNode(op='tvm_op', name='tvmgen_default_fused_nn_conv2d', inputs=[[0, 0, 0], [1, 0, 0]], attrs={'num_outputs': '1', 'num_inputs': '2', 'flatten_data': '0', 'func_name': 'tvmgen_default_fused_nn_conv2d', 'out_layout': '', 'kernel_layout': 'OIHW', 'data_layout': 'NCHW', 'hash': '8f5bab575bcb83dc'}),
 GraphNode(op='tvm_op', name='tvmgen_default_fused_nn_relu', inputs=[[2, 0, 0]], attrs={'num_outputs': '1', 'num_inputs': '1', 'flatten_data': '0', 'func_name': 'tvmgen_default_fused_nn_relu', 'hash': 'fd6e720bc47ba75c'})]
graph_json.attrs.shape
['list_shape', [[1, 1, 8, 8], [2, 1, 3, 3], [1, 2, 6, 6], [1, 2, 6, 6]]]
attrs = []
dtypes = graph_json.attrs.dltype[1]
device_indexes = graph_json.attrs.device_index[1]
storage_ids = graph_json.attrs.storage_id[1]
shapes = graph_json.attrs.shape[1]
for shape, dtype, storage_id, device_index, node in zip(shapes, dtypes, storage_ids, device_indexes, graph_json.nodes):
    attr = {
        "storage_id": storage_id,
        "device_index": device_index,
        "inputs": node.inputs,
        "op": node.op,
        "op_type": TensorType(shape=shape, dtype=dtype, name=node.name),
    }
    if node.name == "tvm_op":
        attr.update(**node.attrs)
    attrs.append(attr)
attrs
[{'storage_id': 0,
  'device_index': 1,
  'inputs': [],
  'op': 'null',
  'op_type': TensorType(shape=[1, 1, 8, 8], dtype='int8', name='x')},
 {'storage_id': 1,
  'device_index': 1,
  'inputs': [],
  'op': 'null',
  'op_type': TensorType(shape=[2, 1, 3, 3], dtype='int8', name='p0')},
 {'storage_id': 2,
  'device_index': 1,
  'inputs': [[0, 0, 0], [1, 0, 0]],
  'op': 'tvm_op',
  'op_type': TensorType(shape=[1, 2, 6, 6], dtype='int8', name='tvmgen_default_fused_nn_conv2d')},
 {'storage_id': 3,
  'device_index': 1,
  'inputs': [[2, 0, 0]],
  'op': 'tvm_op',
  'op_type': TensorType(shape=[1, 2, 6, 6], dtype='int8', name='tvmgen_default_fused_nn_relu')}]