今日から8月ですね。前職では、8月1日の週は夏休みでしたが、今の勤務先は夏休みはありません。
勝って休めと。。。
勝って休めと。。。
ということで、TensorFlowのグラフの最適化の部分をちょっと調べてみたよ。
core/common_runtime/optimization_registry.hは、GraphOptimizationPass クラスの定義部。
class GraphOptimizationPass { public: virtual ~GraphOptimizationPass() {} virtual Status Run(const GraphOptimizationPassOptions& options) = 0; void set_name(const string& name) { name_ = name; } string name() const { return name_; } private: // The name of the opitimization pass, which is the same as the inherited // class name. string name_; };
そして、その最適化パスをグラフ変換のどのフェーズで行うのか?
// Groups of passes are run at different points in initialization. enum Grouping { PRE_PLACEMENT, // after cost model assignment, before placement. POST_PLACEMENT, // after placement. POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints. POST_PARTITIONING, // after partitioning };
グラフの最適化パスを実装するときは、この GraphOptimizationPass クラスを継承し、登録するんだけど、
便利なマクロ REGISTER_OPTIMIZATION(grouping, phase, optimization) を使います。
便利なマクロ REGISTER_OPTIMIZATION(grouping, phase, optimization) を使います。
#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \ REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization) #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \ REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \ static optimization_registration::OptimizationPassRegistration \ register_optimization_##ctr( \ grouping, phase, \ std::unique_ptr<GraphOptimizationPass>(new optimization()), \ #optimization)
で、REGISTER_OPTIMIZATIONを検索した結果
compiler/jit/jit_compilation_pass_registration.ccでは、3つ。TensorFlow XLAのJITの最適化パス
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, BuildXlaLaunchOpsPass);
その他に、
core/common_runtime/accumulate_n_optimizer.ccには、
core/common_runtime/accumulate_n_optimizer.ccには、
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, AccumulateNV2RemovePass);
core/common_runtime/parallel_concat_optimizer.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, ParallelConcatRemovePass);
contrib/nccl/kernels/nccl_rewrite.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, NcclReplacePass);
core/common_runtime/lower_if_op.cc
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, LowerIfOpPass);
core/graph/mkl_tfconversion_pass.cc
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
core/graph/mkl_layout_pass.cc
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);