解读 ExprFunctor#

tvm/include/tvm/relay/expr_functor.h 是名为 expr_functor 的函数访问者(visitor),它具有更强大的动态分派功能,可以定义具有基于第一个参数的类型分派的任意函数签名。

在计算机编程中,访问者模式是一种设计模式,用于处理不同类型的对象结构。通过使用访问者模式,可以将对不同对象的操作集中在一个或多个访问者类中,从而实现统一的接口和逻辑。

ExprFunctor#

template <typename FType>
class ExprFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
  { return VisitExprDefault_(op, std::forward<Args>(args)...); }

#define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                                    \
  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {     \
    return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
  });

这段代码是 C++ 模板类的定义,用于实现动态函数对象(functional object),该函数对象可以根据第一个表达式参数的类型进行分派。具体来说,这个类名为 ExprFunctor,它是一个模板类,使用了类型参数 FType 来表示函数签名。根据注释中的描述,FType 应该具有函数签名 R(const Expr&, Args...),其中 R 是返回类型,Expr 是第一个参数的类型,Args 是其他参数的类型。

在代码中,看到两个宏定义:

  • EXPR_FUNCTOR_DEFAULT:这是默认的函数体,用于处理没有特定重载版本的函数调用。它使用 VisitExprDefault_ 函数来处理传入的表达式,并将结果返回。

  • RELAY_EXPR_FUNCTOR_DISPATCH(OP):这是用于分发函数调用的宏定义。它使用了虚函数表(vtable)和 set_dispatch 方法来实现基于算子(OP)的类型分派。当调用该函数对象时,会根据传入的算子类型选择相应的重载版本进行处理。

ExprFunctor 类的主要目的是在表达式树中进行访问操作。它通过重载 operator() 函数来实现对表达式节点的调用,并通过 VisitExpr 函数来处理不同类型的节点。EXPR_FUNCTOR_DEFAULT 宏用于生成默认的可调用对象,而 RELAY_EXPR_FUNCTOR_DISPATCH 宏用于设置节点分派的函数对象。

  • ExprFunctor 类的实现中,首先定义了私有成员变量 vtable,它是类型为 FType 的函数对象。然后,通过调用 InitVTable 函数来初始化 vtable

  • InitVTable 函数使用 RELAY_EXPR_FUNCTOR_DISPATCH 宏来设置不同类型节点的分派函数对象。每个分派函数对象都接受 ConstantNode 指针作为参数,并返回结果。

  • 最后,ExprFunctor 类的构造函数是虚析构函数,确保当删除 ExprFunctor 对象时,能够正确地调用其析构函数。

总的来说,这段代码实现了灵活强大的函数对象,可以在表达式树中进行访问操作,并根据节点的类型选择相应的处理方式。

ExprVisitor#

ExprVisitortvm::relay::ExprFunctor 的子类。ExprVisitorExpr 视为数据流图,并且每个 Expr 节点只访问一次。

ExprVisitor 类中包含了多个重载的 VisitExpr 函数,每个函数都接受 const Expr& 类型的参数,用于处理不同类型的 Expr 节点。这些重载函数根据节点的类型调用相应的 VisitExpr_ 函数进行处理。除了处理 Expr 节点外,ExprVisitor 还定义了一些其他的虚函数,如 VisitTypeVisitClauseVisitPatternVisitSpan,用于处理其他类型的节点。在 ExprVisitor 类中还定义了受保护的成员变量 visit_counter_,它是无序的哈希表,用于记录每个节点被访问的次数。

MixedModeVisitor#

MixedModeVisitortvm::relay::ExprVisitor 的子类。MixedModeVisitorExpr 视为数据流图,并按照后序深度优先搜索(DFS)的顺序进行访问。MixedModeVisitor 提供了与 ExprVisitor 相同的递归 API,并使用递归来遍历 IR 的大多数形式,但在底层,它会展开图中嵌套的数据流区域,并以迭代的方式处理它们,以防止堆栈溢出。

MixedModeVisitor 类中还定义了一些受保护的成员变量和函数。其中,visit_limit_ 表示允许访问节点的最大次数,通常为 1,有时为 2(例如用于消除死代码),但限制为 10 作为合理性检查。

  • VisitLeaf 是虚函数,当到达图的叶子节点时调用,以非递归方式应用。

  • CheckVisited 是虚函数,用于确定表达式是否已经被访问过或者需要重新访问。

VisitExpr 函数被声明为 final,以保留数据流区域的调用扩展。它还重载了多个版本的 VisitExpr_ 函数,用于处理不同类型的节点。

ExprMutator#

ExprMutator 类是 tvm::relay::ExprFunctor 的子类。ExprMutatorExpr 视为数据流图,并且每个 Expr 只进行一次变更。ExprMutator 类中包含了多个重载的 VisitExpr 函数,每个函数都接受 const Expr& 类型的参数,用于处理不同类型的 Expr 节点。这些重载函数根据节点的类型调用相应的 VisitExpr_ 函数进行处理。除了处理 Expr 节点外,ExprMutator 还定义了一些其他的虚函数,如 VisitTypeVisitClauseVisitPattern,用于处理其他类型的节点。

ExprMutator 类中还定义了受保护的成员变量 memo_,它是无序的哈希表,用于记录每个节点被访问的次数。这个哈希表用于实现结果的缓存,以提高后续相同表达式的访问效率。

MixedModeMutator#

MixedModeMutatortvm::relay::ExprMutator 的子类。MixedModeMutatorExpr视为数据流图,并只重写每个 Expr 一次。重写后的结果被缓存在映射中并重复使用,以便数据流上的局部转换保持图结构。

MixedModeMutator 提供了与 ExprMutator 相同的递归 API,并使用递归来遍历IR的大多数形式,但在实际实现中,它会展开图中嵌套的数据流区域,并以迭代的方式处理它们,以防止堆栈溢出。

该类使用了 ExprRewriterRewrite_ API,以实现递归和非递归行为之间的更清晰的分离。

MixedModeMutator 类中还定义了一些受保护的成员变量和函数。其中,pre_ 表示是否为预处理模式。

VisitExpr 函数被声明为final,以保留数据流区域重写的调用扩展。它还重载了多个版本的 VisitExpr_ 函数,用于处理不同类型的节点。

DispatchVisitExpr 函数是一个虚拟函数,用于分发访问表达式节点的操作。

Rewrite_ 函数是用户应该重写的虚函数,用于实现他们的传递。这些重写函数应该能够仅使用原始节点 pre 的数据以及具有修改输入的相同节点 post 进行重写,并且不应递归。

VisitLeafCheckVisited 是受保护的虚函数,用于在叶子节点上进行处理和检查是否已访问。

ExprRewriter#

ExprRewriter 类是非迭代式的表达式重写器。

ExprRewriter 提供了重写接口,用于以后序 DFS 顺序修改图。预期是,ExprRewriter 对象将被传递给 PostOrderRewrite,它将非递归地展开图并调用重写输入。然后,它将传递原始节点(称为 pre)和使用任何更改的输入重新创建的节点(称为 post)给 ExprRewriter。然后,ExprRewriter 可以使用这两个节点中的信息执行更复杂的图重写。

在私有成员中,它定义了类型为 FType 的静态成员变量 vtable,并通过调用 InitVTable 函数进行初始化。InitVTable 函数返回 lambda 表达式,该表达式调用了 Relay_Expr_Rewriter_Dispatch 宏来设置分派。

在公共成员中,它定义了一个虚析构函数,以及重载的括号运算符 operator(),该运算符调用了 Rewrite 函数。它还定义了一些可以被子类覆盖的虚函数,这些函数不应递归。

最后,它还定义了一些重写的虚函数,这些函数默认不执行任何操作,但可以在子类中被覆盖以执行更复杂的重写逻辑。

PostOrderRewrite#

PostOrderRewrite 函数,它执行对图的非递归后序 DFS 遍历,并在输入被重写后调用 ExprRewriterRewrite 函数。在每次重写调用时,PostOrderRewrite 提供原始节点和具有更改的输入的节点,供 ExprRewriter 使用。

该函数接受两个参数:Expr 类型的 expr,表示要遍历的表达式;ExprRewriter* 类型的 rewriter,表示用于重写的表达式重写器。

函数的返回类型是 Expr,表示经过重写后的表达式。

PostOrderVisit#

PostOrderVisit 函数,它以后序 DFS 顺序递归地访问 IR(中间表示),并对每个节点应用 fvisit 访问者函数。

该函数接受两个参数:Expr 类型的 node,表示要访问的 IR 节点;std::function<void(const Expr&)> 类型的 fvisit,表示要应用的访问者函数。

函数没有返回值。

该函数的具体实现并未给出,但从函数注释中可以了解到它的大致功能和用途。它的作用是按照后序 DFS 顺序递归地访问IR中的每个节点,并对每个节点应用给定的访问者函数 fvisit。由于每个节点只被访问一次,因此可以确保节点的访问是正确且高效的。

ExpandDataflowExpandANormalForm#

ExpandDataflow 是一个模板函数,用于以深度优先顺序遍历一个表达式的 IR(中间表示)数据流区域。它接受四个参数:要遍历的表达式 expr、一个检查节点是否被访问过的函数 fcheck_visited、一个访问叶子节点的函数 fvisit_leaf 以及一个扩展表达式的函数 fexpand_expr

该函数使用一个栈来管理遍历过程中的数据流节点。它首先将输入表达式压入栈中,然后进入循环,直到栈为空为止。在每次迭代中,它从栈顶取出一个节点,并检查该节点是否满足数据流类型。如果满足,则将该节点的子节点压入栈中;如果不满足或者该节点的所有输入都已经被处理过,则调用 fvisit_leaf 函数访问当前叶子节点。

ExpandDataflow 函数通过模板参数 FCheckVisitedFVisitLeafFExpandExpr 来实现重用。这些参数是类型别名,分别对应于检查节点是否被访问过的函数、访问叶子节点的函数和扩展表达式的函数的类型。这样,用户可以根据需要提供不同的实现,以便在不同的场景下进行遍历分析。

ExpandANormalForm 函数是一个辅助函数,用于展开一个正常的 LetNode 表达式。它接受三个参数:要展开的表达式 op、一个在访问 LetNode 之前执行的函数 pre_visit 和一个在访问LetNode之后执行的函数 post_visit

ExpandANormalForm 的作用是在展开表达式之前和之后执行一些额外的操作,例如预处理或后处理。