VTA 运行时数组

VTA 运行时数组#

import tvm
import numpy as np

import vta.testing

np.random.seed(0xDEADB)
def _run(env, remote):
    n = 100
    dev = remote.ext_dev(0)
    x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8")
    x_nd = tvm.nd.array(x_np, dev)
    print(x_nd.device)
    np.testing.assert_equal(x_np, x_nd.numpy())

vta.testing.run(_run)
remote[0]:ext_dev(0)