初识 Relay 变换

初识 Relay 变换#

先从 Relay 变换,开始了解 TVM 的一些 FFI 机制。

研读源码,可以看出 tvm/src/relay/transforms/ 定义了大量 Relay 变换实现。下面挑选 tvm/src/relay/transforms/div_to_mul.cc 中的 DivToMul Pass,以了解 Relay 变换是如何生效的。

namespace tvm {
namespace relay {
namespace transform {
Pass DivToMul() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(DivToMulRewrite().Mutate(f));
      };
  return CreateFunctionPass(pass_func, 0, "DivToMul", {"InferType", "FoldConstant"});
}
// 注册到全局
TVM_REGISTER_GLOBAL("relay._transform.DivToMul").set_body_typed(DivToMul);
}
}
}

这里在名称空间 tvm::relay::transform 下定义变换函数 DivToMul() 并将其注册到全局。

namespace tvm {
namespace relay {
class DivToMulRewrite : public MixedModeMutator {
  Expr Rewrite_(const CallNode* pre, const Expr& post) final {
    if (const CallNode* call_node = post.as<CallNode>()) {
      if (call_node->op == Op::Get("divide")) {
        auto rhs = call_node->args[1].as<ConstantNode>();
        if (rhs != nullptr) {
          auto inv =
              runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);
          std::string dtype = DLDataType2String(rhs->data.DataType());
          if (dtype == "float32") {
            float rhs_val = static_cast<float*>(rhs->data->data)[0];
            // Check for division by zero
            if (rhs_val == 0.) {
              return post;
            }
            static_cast<float*>(inv->data)[0] = 1. / rhs_val;
          } else if (dtype == "float64") {
            double rhs_val = static_cast<double*>(rhs->data->data)[0];
            // Check for division by zero
            if (rhs_val == 0.) {
              return post;
            }
            static_cast<double*>(inv->data)[0] = 1. / rhs_val;
          } else if (dtype == "float16") {
            // Do f16 math in f32
            float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
            // Check for division by zero
            if (rhs_val == 0.) {
              return post;
            }
            static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);
          } else {
            // Cannot do 1/int because it will truncate
            return post;
          }
          return Multiply(call_node->args[0], Constant(inv));
        }
      }
    }
    return post;
  }
};
}
}

想要 Python 端使用,需要在 tvm/python/tvm/relay/transform/transform.py 中定义:

def DivToMul():
    """Transform division by a constant to multiplication by the inverse of the constant"""
    return _ffi_api.DivToMul()

关键点就在于:_ffi_api,即 tvm/python/tvm/relay/transform/_ffi_api.py 中的:

tvm._ffi._init_api("relay._transform", __name__)

Python 端测试代码见:除法转乘法