Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

Intel nGraph で Halide はどのように使われているか?(その2)



実装は、src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cppの HalideSubgraphExtraction::run_on_function メソッドです。

引用
bool runtime::cpu::pass::HalideSubgraphExtraction::run_on_function(
    std::shared_ptr<ngraph::Function> function)
{
    list<shared_ptr<Node>> worklist;
    auto results = function->get_results();  

    // Artificial limitation
    if (results.size() > 1)
    {
        return false;
    }

    if (function->get_result()->get_element_type() != element::f32)
    {
        return false;
    }

    for (const auto& result : results)
    {
        worklist.emplace_back(result);
    }

ここまでで、出力、つまり、戻り値の数が1であることを確認後、戻り値の方が f32 であることも確認。最後に、
戻り値を worklist に保存。

    unordered_set<shared_ptr<Node>> ops;
    list<shared_ptr<Node>> ordered_ops;

    while (!worklist.empty())
    {
        const auto& node = worklist.front();

        if (!halide::skiplist.count(TI(*node)))
        {
            if (halide::whitelist.count(TI(*node)))
            {
                ops.emplace(node);
                ordered_ops.emplace_back(node);
            }
            else
            {
                break;
            }
        }
        const auto& args = node->get_arguments();
        for (const auto& arg : args)
        {
            worklist.emplace_back(arg);
        }
        worklist.pop_front();
    }

戻り値から逆に戻り、各ノードを探索していく。。探索したノードは、ops と ordered_ops に保存。

    NodeVector liveins;
    for (const auto& op : ops)
    {
        const auto& args = op->get_arguments();
        for (const auto& arg : args)
        {
            if (!ops.count(arg))
            {
                liveins.emplace_back(arg);
            }
        }
    }
    ordered_ops.reverse();
    if (ordered_ops.size() > 1)
    {
        auto subgraph = make_shared<cpu::op::HalideOp>(liveins,
                                                       ordered_ops,
                                                       function->get_result()->get_element_type(),
                                                       function->get_result()->get_shape());

        replace_node(function->get_result()->get_argument(0), subgraph);
        return true;
    }
    else
    {
        return false;
    }
}

ops と ordered_ops から cpu::op::HalideOp を使って、subgraph を生成します。
作った、replace_node を使って、生成した subgraph に置き換えます。

replace_node は、/src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cppで次のように定義されています。
void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
    ngraph::replace_node(old, repl);
}

nraph::replace_node を呼んでいます。
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
    if (target->is_output())
    {
        throw ngraph_error("Result nodes cannot be replaced.");
    }

    if (target->get_users().empty())
    {
        throw ngraph_error("replacing an unreachable node");
    }
    // Fix input/output descriptors
    assert(target->get_outputs().size() == replacement->get_outputs().size());


置き換える側のノードをチェックをしてから下記のように置き換えていきます。

    // For each of target's output O with replacement output O_rep:
    //     For each O's connected downstream input I:
    //         Change I's connected upstream output to O_rep
    for (size_t i = 0; i < target->get_outputs().size(); i++)
    {
        auto& target_output = target->get_outputs().at(i);
        std::set<ngraph::descriptor::Input*> copy_inputs{begin(target_output.get_inputs()),
                                                         end(target_output.get_inputs())};
        for (auto input : copy_inputs)
        {
            input->replace_output(replacement->get_outputs().at(i));
        }
    }
}

test/halide.cppの下記のコードを見てみましょう。

TEST(halide, halide_subgraph)
{
    Shape shape{8};
    auto A = make_shared<op::Parameter>(element::f32, shape);
    auto B = make_shared<op::Parameter>(element::f32, shape);
    auto C = make_shared<op::Parameter>(element::f32, shape);
    auto D = make_shared<op::Parameter>(element::f32, shape);

    auto relu = make_shared<op::Relu>((A + B) * C);

    auto f = make_shared<Function>(relu + D, ParameterVector{A, B, C, D});

    auto backend = runtime::Backend::create("CPU");
    shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
    shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
    shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
    shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);

    shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);

    vector<float> data{-1, 4, -2, 5, 1, 5, 7, 9};

    copy_data(a, data);
    copy_data(b, data);
    copy_data(c, data);
    copy_data(d, data);

    vector<float> expected{1, 36, 6, 55, 3, 55, 105, 171};

    backend->call_with_validate(backend->compile(f), {result}, {a, b, c, d});

    EXPECT_TRUE(test::all_close(read_vector<float>(result), expected, 1.0e-4f, 1.0e-4f));
}

    auto relu = make_shared<op::Relu>((A + B) * C);
と
    auto f = make_shared<Function>(relu + D, ParameterVector{A, B, C, D});
の
    relu + D
のどこかが HalideのSubgraph として構築されるのでしょうね。

明日は、cpu::op::HalideOp を見ていきます。