测试 Tedd

测试 Tedd#

import numpy as np
from IPython.display import display_svg
from tvm import te, build, lower
from tvm_book.testing.relay.viz import graphviz_relay
from tvm.contrib import tedd
A = te.placeholder((1,), name="A")
B = te.placeholder((1,), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
sch = te.create_schedule(C.op)
ir_mod = lower(sch, [A, B, C], name="test_add")
rt_mod = build(ir_mod, target="llvm")
func = te.create_prim_func([A, B, C])
func.show()
# from tvm.script import tir as T


@T.prim_func
def main(
    A: T.Buffer((1,), "float32"),
    B: T.Buffer((1,), "float32"),
    C: T.Buffer((1,), "float32"),
):
    T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
    # with T.block("root"):
    for i in range(1):
        with T.block("C"):
            v_i = T.axis.spatial(1, i)
            T.reads(A[v_i], B[v_i])
            T.writes(C[v_i])
            C[v_i] = A[v_i] + B[v_i]
graph = tedd.viz_dataflow_graph(sch, show_svg=True)
display_svg(graph)
../../../_images/41180b72fe244555a8eb476d0cfe6d50508ebb54778f96bd1522be78cc7e3da2.svg
sch = sch.normalize()
tree = tedd.viz_schedule_tree(sch, show_svg=True)

# tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree.dot")
display_svg(tree)
../../../_images/db1639731e5b94b13869ee223b23df58982567889291277a88d72f1235eb44ca.svg
from graphviz import Source
# tedd.viz_itervar_relationship_graph(s, dot_file_path="/tmp/itervar.dot")
dot_string = tedd.viz_itervar_relationship_graph(sch, output_dot_string=True)
src = Source(dot_string)
display_svg(src)
../../../_images/964370680f12e3b95ae84bdd2b37e917d71f6a75ea4855e85828a3a95a33e7f4.svg
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state = te.placeholder((m, n))
s_init = te.compute((1, n), lambda _, i: X[0, i])
s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = te.scan(s_init, s_update, s_state, X)
sch = te.create_schedule(res.op)
sch = sch.normalize() # 类型和 shape 推断
from graphviz import Source
# tedd.viz_itervar_relationship_graph(s, dot_file_path="/tmp/itervar.dot")
dot_string = tedd.viz_itervar_relationship_graph(sch, output_dot_string=True)
src = Source(dot_string)
display_svg(src)
../../../_images/c19f7ef74c487d01dd5b05aca9b945fb907f57672b09146e02ed68b20ac88309.svg
tree = tedd.viz_schedule_tree(sch, show_svg=True)

# tedd.viz_schedule_tree(s, dot_file_path="/tmp/scheduletree.dot")
display_svg(tree)
../../../_images/60a1683ebf2d839782067870264f32104fb692cadf9b5c62c38157808cec81fe.svg
graph = tedd.viz_dataflow_graph(sch, show_svg=True)
display_svg(graph)
../../../_images/d8d33b42d197478d12a3a3d42b045b99590ffee0f6612f3ead1f36995ccb3a2b.svg