DFPatternCallback

DFPatternCallback#

参考:DFPatternCallback

import set_env
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import (
    DFPatternCallback, rewrite,
    is_constant, is_op, is_tuple, wildcard, is_tuple_get_item
)

替换加法为减法#

x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super(TestRewrite, self).__init__()
        self.pattern = add_pattern

    def callback(self, pre, post, node_map):
        return post.args[0] - post.args[1]

out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
x = relay.var("x")
w = relay.var("w")
y = relay.var("y")
inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function(
    [inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
)
out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)
class BatchnormCallback(DFPatternCallback):
    def __init__(self):
        super(BatchnormCallback, self).__init__()
        self.x = wildcard()
        self.var = wildcard()
        self.mean = wildcard()
        self.beta = wildcard()
        self.gamma = wildcard()
        self.eps = is_constant()

        self.pattern = (
            self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta
        )

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        var = node_map[self.var][0]
        mean = node_map[self.mean][0]
        beta = node_map[self.beta][0]
        gamma = node_map[self.gamma][0]
        eps = node_map[self.eps][0]
        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.numpy().item())[0]
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")

BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta

out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
    out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)
DD
print(out)
free_var %x;
free_var %gamma;
free_var %beta;
free_var %mean;
free_var %var;
%0 = nn.batch_norm(%x, %gamma, %beta, %mean, %var);
%0.0
print(BN)
free_var %x;
free_var %mean;
free_var %gamma;
%0 = subtract(%x, %mean);
free_var %var;
%1 = add(%var, 1e-05f);
%2 = multiply(%gamma, %0);
%3 = sqrt(%1);
%4 = divide(%2, %3);
free_var %beta;
add(%4, %beta)