Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

TVMに TensorFlow Lite Importer を投入


TVMのソースコードを眺めていたら、TensorFlow Lite の Importer が投入されているのに気が付きました。

いつものように、テストコード(test_forward.py)から。。。

def test_forward_mobilenet():
    '''test mobilenet v1 tflite model'''
    # MobilenetV1
    temp = util.tempdir()
    tflite_model_file = nnvm.testing.tf.get_workload_official(
        "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz",
        "mobilenet_v1_1.0_224.tflite", temp)
    tflite_model_buf = open(tflite_model_file, "rb").read()
    data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
    tflite_output = run_tflite_graph(tflite_model_buf, data)
    tvm_output = run_tvm_graph(tflite_model_buf, tvm_data, 'input')
    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
                                rtol=1e-5, atol=1e-5)

の、run_tflite_graph メソッドにて、TensorFlow LiteのInterpreterによる実行。

def run_tflite_graph(tflite_model_buf, input_data):
    """ Generic function to execute TFLite """
    input_data = convert_to_list(input_data)

    interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # set input
    assert len(input_data) == len(input_details)
    for i in range(len(input_details)):
        interpreter.set_tensor(input_details[i]['index'], input_data[i])

    # Run
    interpreter.invoke()

    # get output
    tflite_output = list()
    for i in range(len(output_details)):
        tflite_output.append(interpreter.get_tensor(output_details[i]['index']))

    return tflite_output

run_tvm_graphメソッドでは、relay.frontend.from_tfliteにて、TensorFlow Liteのモデルを TVMのモデル(func)とパラメータ(params)に変換して、run.build にてビルドし、graph_runtime.createにて、ランタイムの実行環境を作り、set_inputメソッドにて入力データとパラメータを設定して、runメソッドにて推論を実行。get_outputメソッドにて出力データを獲得するって感じ。

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                  out_names=None):
    """ Generic function to compile on relay and execute on tvm """
    try:
        import tflite.Model
    except ImportError:
        raise ImportError("The tflite package must be installed")

    # get TFLite model from buffer
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

    input_data = convert_to_list(input_data)
    input_node = convert_to_list(input_node)

    shape_dict = {}
    dtype_dict = {}
    for i, e in enumerate(input_node):
        shape_dict[e] = input_data[i].shape
        dtype_dict[e] = input_data[i].dtype.name

    func, params = relay.frontend.from_tflite(tflite_model,
                                              shape_dict=shape_dict,
                                              dtype_dict=dtype_dict)
    with relay.build_config(opt_level=3):
        graph, lib, params = relay.build(func, target, params=params)

    ctx = tvm.context(target, 0)
    from tvm.contrib import graph_runtime
    m = graph_runtime.create(graph, lib, ctx)
    # set inputs
    for i, e in enumerate(input_node):
        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

    m.set_input(**params)
    # execute
    m.run()
    # get outputs
    assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
        out_names, num_output)
    tvm_output_list = []
    for i in range(0, num_output):
        tvm_output = m.get_output(i)
        tvm_output_list.append(tvm_output.asnumpy())
    return tvm_output_list

これで、TVMのRelayでは、
 ・Keras
 ・MXNet
 ・ONNX
 ・TensorFlow Lite

TVMのNNVMでは、
 ・Keras
 ・MXNet
 ・ONNX
 ・Caffe2
 ・CoreML
 ・DarkNet
 ・TensorFlow

からモデルを取り込めるようになったようです。