DFPatternCallback
#
import set_env
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import (
DFPatternCallback, rewrite,
is_constant, is_op, is_tuple, wildcard, is_tuple_get_item
)
替换加法为减法#
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
x = relay.var("x")
w = relay.var("w")
y = relay.var("y")
inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function(
[inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
)
out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)
class BatchnormCallback(DFPatternCallback):
def __init__(self):
super(BatchnormCallback, self).__init__()
self.x = wildcard()
self.var = wildcard()
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
self.eps = is_constant()
self.pattern = (
self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + self.beta
)
def callback(self, pre, post, node_map):
x = node_map[self.x][0]
var = node_map[self.var][0]
mean = node_map[self.mean][0]
beta = node_map[self.beta][0]
gamma = node_map[self.gamma][0]
eps = node_map[self.eps][0]
return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=eps.data.numpy().item())[0]
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN)
assert tvm.ir.structural_equal(
out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
)
DD
print(out)
free_var %x;
free_var %gamma;
free_var %beta;
free_var %mean;
free_var %var;
%0 = nn.batch_norm(%x, %gamma, %beta, %mean, %var);
%0.0
print(BN)
free_var %x;
free_var %mean;
free_var %gamma;
%0 = subtract(%x, %mean);
free_var %var;
%1 = add(%var, 1e-05f);
%2 = multiply(%gamma, %0);
%3 = sqrt(%1);
%4 = divide(%2, %3);
free_var %beta;
add(%4, %beta)