Relay Pass Instrument

Relay Pass Instrument#

import tvm
from tvm import relay

@tvm.instrument.pass_instrument
class PassCounter:
    def __init__(self):
        # Just setting a garbage value to test set_up callback
        self.counts = 1234

    def enter_pass_ctx(self):
        self.counts = 0

    def exit_pass_ctx(self):
        self.counts = 0

    def run_before_pass(self, module, info):
        self.counts += 1

    def get_counts(self):
        return self.counts


def test_print_debug_callback():
    shape = (1, 2, 3)
    tp = relay.TensorType(shape, "float32")
    x = relay.var("x", tp)
    y = relay.add(x, x)
    y = relay.multiply(y, relay.const(2, "float32"))
    func = relay.Function([x], y)

    seq = tvm.transform.Sequential(
        [
            relay.transform.InferType(),
            relay.transform.FoldConstant(),
            relay.transform.DeadCodeElimination(),
        ]
    )

    mod = tvm.IRModule({"main": func})

    pass_counter = PassCounter()
    with tvm.transform.PassContext(opt_level=3, instruments=[pass_counter]):
        # Should be reseted when entering pass context
        assert pass_counter.get_counts() == 0
        mod = seq(mod)

        # TODO(@jroesch): when we remove new fn pass behavior we need to remove
        # change this back to match correct behavior
        assert pass_counter.get_counts() == 6

    # Should be cleanned up after exiting pass context
    assert pass_counter.get_counts() == 0