TVMのソースコードを眺めていたら、TensorFlow Lite の Importer が投入されているのに気が付きました。
いつものように、テストコード(test_forward.py)から。。。
def test_forward_mobilenet(): '''test mobilenet v1 tflite model''' # MobilenetV1 temp = util.tempdir() tflite_model_file = nnvm.testing.tf.get_workload_official( "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz", "mobilenet_v1_1.0_224.tflite", temp) tflite_model_buf = open(tflite_model_file, "rb").read() data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') tvm_data = np.transpose(data, axes=(0, 3, 1, 2)) tflite_output = run_tflite_graph(tflite_model_buf, data) tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
def run_tflite_graph(tflite_model_buf, input_data): """ Generic function to execute TFLite """ input_data = convert_to_list(input_data) interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # set input assert len(input_data) == len(input_details) for i in range(len(input_details)): interpreter.set_tensor(input_details[i]['index'], input_data[i]) # Run interpreter.invoke() # get output tflite_output = list() for i in range(len(output_details)): tflite_output.append(interpreter.get_tensor(output_details[i]['index'])) return tflite_output
run_tvm_graphメソッドでは、relay.frontend.from_tfliteにて、TensorFlow Liteのモデルを TVMのモデル(func)とパラメータ(params)に変換して、run.build にてビルドし、graph_runtime.createにて、ランタイムの実行環境を作り、set_inputメソッドにて入力データとパラメータを設定して、runメソッドにて推論を実行。get_outputメソッドにて出力データを獲得するって感じ。
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ try: import tflite.Model except ImportError: raise ImportError("The tflite package must be installed") # get TFLite model from buffer tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) shape_dict = {} dtype_dict = {} for i, e in enumerate(input_node): shape_dict[e] = input_data[i].shape dtype_dict[e] = input_data[i].dtype.name func, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) with relay.build_config(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs for i, e in enumerate(input_node): m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) m.set_input(**params) # execute m.run() # get outputs assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( out_names, num_output) tvm_output_list = [] for i in range(0, num_output): tvm_output = m.get_output(i) tvm_output_list.append(tvm_output.asnumpy()) return tvm_output_list
これで、TVMのRelayでは、
・Keras ・MXNet ・ONNX ・TensorFlow Lite
TVMのNNVMでは、
・Keras ・MXNet ・ONNX ・Caffe2 ・CoreML ・DarkNet ・TensorFlow
からモデルを取り込めるようになったようです。