op IR#

import set_env
import tvm
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.op import op as _op

op 属性#

属性访问:

log_op = relay.op.get("log")
assert log_op.num_inputs == 1

注册 op 属性函数:

@tvm.ir.register_op_attr("exp", "ftest")
def test(x):
    return x + 1

assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2

重置属性函数:

def add1(x):
        return x + 1

def add2(x):
    return x + 2

# 注册 fadd1 和 fadd2 属性
tvm.ir.register_op_attr("exp", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd1", add1)
tvm.ir.register_op_attr("log", "fadd2", add2)
<function __main__.add2(x)>

重置 log 属性函数:

log_op = relay.op.get("log")
log_op.reset_attr("fadd1")
# 检查 fadd1 属性是否已重置。
assert log_op.get_attr("fadd1") is None
# 检查其他算子的 fadd1 属性是否完好无损。
assert relay.op.get("exp").get_attr("fadd1")(1) == 2
# 检查 log 算子的其他属性是否完好无损。
assert relay.op.get("log").get_attr("fadd2")(1) == 3

op 临时属性#

def add1(x):
    return x + 1

def add2(x):
    return x + 2

# 将原始 attr 值设置为add1。
tvm.ir.register_op_attr("sqrt", "ftest", add1)

with TempOpAttr("sqrt", "ftest", add2):
    # 检查 attr 值是否已更新为 add2。
    assert relay.op.get("sqrt").get_attr("ftest")(1) == 3

# 检查 attr 值是否已恢复为 add1。
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2

op 注册#

op_name = "custom_op"

_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
_op.get(op_name).set_num_inputs(2)
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
# 调用默认关系函数
_op.get(op_name).add_type_rel("Identity")
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
_op.register_stateful(op_name, False)

assert _op.get(op_name).name == op_name
assert _op.get(op_name).num_inputs == 2
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
assert _op.get(op_name).get_attr("TOpIsStateful") == False
_op.register??
Signature: _op.register(op_name, describe='')
Source:   
def register(op_name, describe=""):
    """Get the Op for a given name.
    when the op_name is not registered, create a new empty op with the given name.
    when the op_name has been registered, abort with an error message.

    Parameters
    ----------
    op_name : str
        The operator name

    describe : Optional[str]
        The operator description
    """

    tvm.ir._ffi_api.RegisterOp(op_name, describe)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py
Type:      function