Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

TensorFlowの中のTensorRT


TensorFlowの中で NVIDIAのTensorRTを使っていますが、
ソースコードの位置が tensorflow/contrib/tensorrt から tensorflow/compiler の下に移った模様。


TRTOptimizationPass_Registrarにて、
TRTOptimizationPassを登録しているみたい。
static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
    []() {
      VLOG(1)
          << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
      return new tensorflow::tensorrt::convert::TRTOptimizationPass(
          "TensorRTOptimizer");
    },
    ("TensorRTOptimizer"));

また、Accelerating Inference In TensorFlow With TensorRT User GuideというドキュメントをNVIDIAは提供しています。

SavedModelからTensorRTを使って、TensorFlowを実行する例が、ドキュメントに載っています。
引用します
# Import TensorFlow and TensorRT
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
# Inference with TF-TRT `SavedModel` workflow:
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        # Create a TensorRT inference graph from a SavedModel:
        trt_graph = trt.create_inference_graph(
            input_saved_model_dir=”/path/to/your/saved/model”,
            input_saved_model_tags=[”your_saved_model_tags”],
            max_batch_size=your_batch_size,
            max_workspace_size_bytes=max_GPU_mem_size_for_TRT,
            precision_mode=”your_precision_mode”) 
        # Import the TensorRT graph into a new graph and run:
        output_node = tf.import_graph_def(
            trt_graph,
            return_elements=[“your_outputs”])
       sess.run(output_node)

trt.create_inference_graphにて、SavedModel から TensorRT inference graph を生成し、
そのモデルを tf.import_graph_def の出力を使って、sess.run で実行するんですね。

A Frozen Graphを使う場合も同様に、trt.create_infecence_graphにて、Frozen Graphから読み込んだモデルを変換して、
後は、SavedModelの時と同じ。
# Import TensorFlow and TensorRT
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
# Inference with TF-TRT frozen graph workflow:
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        # First deserialize your frozen graph:
        with tf.gfile.GFile(“/path/to/your/frozen/graph.pb”, ‘rb’) as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        # Now you can create a TensorRT inference graph from your
        # frozen graph:
        trt_graph = trt.create_inference_graph(
            input_graph_def=graph_def,
            outputs=[“your_output_node_names”],
            max_batch_size=your_batch_size,
            max_workspace_size_bytes=max_GPU_mem_size_for_TRT,
            precision_mode=”your_precision_mode”)
        # Import the TensorRT graph into a new graph and run:
        output_node = tf.import_graph_def(
            trt_graph,
            return_elements=[“your_outputs”])
        sess.run(output_node)

TensorFlowでGPU使って推論する場合は、TensorRTを使えばかなり速くなりそう。。。