对象属性辅助函数#

源码:tvm/include/tvm/ir/attrs.h & tvm/src/support/ffi_testing.cc

示例:

struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
    float learning_rate;
    int num_hidden;
    String name;
    // 声明属性字段和头文件
    TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
    TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
    TVM_ATTR_FIELD(learning_rate).set_default(0.01f);
    TVM_ATTR_FIELD(name).set_default("hello");
    }
};
// 在 cc 文件中注册
TVM_REGISTER_NODE_TYPE(MyAttrs);

首先,定义了名为 MyAttrs 的结构体,它继承自 tvm::AttrsNode<MyAttrs>。这个结构体包含了三个属性:learning_ratenum_hiddenname。在头文件中,可以使用 TVM_DECLARE_ATTRS 宏来声明这些属性。在这个例子中,声明了三个属性字段:num_hiddenlearning_ratename。每个字段都使用了 TVM_ATTR_FIELD 宏来指定属性的名称和一些选项。

  • 对于 num_hidden 字段,使用 set_lower_bound(1) 来设置它的下界为 1,表示该属性的有效取值范围是大于等于 \(1\) 的整数。

  • 对于 learning_rate 字段,使用 set_default(0.01f) 来设置它的默认值为 0.01。这意味着如果在代码中没有显式地给 learning_rate 赋值,那么它将被初始化为 0.01

  • 对于 name 字段,使用 set_default("hello") 来设置它的默认值为 "hello"。同样地,如果在代码中没有显式地给 name 赋值,那么它将被初始化为 "hello"

在源文件(通常是 C++)中,需要注册这个新的节点类型。使用 TVM_REGISTER_NODE_TYPE(MyAttrs) 宏可以将 MyAttrs 节点类型注册到 TVM 中,使其可以在后续的编译和执行过程中被识别和使用。

总结起来,这段代码演示了如何使用TVM的 AttrsNode 类和相关宏来声明和使用具有默认值和边界检查的命名属性。通过这种方式,可以更方便地管理和使用模型的属性信息。

TVM_DECLARE_ATTRSTVM_ATTR_FIELD#

#define TVM_DECLARE_ATTRS(ClassName, TypeKey)                    \
  static constexpr const char* _type_key = TypeKey;              \
  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
  template <typename FVisit>                                     \
  void _tvm_VisitAttrs(FVisit& _tvm_fvisit)  // NOLINT(*)
#define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)

这段代码定义了两个宏,用于在 TVM 中声明属性函数和属性字段。

第一个宏 TVM_DECLARE_ATTRS 用于声明属性函数。它接受两个参数:ClassName 表示类的名称,TypeKey 表示类型键,用于在 TVM 节点系统中标识该属性的类型。

  • 在宏的定义中,首先使用 static constexpr 关键字将 TypeKey 赋值给了名为 _type_key 的静态常量字符指针。这样做的目的是确保 _type_key 的值在编译时就已经确定,并且只初始化一次。

  • 接下来,通过调用 TVM_DECLARE_FINAL_OBJECT_INFO 宏来声明最终对象信息,其中包含了类名 ClassName 和基类 ::tvm::BaseAttrsNode。这个宏的作用是将类名和基类传递给 TVM 的元编程系统,以便正确地生成派生类的元数据。

  • 最后,通过模板函数 _tvm_VisitAttrs 来实现属性访问的机制。这个函数接受类型为 FVisit 的函数对象作为参数,用于遍历和访问属性。函数体中的注释 // NOLINT(*) 表示编译器在编译时应忽略该函数未被使用的错误警告。

第二个宏 TVM_ATTR_FIELD 用于声明属性字段。它也接受两个参数:FieldName 表示属性的字段名,&FieldName 表示对应的变量。

  • 在宏的定义中,使用之前定义的模板函数 _tvm_fvisit 来处理属性字段的访问。_tvm_fvisit(#FieldName, &FieldName) 将属性字段的名称和对应的变量传递给 _tvm_fvisit 函数。这样,在后续的编译和执行过程中,就可以正确地访问和操作该属性字段了。

备注

#FieldName 中,# 符号是 C++ 中的预处理器指令的开始。它用于指示编译器将紧随其后的文本视为预处理指令。

在这种情况下,#FieldName 被视为宏定义的开始。宏定义是一种在编译之前进行文本替换的技术。通过使用 #FieldName,我们可以定义名为 FieldName 的宏。

在给定的代码片段中,#FieldName 被用作参数传递给宏 TVM_ATTR_FIELD。这意味着当该宏被调用时,FieldName 将被替换为实际的值。

总结起来,#FieldName 是 C++ 中预处理指令的语法,用于定义宏。在这个特定的代码片段中,它被用作参数传递给宏 TVM_ATTR_FIELD,并在宏展开时被替换为相应的值。

总结起来,这段代码定义了两个宏,用于在 TVM 中声明属性函数和属性字段。通过使用这些宏,可以在 C++ 代码中方便地声明和使用带有类型键的属性。

NullValueNullValue<DataType>#

模板函数

template <typename TObjectRef>
inline TObjectRef NullValue() {
  static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
  return TObjectRef(ObjectPtr<Object>(nullptr));
}

template <>
inline DataType NullValue<DataType>() {
  return DataType(DataType::kHandle, 0, 0);
}

这段代码定义了两个模板函数:NullValueNullValue<DataType>

第一个函数 NullValue 是泛型函数,它接受类型参数 TObjectRef。这个函数的作用是创建表示空值的 TObjectRef 类型的对象。在函数内部,首先使用 static_assert 进行编译时断言,确保传入的类型 TObjectRef 是可空类型(即具有 _type_is_nullable 成员变量)。然后,创建 ObjectPtr<Object> 类型的空指针,并将其作为参数传递给 TObjectRef 类型的构造函数,创建表示空值的对象。最后,返回创建的空值对象。

第二个函数 NullValue<DataType> 是特化版本的 NullValue 函数,它专门用于处理 DataType 类型的对象。在这个函数内部,直接返回表示空值的 DataType 类型的对象,其值为 DataType::kHandle, 0, 0

AttrError#

/*! \brief Error thrown during attribute checking. */
struct AttrError : public Error {
  /*!
   * \brief constructor
   * \param msg error message
   */
  explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
};

AttrError 结构体,它继承自 Error 类。AttrError 结构体用于表示在属性检查过程中抛出的错误。

AttrError 结构体中,定义了构造函数,该构造函数接受字符串参数 msg,表示错误信息。构造函数通过调用父类 Error 的构造函数来初始化 AttrError 对象,并将错误信息与字符串 "AttributeError:" 拼接起来。

AttrFieldInfoNode#

AttrFieldInfoNode 类,该类继承自 Object 类。这个类主要用于存储属性字段的信息。

/*!
 * \brief Information about attribute fields in string representations.
 */
class AttrFieldInfoNode : public Object {
 public:
  /*! \brief name of the field */
  String name;
  /*! \brief type docstring information in str. */
  String type_info;
  /*! \brief detailed description of the type */
  String description;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("type_info", &type_info);
    v->Visit("description", &description);
  }

  static constexpr const char* _type_key = "AttrFieldInfo";
  static constexpr bool _type_has_method_sequal_reduce = false;
  static constexpr bool _type_has_method_shash_reduce = false;
  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

AttrFieldInfoNode 类中,定义了三个成员变量:

  • name:表示字段的名称;

  • type_info:表示字段类型的文档字符串信息;

  • description:表示字段类型的详细描述。

此外,还定义了名为 VisitAttrs 的成员函数,该函数接受指向 AttrVisitor 类的指针作为参数。在 VisitAttrs 函数中,调用 AttrVisitor 类的 Visit 方法,将 nametype_infodescription 这三个成员变量的值传递给 Visit 方法。

最后,定义了两个静态常量成员变量:

  • _type_key:表示该类的类型键,值为 "AttrFieldInfo"

  • _type_has_method_sequal_reduce_type_has_method_shash_reduce:这两个静态常量成员变量的值都为 false

此外,还使用 TVM_DECLARE_FINAL_OBJECT_INFO 宏声明了 AttrFieldInfoNode 类的对象信息。

AttrFieldInfo#

AttrFieldInfo 类,该类继承自 ObjectRef 类。

/*! \brief AttrFieldInfo */
class AttrFieldInfo : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};

AttrFieldInfo 类中,使用 TVM_DEFINE_OBJECT_REF_METHODS 宏来定义一些方法,这些方法用于操作 AttrFieldInfo 对象和 ObjectRef 对象之间的关联关系。具体来说, TVM_DEFINE_OBJECT_REF_METHODS 宏将 AttrFieldInfo 类的方法映射到 ObjectRef 类的方法上,从而实现了 AttrFieldInfo 对象与 ObjectRef 对象的相互操作。

BaseAttrsNode#

BaseAttrsNode 的基类,它继承自 Object 类。这个基类主要用于表示所有属性类的基类。

/*!
 * \brief Base class of all attribute class
 * \note Do not subclass AttrBaseNode directly,
 *       subclass AttrsNode instead.
 * \sa AttrsNode
 */
class BaseAttrsNode : public Object {
 public:
  using TVMArgs = runtime::TVMArgs;
  using TVMRetValue = runtime::TVMRetValue;
  /*! \brief virtual destructor */
  virtual ~BaseAttrsNode() {}
  // visit function
  virtual void VisitAttrs(AttrVisitor* v) {}
  /*!
   * \brief Initialize the attributes by sequence of arguments
   * \param args The positional arguments in the form
   *        [key0, value0, key1, value1, ..., key_n, value_n]
   */
  template <typename... Args>
  inline void InitBySeq(Args&&... args);
  /*!
   * \brief Print readible docstring to ostream, add newline.
   * \param os the stream to print the docstring to.
   */
  inline void PrintDocString(std::ostream& os) const;  // NOLINT(*)
  /*!
   * \brief Visit attributes that do not equal the default value.
   *
   * \note This is useful to extract fields for concise printing.
   * \param v The visitor
   */
  TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
  /*!
   * \brief Get the field information
   * \return The fields in the Attrs.
   */
  TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
  /*!
   * \brief Initialize the attributes by arguments.
   * \param kwargs The key value pairs for initialization.
   *        [key0, value0, key1, value1, ..., key_n, value_n]
   * \param allow_unknown Whether allow additional unknown fields.
   * \note This function throws when the required field is not present.
   */
  TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;

  static constexpr const bool _type_has_method_sequal_reduce = true;
  static constexpr const bool _type_has_method_shash_reduce = true;
  static constexpr const char* _type_key = "Attrs";
  TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};

BaseAttrsNode 类中,定义了一些类型别名和虚函数:

  • TVMArgsTVMRetValue 分别表示 TVM 的参数和返回值类型;

  • VisitAttrs 虚函数,用于访问属性;

  • InitBySeq 模板函数,用于通过序列化的方式初始化属性;

  • PrintDocString 虚函数,用于打印可读的文档字符串;

  • VisitNonDefaultAttrs 纯虚函数,用于访问不等于默认值的属性;

  • ListFieldInfo 纯虚函数,用于获取属性字段信息;

  • InitByPackedArgs 纯虚函数,用于通过键值对的方式初始化属性。

此外,还定义了一些静态常量和宏:

  • _type_has_method_sequal_reduce_type_has_method_shash_reduce 表示该类型是否具有相等和哈希减少的方法;

  • _type_key 表示该类型的键;

  • TVM_DECLARE_BASE_OBJECT_INFO 宏用于声明基对象信息。

Attrs#

/*!
 * \brief Managed reference to BaseAttrsNode.
 * \sa AttrsNode, BaseAttrsNode
 */
class Attrs : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
};

Attrs 的类,该类继承自 ObjectRef 类。在 Attrs 类中,使用 TVM_DEFINE_OBJECT_REF_METHODS 宏来定义一些方法,这些方法用于操作 Attrs 对象和 ObjectRef 对象之间的关联关系。具体来说,TVM_DEFINE_OBJECT_REF_METHODS 宏将 Attrs 类的方法映射到 ObjectRef 类的方法上,从而实现了 Attrs 对象与 ObjectRef 对象的相互操作。

TestAttrs 示例#

// Attrs used to python API
struct TestAttrs : public AttrsNode<TestAttrs> {
  int axis;
  String name;
  Array<PrimExpr> padding;
  TypedEnvFunc<int(int)> func;

  TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
    TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe(
        "axis field");
    TVM_ATTR_FIELD(name).describe("name");
    TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array<PrimExpr>({0, 0}));
    TVM_ATTR_FIELD(func)
        .describe("some random env function")
        .set_default(TypedEnvFunc<int(int)>(nullptr));
  }
};

TVM_REGISTER_NODE_TYPE(TestAttrs);

这段代码定义了名为 TestAttrs 的结构体,它继承自 AttrsNode<TestAttrs>。这个结构体包含以下成员变量:

  1. int axis:整数类型的变量,表示轴。

  2. String name:字符串类型的变量,表示名称。

  3. Array<PrimExpr> paddingPrimExpr 类型的数组,表示输入的填充。

  4. TypedEnvFunc<int(int)> func:类型为 TypedEnvFunc<int(int)> 的函数对象,表示一些随机的环境函数。

TVM_DECLARE_ATTRS 宏中,为这些成员变量设置了默认值、边界值和描述信息。例如,axis 字段的默认值为 10,下界为 1,上界为 10name 字段没有设置默认值,也没有描述信息;padding 字段的默认值为 {0, 0}func 字段的默认值为 nullptr,并描述了这是随机的环境函数。

import set_env
import tvm
import pytest
with pytest.raises(AttributeError):
    x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")

with pytest.raises(AttributeError):
    x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10

DictAttrsDictAttrsNode#

/*!
 * \brief Specialized attribute type that is backed by a map.
 *  The DictAttrsNode implements the Attrs behavior,
 *  its fields are directly accessible via object.field_name
 *  like other normal nodes.
 */
class DictAttrsNode : public BaseAttrsNode {
 public:
  /*! \brief internal attrs map */
  Map<String, ObjectRef> dict;

  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
    return equal(dict, other->dict);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }

  // implementations
  void VisitAttrs(AttrVisitor* v) final;
  void VisitNonDefaultAttrs(AttrVisitor* v) final;
  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
  Array<AttrFieldInfo> ListFieldInfo() const final;

  // type info
  static constexpr const char* _type_key = "DictAttrs";
  TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};

/*!
 * \brief Managed reference to DictAttrsNode
 * \sa DictAttrsNode.
 */
class DictAttrs : public Attrs {
 public:
  /*!
   * \brief Consruct a Attrs backed by DictAttrsNode.
   * \param dict The attributes.
   */
  TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);

  // Utils for accessing attributes
  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
  // value if DictAttrsNode is not defined.
  /*!
   * \brief Get a function attribute.
   *
   * \param attr_key The attribute key.
   * \param default_value The default value if the key does not exist, defaults to nullptr.
   *
   * \return The result
   *
   * \tparam TOBjectRef the expected object type.
   * \throw Error if the key exists but the value does not match TObjectRef
   *
   * \code
   *
   *  void GetAttrExample(const BaseFunc& f) {
   *    auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
   *  }
   *
   * \endcode
   */
  template <typename TObjectRef>
  Optional<TObjectRef> GetAttr(
      const std::string& attr_key,
      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
                  "Can only call GetAttr with ObjectRef types.");
    if (!defined()) return default_value;
    const DictAttrsNode* node = this->as<DictAttrsNode>();

    auto it = node->dict.find(attr_key);
    if (it != node->dict.end()) {
      return Downcast<Optional<TObjectRef>>((*it).second);
    } else {
      return default_value;
    }
  }
  // variant that uses TObjectRef to enable implicit conversion to default value.
  template <typename TObjectRef>
  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
  }
  /*!
   * \brief Check whether the function has an non-zero integer attr.
   *
   * This function can be used to check whether an optional
   * attribute mark(e.g. inline) exists.
   *
   * \param attr_key The key to the attribute.
   * \return The check result.
   *
   * \code
   *
   *  void HasNonzeroAttrExample(const BaseFunc& f) {
   *    if (f->HasNonzeroAttr(attr::kInline)) {
   *      // inline the function.
   *    }
   *  }
   *
   * \endcode
   */
  bool HasNonzeroAttr(const std::string& attr_key) const {
    return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
  }

  TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

这段代码定义了两个类:DictAttrsNodeDictAttrs

DictAttrsNode 是特殊的属性类型,它由映射(map)支持。这个映射的键是字符串,值是对象引用(ObjectRef)。这个类实现了 BaseAttrsNode 的属性行为,它的字段可以直接通过 object.field_name 访问,就像其他普通的节点一样。

DictAttrs 是管理 DictAttrsNode 引用的类。它有构造函数,接受映射作为参数,这个映射就是属性。此外,它还提供了一些实用的方法来访问这些属性。例如,GetAttr 方法可以获取特定的属性,如果该属性不存在,则返回默认值。HasNonzeroAttr 方法可以检查特定的属性是否存在并且其值是否不为零。

总的来说,这段代码提供了灵活的方式来存储和管理属性,特别是那些需要动态访问的属性。

dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0, 0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
assert len(dattr) == 4
assert len([x for x in dattr.keys()]) == 4
assert len(dattr.items()) == 4

AttrsWithDefaultValuesWithAttr#

/*!
 * \brief Create an Attr object with all default values.
 * \tparam TAttrNode the type to be created.
 * \return A instance that will represent None.
 */
template <typename TAttrs>
inline TAttrs AttrsWithDefaultValues() {
  static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
  auto n = make_object<typename TAttrs::ContainerType>();
  n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
  return TAttrs(n);
}

/*!
 * \brief Copy the function or module, but overrides
 *        the attribute value key with the value.
 *
 * \param input The thing to annotate (BaseFunc or IRModule)
 * \param attr_key The attribute key.
 * \param attr_value The value attribute value.
 *
 * \tparam TFunc The corresponding function or module type.
 *
 * \returns The new function or module with updated attributes.
 *
 * \note This function performs copy on write optimization for func and module.
 *       If we move a uniquely referenced func or module into WithAttr,
 *       then no additional copy will be performed.
 *
 *       This is also why we make it as a function instead of a member function
 *       and why we pass by value in the first argument.
 *
 * \code
 *
 *  // Recommended way to trigger copy on write
 *  func = WithAttr(std::move(func), "key1", value1);
 *  func = WithAttr(std::move(func), "key2", value2);
 *
 * \endcode
 */
template <typename TFunc>
inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
  using TNode = typename TFunc::ContainerType;
  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
  TNode* node = input.CopyOnWrite();
  if (node->attrs.defined()) {
    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
  } else {
    Map<String, ObjectRef> dict = {{attr_key, attr_value}};
    node->attrs = DictAttrs(dict);
  }
  return input;
}

这段代码主要包含两个模板函数:AttrsWithDefaultValuesWithAttr

  1. AttrsWithDefaultValues 函数用于创建具有默认值的 TAttrs 对象。它首先检查传入的类型是否为 Attrs 的子类,然后创建新的 TAttrs::ContainerType 对象,并使用 InitByPackedArgs 方法初始化其属性。最后,返回包含这些属性的新 TAttrs 对象。\

  2. WithAttr 函数用于复制函数或模块,并覆盖其属性值键与值。它接受三个参数:要注释的对象(BaseFuncIRModule)、属性键和属性值。函数首先确保操作的是叶子节点,然后对输入进行写时复制。接下来,如果节点已经定义了属性,它将在属性字典中设置新的键值对;否则,它将创建一个新的属性字典并将其设置为节点的属性。最后,返回输入对象。

这两个函数都使用了 C++ 的模板编程特性,允许它们处理不同类型的对象和属性。

WithAttrsWithoutAttr#

/*!
 * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs.
 *
 * \param input The thing to annotate (BaseFunc or IRModule)
 * \param attrs Key/values attributes to add to \p input.
 *
 * \tparam TFunc The corresponding function or module type.
 *
 * \returns The new function or module with updated attributes.
 */
template <typename TFunc>
inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
  using TNode = typename TFunc::ContainerType;
  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
  TNode* node = input.CopyOnWrite();
  if (node->attrs.defined()) {
    for (const auto& pair : attrs) {
      node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
    }
  } else {
    node->attrs = DictAttrs(std::move(attrs));
  }
  return input;
}

/*!
 * \brief Copy the function or module, but removes the specified
 *        attribute.
 *
 * \param input The thing to annotate (BaseFunc or IRModule)
 * \param attr_key The attribute key.
 *
 * \tparam TFunc The corresponding function or module type.
 *
 * \returns The new function or module with removed attribute.
 *
 * \note This function performs copy on write optimization for func and module.
 *       If we move a uniquely referenced func or module into WithoutAttr,
 *       then no additional copy will be performed.
 *
 *       This is also why we make it as a function instead of a member function
 *       and why we pass by value in the first argument.
 *
 * \code
 *
 *  // Recommended way to trigger copy on write
 *  func = WithoutAttr(std::move(func), "key1");
 *  func = WithoutAttr(std::move(func), "key2");
 *
 * \endcode
 */
template <typename TFunc>
inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
  using TNode = typename TFunc::ContainerType;
  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

  if (input->attrs.defined()) {
    TNode* node = input.CopyOnWrite();
    node->attrs.CopyOnWrite()->dict.erase(attr_key);
    if (node->attrs->dict.size() == 0) {
      node->attrs = NullValue<DictAttrs>();
    }
  }
  return input;
}

这段代码定义了两个模板函数:WithAttrsWithoutAttr。这两个函数都用于处理函数或模块的属性,但它们的操作方式有所不同。

  1. WithAttrs 函数接受输入参数 input 和属性映射 attrs,然后将这些属性添加到 input 中。如果 input 已经定义了属性,那么它会遍历 attrs 中的每个键值对,并将它们添加到 input 的属性字典中。如果 input 没有定义属性,那么它会创建一个新的属性字典,并将 attrs 中的所有键值对添加到这个新的字典中。最后,函数返回更新后的 input

  2. WithoutAttr函数接受输入参数 input 和属性键 attr_key,然后从 input 中删除指定的属性。如果 input 已经定义了属性,那么它会复制 input,然后在复制的属性字典中删除指定的键。如果删除后的属性字典为空,那么它会将 input 的属性设置为 NullValue<DictAttrs>。最后,函数返回更新后的 input

这两个函数都使用了 C++ 的模板编程特性,允许它们处理不同类型的对象和属性。同时,它们都使用了写时复制(CopyOnWrite)优化,以提高性能。

属性实现细节#

AttrNopEntry#

// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
  using TSelf = AttrNopEntry;

  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
  template <typename T>
  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
    return *this;
  }
  template <typename T>
  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
    return *this;
  }
  template <typename T>
  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
    return *this;
  }
};

AttrNopEntry 结构体,它没有在 set_defaultbounddescribe 方法中执行任何操作。这个结构体主要用于作为占位符,当需要实现这些方法但不需要实际功能时使用。

AttrNopEntry 结构体包含以下成员函数:

  1. describe:接受 const char* 类型的参数 str,并返回 TSelf& 类型的引用。这个方法什么都不做,直接返回当前对象的引用。

  2. set_default:接受泛型类型 T 的参数 value,并返回 TSelf& 类型的引用。这个方法什么都不做,直接返回当前对象的引用。

  3. set_lower_bound:接受泛型类型 T 的参数 begin,并返回 TSelf& 类型的引用。这个方法什么都不做,直接返回当前对象的引用。

  4. set_upper_bound:接受泛型类型 T 的参数 end,并返回 TSelf& 类型的引用。这个方法什么都不做,直接返回当前对象的引用。

这个结构体的主要作用是提供空操作的占位符,以便在需要实现这些方法但不需要实际功能的情况下使用。

AttrNormalVisitor#

// Wrapper for normal visitor.
class AttrNormalVisitor {
 public:
  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
  template <typename T>
  AttrNopEntry operator()(const char* key, T* value) {
    visitor_->Visit(key, value);
    return AttrNopEntry();
  }

 private:
  AttrVisitor* visitor_;
};

这段代码定义了名为 AttrNormalVisitor 的类,它是对普通访问者(AttrVisitor)的包装。这个类的主要目的是在不改变原有访问者接口的情况下,为其添加一些额外的功能。

AttrNormalVisitor 类有构造函数,接受指向 AttrVisitor 对象的指针作为参数,并将其存储在私有成员变量 visitor_ 中。这样,AttrNormalVisitor 对象就可以通过调用 visitor_ 来访问原始的访问者对象。

此外,AttrNormalVisitor 类还重载了 operator(),使其可以像函数一样被调用。这个重载版本的 operator() 接受两个模板参数:const char* 类型的键和泛型类型 T 的值。在函数体中,它首先调用 visitor_->Visit(key, value),将键和值传递给原始的访问者对象。然后,它返回 AttrNopEntry 对象。

总之,这段代码定义了 AttrNormalVisitor 类,用于包装普通的访问者对象,并为其添加一些额外的功能。

AttrsSEqualVisitorAttrsSHashVisitor#

class AttrsSEqualVisitor {
 public:
  bool result_{true};
  // constructor
  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
      : lhs_(lhs), rhs_(rhs), equal_(equal) {}
  template <typename T>
  AttrNopEntry operator()(const char* key, T* lhs_value) {
    if (!result_) return AttrNopEntry();
    const T* rhs_value = reinterpret_cast<const T*>(
        reinterpret_cast<const char*>(rhs_) +
        (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
    if (!equal_(*lhs_value, *rhs_value)) {
      result_ = false;
    }
    return AttrNopEntry();
  }

 private:
  const Object* lhs_;
  const Object* rhs_;
  const SEqualReducer& equal_;
};

class AttrsSHashVisitor {
 public:
  explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}

  template <typename T>
  AttrNopEntry operator()(const char* key, T* value) {
    hash_reducer_(*value);
    return AttrNopEntry();
  }

 private:
  const SHashReducer& hash_reducer_;
};
  1. AttrsSEqualVisitor 类用于比较两个对象(lhs_rhs_)的属性是否相等。它接受 Object 指针作为左操作数,Object 指针作为右操作数,以及 SEqualReducer 对象来执行实际的比较操作。在构造函数中,它将这三个参数存储为私有成员变量。然后,它重载了 operator() 函数,该函数接受一个键和一个值,并使用 SEqualReducer 对象来比较这两个值。如果它们不相等,result_ 成员变量将被设置为false。最后,operator() 函数返回 AttrNopEntry 对象。

  2. AttrsSHashVisitor 类用于计算对象的哈希值。它接受 SHashReducer 对象作为参数,并在构造函数中将其存储为私有成员变量。然后,它重载了 operator() 函数,该函数接受一个键和一个值,并使用 SHashReducer 对象来计算这个值的哈希值。最后,operator() 函数返回 AttrNopEntry 对象。

dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
assert tvm.ir.structural_equal(dattr0, dattr1)
assert not tvm.ir.structural_equal(dattr0, dattr2)
assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))

AttrInitEntry#

// helper entry that does initialization, set default.
template <typename T>
struct AttrInitEntry {
  // The attributes
  using TSelf = AttrInitEntry<T>;
  // The type key
  const char* type_key_;
  // field name
  const char* key_;
  // internal value.
  T* value_;
  // whether the value is missing.
  // NOTE: initialize to false so that the destructor does not throw unless
  // AttrInitVisitor::operator() is committed to returning an instance of this class.
  // It is expected not to set this to true until that is true.
  bool value_missing_{false};

  AttrInitEntry() = default;

  AttrInitEntry(AttrInitEntry&& other) {
    type_key_ = other.type_key_;
    key_ = other.key_;
    value_ = other.value_;
    value_missing_ = other.value_missing_;
    // avoid unexpected throw
    other.value_missing_ = false;
  }

  // If the value is still missing in destruction time throw an error.
  ~AttrInitEntry() DMLC_THROW_EXCEPTION {
    if (value_missing_) {
      std::ostringstream os;
      os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
         << "If the key is defined check that its type matches the declared type.";
      throw AttrError(os.str());
    }
  }
  // override fields.
  // This function sets the lower bound of the attribute
  TSelf& set_lower_bound(const T& begin) {
    if (this->value_missing_) return *this;
    const T& val = *value_;
    if (begin > val) {
      std::ostringstream os;
      os << type_key_ << "." << key_ << ": "
         << "value " << val << " is smaller than the lower bound " << begin;
      throw AttrError(os.str());
    }
    return *this;
  }
  // This function sets the upper bound of the attribute
  TSelf& set_upper_bound(const T& end) {
    if (this->value_missing_) return *this;
    const T& val = *value_;
    if (val > end) {
      std::ostringstream os;
      os << type_key_ << "." << key_ << ": "
         << "value " << val << " is bigger than the upper bound " << end;
      throw AttrError(os.str());
    }
    return *this;
  }
  // set default when
  TSelf& set_default(const T& value) {
    if (!value_missing_) return *this;
    *value_ = value;
    value_missing_ = false;
    return *this;
  }
  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
};

这段代码定义了模板结构体 AttrInitEntry,用于初始化和设置属性的默认值。这个结构体包含以下成员:

  1. type_key_:类型键,用于标识属性的类型。

  2. key_:字段名,用于标识属性的名称。

  3. value_:内部值,指向属性的值的指针。

  4. value_missing_:一个布尔值,表示属性值是否缺失。在构造函数中初始化为 false,表示属性值已经存在。如果在析构时仍然缺失,则抛出异常。

结构体还提供了以下成员函数:

  1. 构造函数:接受 AttrInitEntry 对象作为右值引用,将其他对象的属性复制到当前对象。

  2. 拷贝构造函数:实现深拷贝,避免意外抛出异常。

  3. 析构函数:在析构时检查 value_missing_ 是否为 true,如果是,则抛出异常。

  4. set_lower_bound():设置属性的下界,如果属性值小于下界,则抛出异常。

  5. set_upper_bound():设置属性的上界,如果属性值大于上界,则抛出异常。

  6. set_default():设置属性的默认值,如果属性值缺失,则设置为给定的默认值。

  7. describe():一个虚函数,用于描述属性。在这个实现中,它什么都不做,直接返回 *this

SetValueSetIntValue#

这段代码是 C++ 模板函数的实现,用于将不同类型的 TVMArgValue 对象转换为其他类型的值。这些转换包括从 Expr 类型到常量、从 int 类型到 DataType 类型、从 std::string 类型到字符串等。

  1. SetValue 函数是一个通用模板函数,用于将 TVMArgValue 对象转换为指定类型的指针所指向的值。它首先尝试使用 operator T() 进行转换,如果失败则尝试使用 static_cast<T>() 进行转换。

  2. SetIntValue 函数是专门用于处理整数类型的函数。它首先检查 val 的类型是否为 kDLInt,如果是,则直接将其转换为T类型并赋值给指针;否则,它会创建 IntImm 表达式,并将其值转换为 T 类型后赋值给指针。

  3. 接下来的几个模板特化函数是对特定类型的特殊处理。例如,对于 DataTypestd::string 类型,它们分别使用 operator DataType()operator std::string() 进行转换。对于 double 类型,它首先检查 val 的类型是否为 kDLFloatkDLInt,如果是,则直接将其转换为 double 类型并赋值给指针;否则,它会创建 ObjectRef 表达式,并根据表达式的类型选择相应的转换方式。

  4. 最后,对于 intint64_tuint64_tbool类型,它们都调用了 SetIntValue 函数进行处理。

AttrInitVisitor#

// Visitor for value initialization
template <typename FFind>
class AttrInitVisitor {
 public:
  // Counter of number of matched attributes during visit.
  // This is used to decide if there is additional unmatched attributes.
  size_t hit_count_{0};
  // constructor
  AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}

  template <typename T>
  AttrInitEntry<T> operator()(const char* key, T* value) {
    TVMArgValue val;
    AttrInitEntry<T> opt;
    opt.type_key_ = type_key_;
    opt.key_ = key;
    opt.value_ = value;
    if (ffind_(key, &val)) {
      SetValue(value, val);
      opt.value_missing_ = false;
      ++hit_count_;
    } else {
      opt.value_missing_ = true;
    }
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Wpessimizing-move"
#endif
    return std::move(opt);
  }

 private:
  // the type key
  const char* type_key_;
  FFind ffind_;
};

template <typename FFind>
inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
  return AttrInitVisitor<FFind>(type_key, ffind);
}

这段代码定义了一个名为 AttrInitVisitor 的模板类,用于在访问过程中初始化值。该类接受一个类型参数FFind,并包含以下成员:

  1. hit_count_:匹配属性的数量计数器,用于判断是否有未匹配的属性。

  2. 构造函数:接受一个类型键(type_key)和一个查找函数(ffind),并将它们分别赋值给成员变量type_key_ffind_

  3. 重载的调用运算符operator():接受一个键(key)和一个指向值的指针(value),并返回一个AttrInitEntry<T>类型的对象。在这个函数中,首先创建一个TVMArgValue类型的变量val和一个AttrInitEntry<T>类型的变量opt。然后,将type_key_keyvalue分别赋值给opt的成员变量。接下来,使用ffind_函数查找键对应的值,如果找到了,就调用SetValue函数设置值,并将opt.value_missing_设置为false,同时增加hit_count_的值;如果没有找到,就将opt.value_missing_设置为true。最后,返回移动后的opt对象。

  4. 私有成员变量:type_key_ffind_,分别表示类型键和查找函数。

此外,还定义了一个名为 CreateInitVisitor 的内联函数,用于创建 AttrInitVisitor 对象。这个函数接受一个类型键(type_key)和一个查找函数(ffind),并返回一个AttrInitVisitor<FFind>类型的对象。

其他属性模板函数#

这段代码定义了一些模板类和函数,用于处理不同类型的属性文档。以下是对每个部分的简要解释:

  1. TypeName 结构体模板:这个模板用于获取已知于 tvm 的类型名称。它接受一个类型参数 T,并返回一个指向常量字符指针的静态成员变量 value。对于一些特定的类型(如 int、int64_t、uint64_t、DataType、std::string、bool、void* 和 double),已经为它们提供了特化版本。

  2. AttrDocEntry 类:这个类表示属性文档条目。它有一个构造函数,接受一个 ObjectPtr 类型的参数 info,并将其存储在私有成员变量 info_ 中。此外,它还提供了一些方法,如 describe、set_default、set_lower_bound 和 set_upper_bound,用于设置属性文档条目的属性。

  3. AttrDocVisitor 类:这个类是一个访问者模式的实现,用于遍历属性文档条目。它有一个模板方法 operator(),接受一个键和一个值,然后创建一个 AttrFieldInfoNode 对象,并将其添加到 fields_ 数组中。最后,它返回一个 AttrDocEntry 对象。

  4. AttrExistVisitor 类:这个类也实现了访问者模式,用于检查属性是否存在。它有一个模板方法 operator(),接受一个键和一个值,然后检查 key 是否等于已存在的键。如果找到匹配的键,它将 exist_ 设置为 true。

  5. AttrTriggerNonDefaultEntry 结构体模板:这个模板用于触发非默认属性的访问。它接受一个 AttrVisitor 指针、一个键和一个数据指针作为参数。当触发器为 true 时,它会调用 visitor_ 的 Visit 方法。此外,它还提供了一些方法,如 describe、set_default、set_lower_bound 和 set_upper_bound,用于设置属性文档条目的属性。

  6. AttrNonDefaultVisitor 类:这个类是 AttrTriggerNonDefaultEntry 的一个特化版本,用于非默认属性的访问。它接受一个 AttrVisitor 指针作为参数,并在调用操作符 () 时返回一个 AttrTriggerNonDefaultEntry 对象。

AttrsNodeBaseAttrsNode#

这段代码定义了一个名为 AttrsNode 的模板类,它是所有属性节点的基类。这个类使用了一种被称为“递归模板模式”的技巧,通过模板参数DerivedType来表示最终的属性类型。

AttrsNode类提供了以下方法:

  1. VisitAttrs(AttrVisitor* v):访问属性并调用相应的访问者函数。

  2. VisitNonDefaultAttrs(AttrVisitor* v):访问非默认属性并调用相应的访问者函数。

  3. InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown):通过打包的参数初始化属性。它首先检查参数的数量是否为偶数,然后根据参数数量选择线性搜索或构造映射进行查找。如果允许未知属性且未找到匹配项,则抛出异常。

  4. SEqualReduce(const DerivedType* other, SEqualReducer equal) const:使用给定的相等比较器对当前对象和另一个派生类型的对象进行相等性比较。

  5. SHashReduce(SHashReducer hash_reducer) const:使用给定的哈希减少器对当前对象的哈希值进行计算。

  6. ListFieldInfo() const final:列出当前对象的所有属性字段信息。

  7. PrintDocString(std::ostream& os) const:将当前对象的属性文档字符串打印到输出流中。

此外,还提供了一个名为 BaseAttrsNode 的友元类,它提供了两个静态方法InitBySeqPrintDocStringInitBySeq方法接受一系列参数,并将它们传递给InitByPackedArgs方法进行初始化。PrintDocString方法用于打印当前对象的属性文档字符串。