Vengineerの妄想

人生を妄想しています。

TensorFlowのグラフ最適化パス


今日から8月ですね。前職では、8月1日の週は夏休みでしたが、今の勤務先は夏休みはありません。
勝って休めと。。。

ということで、TensorFlowのグラフの最適化の部分をちょっと調べてみたよ。


core/common_runtime/optimization_registry.hは、GraphOptimizationPass クラスの定義部。
class GraphOptimizationPass {
 public:
  virtual ~GraphOptimizationPass() {}
  virtual Status Run(const GraphOptimizationPassOptions& options) = 0;
  void set_name(const string& name) { name_ = name; }
  string name() const { return name_; }

 private:
  // The name of the opitimization pass, which is the same as the inherited
  // class name.
  string name_;
};

そして、その最適化パスをグラフ変換のどのフェーズで行うのか?
  // Groups of passes are run at different points in initialization.
  enum Grouping {
    PRE_PLACEMENT,          // after cost model assignment, before placement.
    POST_PLACEMENT,         // after placement.
    POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
    POST_PARTITIONING,      // after partitioning
  };

グラフの最適化パスを実装するときは、この GraphOptimizationPass クラスを継承し、登録するんだけど、
便利なマクロ REGISTER_OPTIMIZATION(grouping, phase, optimization) を使います。
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
  REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)

#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
  static optimization_registration::OptimizationPassRegistration       \
      register_optimization_##ctr(                                     \
          grouping, phase,                                             \
          std::unique_ptr<GraphOptimizationPass>(new optimization()),  \
          #optimization)

で、REGISTER_OPTIMIZATIONを検索した結果

compiler/jit/jit_compilation_pass_registration.ccでは、3つ。TensorFlow XLAのJITの最適化パス
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
                      MarkForCompilationPass);

// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
                      EncapsulateSubgraphsPass);

// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
                      BuildXlaLaunchOpsPass);

その他に、
core/common_runtime/accumulate_n_optimizer.ccには、
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
                      AccumulateNV2RemovePass);

core/common_runtime/parallel_concat_optimizer.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
                      ParallelConcatRemovePass);

contrib/nccl/kernels/nccl_rewrite.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0,
                      NcclReplacePass);

core/common_runtime/lower_if_op.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
                      LowerIfOpPass);

次の2つは、MKLを使うときの最適化パスですね。POST_PARTITIONG の時に行います。
つまり、グラフを各デバイス毎に分割した後に、行われます。

core/graph/mkl_tfconversion_pass.cc
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
    OptimizationPassRegistry::POST_PARTITIONING;
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);

core/graph/mkl_layout_pass.cc
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
    OptimizationPassRegistry::POST_PARTITIONING;
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);