TVM Pass Instrument

TVM Pass Instrument#

参考:如何使用 TVM Pass Instrument

import tvm
import tvm.relay as relay
from tvm.relay.testing import resnet
from tvm.contrib.download import download_testdata
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
    PassTimingInstrument,
    pass_instrument,
)
batch_size = 1
num_of_image_class = 1000
image_shape = (3, 224, 224)
output_shape = (batch_size, num_of_image_class)
relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)
timing_inst = PassTimingInstrument()
with tvm.transform.PassContext(instruments=[timing_inst]):
    relay_mod = relay.transform.InferType()(relay_mod)
    relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
    # 在退出上下文之前,获取 profile 结果。
    profiles = timing_inst.render()
print("Printing results of timing profile...")
print(profiles)
Printing results of timing profile...
InferType: 11228us [11228us] (53.85%; 53.85%)
FoldScaleAxis: 9621us [7us] (46.15%; 46.15%)
	FoldConstant: 9614us [2007us] (46.11%; 99.92%)
		InferType: 7607us [7607us] (36.49%; 79.13%)