[opset5] ngraph implementation of Loop op (#2583)
* Loop op ngraph implementation, update IE IR Reader and ngraph to cnn converter * refactoring SubGraphOp class * type prop unit tests * ngraph code style * update comment * single layer tests for Loop operation * fix file name * Add SpecialBodyPorts attribute in Loop op, update single layer tests * add several new tests cases, strict checks in Loop impl, temporary disable single layer tests * ngraph codestyle, refactoring, clone_new_args test * resolve review remarks * fix build * fix tests * add a new constructor of Loop op, resolve review remarks
This commit is contained in:
@@ -43,6 +43,7 @@
|
||||
#include "caseless.hpp"
|
||||
#include <debug.h>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||
#include "transformations/rt_info/primitives_priority_attribute.hpp"
|
||||
@@ -809,6 +810,7 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::TopKIE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::Unsqueeze>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::TensorIterator>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::opset5::Loop>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::HardSigmoid_IE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::v1::LogicalNot>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::ShuffleChannels>>(),
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
|
||||
#include <legacy/convert_function_to_cnn_network.hpp>
|
||||
#include "legacy/graph_transformer.h"
|
||||
@@ -114,8 +115,7 @@ CNNLayer::Ptr NodeConverter<ngraph::op::GenericIE>::createLayer(const std::share
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
CNNLayer::Ptr createSubGraphLayer(const std::shared_ptr<ngraph::Node>& layer) {
|
||||
auto find_input_idx = [](const CNNLayerPtr& where, const DataPtr& what) {
|
||||
auto it = std::find_if(where->insData.begin(), where->insData.end(), [&](const DataWeakPtr& wk_ptr) {
|
||||
auto layer_data = wk_ptr.lock();
|
||||
@@ -129,7 +129,7 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::
|
||||
return it - where->insData.begin();
|
||||
};
|
||||
|
||||
auto tensor_iterator = ngraph::as_type_ptr<ngraph::op::TensorIterator>(layer);
|
||||
auto tensor_iterator = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(layer);
|
||||
if (!tensor_iterator) {
|
||||
THROW_IE_EXCEPTION << "Cannot cast layer to TensorIterator.";
|
||||
}
|
||||
@@ -142,8 +142,8 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::
|
||||
std::map<std::string, DataPtr> out_info_map;
|
||||
|
||||
// inputs/outputs of TensorIterator (ngraph representation)
|
||||
auto parameters = tensor_iterator->get_body()->get_parameters();
|
||||
auto results = tensor_iterator->get_body()->get_results();
|
||||
auto parameters = tensor_iterator->get_function()->get_parameters();
|
||||
auto results = tensor_iterator->get_function()->get_results();
|
||||
|
||||
// Convert body (ngraph representation) to CNNNetwork.
|
||||
// This network will contain nodes of type = "Input" and data nodes with wrong names.
|
||||
@@ -155,7 +155,7 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::
|
||||
// This map will save information about data nodes
|
||||
std::map<std::string, std::vector<TensorDesc>> layer_name_to_tensor_desc;
|
||||
{
|
||||
CNNNetwork body_net(tensor_iterator->get_body());
|
||||
CNNNetwork body_net(tensor_iterator->get_function());
|
||||
CNNNetwork net(InferenceEngine::details::convertFunctionToICNNNetwork(body_net.getFunction(), body_net));
|
||||
// Paranoid check for cycles
|
||||
bool res = CNNNetForestDFS(
|
||||
@@ -356,6 +356,20 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::
|
||||
return res;
|
||||
}
|
||||
|
||||
template<>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::TensorIterator>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
auto res = createSubGraphLayer(layer);
|
||||
res->type = "TensorIterator";
|
||||
return res;
|
||||
}
|
||||
|
||||
template<>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::opset5::Loop>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
auto res = createSubGraphLayer(layer);
|
||||
res->type = "Loop";
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::Constant>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
LayerParams params = {layer->get_friendly_name(), "Const",
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
@@ -477,6 +478,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
std::make_shared<LayerCreator<ngraph::op::v0::Tile>>("Tile"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::TopK>>("TopK"),
|
||||
std::make_shared<LayerCreator<ngraph::op::TensorIterator>>("TensorIterator"),
|
||||
std::make_shared<LayerCreator<ngraph::opset5::Loop>>("Loop"),
|
||||
std::make_shared<LayerCreator<ngraph::op::Transpose>>("Transpose"),
|
||||
std::make_shared<LayerCreator<ngraph::op::Unsqueeze>>("Unsqueeze"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::LogicalAnd>>("LogicalAnd"),
|
||||
@@ -662,12 +664,12 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::DetectionOutpu
|
||||
}
|
||||
}
|
||||
|
||||
// TensorIterator layer
|
||||
template <>
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator>::createLayer(
|
||||
const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
|
||||
const GenericLayerParams& layerParsePrms) {
|
||||
auto tensor_iterator = std::make_shared<ngraph::op::TensorIterator>();
|
||||
// SubGraph layer
|
||||
std::shared_ptr<ngraph::Node>
|
||||
V10Parser::LayerBaseCreator::fillSubGraphLayer(const ngraph::OutputVector &inputs, const pugi::xml_node &node,
|
||||
std::istream &binStream,
|
||||
const V10Parser::GenericLayerParams &layerParsePrms,
|
||||
std::shared_ptr<ngraph::op::util::SubGraphOp> tensor_iterator) {
|
||||
tensor_iterator->set_friendly_name(GetStrAttr(node, "name"));
|
||||
auto body_node = node.child("body");
|
||||
|
||||
@@ -695,7 +697,7 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator
|
||||
// Disabled reshape for generic operations in the TI body
|
||||
::ngraph::op::GenericIE::DisableReshape noReshape(ngraph_function);
|
||||
auto body = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes);
|
||||
tensor_iterator->set_body(body);
|
||||
tensor_iterator->set_function(body);
|
||||
|
||||
// Parse PortMap: inputs
|
||||
std::map<uint64_t, pugi::xml_node> input_map;
|
||||
@@ -795,7 +797,8 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator
|
||||
tensor_iterator->get_concatenated_slices(*body_result, start, stride, part_size, end, axis);
|
||||
|
||||
if (!is_sliced_input_exists) {
|
||||
tensor_iterator->set_num_iterations((std::abs(end - start)) / part_size);
|
||||
if (auto ti = std::dynamic_pointer_cast<ngraph::op::TensorIterator>(tensor_iterator))
|
||||
ti->set_num_iterations((std::abs(end - start)) / part_size);
|
||||
}
|
||||
} else {
|
||||
// otherwise create ngraph::TensorIterator::BodyOutput. -1 means last iteration.
|
||||
@@ -807,6 +810,25 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator
|
||||
return tensor_iterator;
|
||||
}
|
||||
|
||||
|
||||
// TensorIterator layer
|
||||
template <>
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::TensorIterator>::createLayer(
|
||||
const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
|
||||
const GenericLayerParams& layerParsePrms) {
|
||||
auto ti = std::make_shared<ngraph::op::TensorIterator>();
|
||||
return fillSubGraphLayer(inputs, node, binStream, layerParsePrms, ti);
|
||||
}
|
||||
|
||||
// Loop layer
|
||||
template <>
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::opset5::Loop>::createLayer(
|
||||
const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
|
||||
const GenericLayerParams& layerParsePrms) {
|
||||
auto loop = std::make_shared<ngraph::opset5::Loop>();
|
||||
return fillSubGraphLayer(inputs, node, binStream, layerParsePrms, loop);
|
||||
}
|
||||
|
||||
// PriorBoxClustered layer
|
||||
template <>
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::PriorBoxClustered>::createLayer(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#ifdef IR_READER_V10
|
||||
# include <ngraph/node.hpp>
|
||||
# include <ngraph/op/util/sub_graph_base.hpp>
|
||||
# include <legacy/ie_ngraph_utils.hpp>
|
||||
# include <cpp/ie_cnn_network.h>
|
||||
#endif // IR_READER_V10
|
||||
@@ -102,6 +103,10 @@ private:
|
||||
std::string type;
|
||||
|
||||
protected:
|
||||
static std::shared_ptr<ngraph::Node> fillSubGraphLayer(const ngraph::OutputVector& inputs, const pugi::xml_node& node,
|
||||
std::istream& binStream,
|
||||
const GenericLayerParams& layerParsePrms,
|
||||
std::shared_ptr<ngraph::op::util::SubGraphOp> sub_graph_node);
|
||||
explicit LayerBaseCreator(const std::string& type): type(type) {}
|
||||
std::string getType() {
|
||||
return type;
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/loop.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
// without clip values increase rapidly, so use only seq_lenghts = 2
|
||||
std::vector<bool> execute_first_iteration{true};
|
||||
std::vector<bool> is_body_condition_const{true, false};
|
||||
std::vector<bool> body_condition{true, false}; // works only if is_body_condition_const == true
|
||||
std::vector<int64_t> trip_count{1, 10, -1}; // -1 means infinity
|
||||
std::vector<std::vector<std::pair<std::vector<size_t>, LOOP_IN_TYPE>>> inputs = {
|
||||
{{{32, 1, 10}, LOOP_IN_TYPE::INVARIANT}, {{32, 1, 10}, LOOP_IN_TYPE::INVARIANT}, {{32, 1, 10}, LOOP_IN_TYPE::MERGED}},
|
||||
};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_LoopCommonZeroClip, LoopTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(execute_first_iteration),
|
||||
::testing::ValuesIn(is_body_condition_const),
|
||||
::testing::ValuesIn(body_condition),
|
||||
::testing::ValuesIn(trip_count),
|
||||
::testing::ValuesIn(inputs),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
LoopTest::getTestCaseName);
|
||||
} // namespace
|
||||
@@ -53,6 +53,9 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*ReverseSequenceLayerTest.*netPRC=(I8|U8).*)",
|
||||
// TODO: Issue: 38841
|
||||
R"(.*TopKLayerTest.*k=10.*mode=min.*sort=index.*)",
|
||||
R"(.*TopKLayerTest.*k=5.*sort=(none|index).*)"
|
||||
R"(.*TopKLayerTest.*k=5.*sort=(none|index).*)",
|
||||
|
||||
// TODO: not supported yet, ticket 37690
|
||||
R"(.*Loop.*)"
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
enum LOOP_IN_TYPE {
|
||||
INVARIANT,
|
||||
MERGED
|
||||
};
|
||||
|
||||
using LoopParams = typename std::tuple<
|
||||
bool, // ExecuteFirstIteration
|
||||
bool, // BodyCondition is a constant?
|
||||
bool, // BodyCondition value, if it is a Const
|
||||
int64_t, // TripCount, -1 means infinity
|
||||
std::vector<std::pair<std::vector<size_t>, LOOP_IN_TYPE>>, // inputs
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class LoopTest : public testing::WithParamInterface<LoopParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<LoopParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -0,0 +1,159 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include "single_layer_tests/loop.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string LoopTest::getTestCaseName(const testing::TestParamInfo<LoopParams> &obj) {
|
||||
bool execute_first_iteration;
|
||||
bool is_body_condition_const;
|
||||
bool body_condition; // works only if is_body_condition_const ==
|
||||
int64_t trip_count;
|
||||
std::vector<std::pair<std::vector<size_t>, LOOP_IN_TYPE>> inputs;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(execute_first_iteration, is_body_condition_const, body_condition, trip_count, inputs, netPrecision,
|
||||
targetDevice) = obj.param;
|
||||
|
||||
std::vector<std::vector<size_t>> inputs_separate;
|
||||
std::vector<LOOP_IN_TYPE> types_separate;
|
||||
for (auto &el : inputs) {
|
||||
inputs_separate.push_back(el.first);
|
||||
types_separate.push_back(el.second);
|
||||
}
|
||||
std::ostringstream result;
|
||||
result << "execute_first_iteration" << execute_first_iteration << "_";
|
||||
result << "is_body_condition_const=" << is_body_condition_const << "_";
|
||||
result << "body_condition=" << body_condition << "_";
|
||||
result << "trip_count=" << trip_count << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputs_separate) << "_";
|
||||
result << "types=" << CommonTestUtils::vec2str(types_separate) << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void LoopTest::SetUp() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
SetRefMode(LayerTestsUtils::IE);
|
||||
bool execute_first_iteration;
|
||||
bool is_body_condition_const;
|
||||
bool body_condition; // works only if is_body_condition_const ==
|
||||
int64_t trip_count;
|
||||
std::vector<std::pair<std::vector<size_t>, LOOP_IN_TYPE>> inputs;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(execute_first_iteration, is_body_condition_const, body_condition, trip_count, inputs, netPrecision,
|
||||
targetDevice) = this->GetParam();
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
// That which we iterate over
|
||||
std::vector<std::vector<size_t>> inputs_separate;
|
||||
std::vector<LOOP_IN_TYPE> types_separate;
|
||||
for (auto &el : inputs) {
|
||||
inputs_separate.push_back(el.first);
|
||||
types_separate.push_back(el.second);
|
||||
}
|
||||
// Example:
|
||||
/* auto X = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::Shape{32, 1, 10});
|
||||
auto Y = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::Shape{32, 1, 10});
|
||||
auto M = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::Shape{32, 1, 10});*/
|
||||
auto params = ngraph::builder::makeParams(ngPrc, inputs_separate);
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
const std::vector<ngraph::PartialShape> body_params_shapes(inputs_separate.size(), ngraph::PartialShape::dynamic());
|
||||
auto current_iteration = std::make_shared<ngraph::op::Parameter>(ngraph::element::i64, ngraph::Shape{1});
|
||||
|
||||
//Example:
|
||||
/* auto Xi = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto Yi = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto M_body = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());*/
|
||||
|
||||
ngraph::ParameterVector body_params;
|
||||
for (const auto &pshape : body_params_shapes) {
|
||||
auto paramNode = std::make_shared<ngraph::opset1::Parameter>(ngPrc, pshape);
|
||||
body_params.push_back(paramNode);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> body_condition_const;
|
||||
if (is_body_condition_const) {
|
||||
if (body_condition) {
|
||||
body_condition_const = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
} else {
|
||||
body_condition_const = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, false);
|
||||
}
|
||||
}
|
||||
|
||||
auto trip_count_const =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, trip_count);
|
||||
|
||||
std::shared_ptr<ngraph::Node> exec_condition;
|
||||
if (execute_first_iteration) {
|
||||
exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
} else {
|
||||
exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, false);
|
||||
}
|
||||
|
||||
// Body
|
||||
std::shared_ptr<ngraph::Node> Zo = body_params[0];
|
||||
for (int i = 1; i < body_params.size(); ++i) {
|
||||
Zo = body_params[i] + Zo;
|
||||
}
|
||||
|
||||
// body_params.insert(body_params.begin(), current_iteration);
|
||||
auto body = std::make_shared<ngraph::Function>(ngraph::OutputVector{body_condition_const, Zo},
|
||||
body_params);
|
||||
|
||||
auto loop = std::make_shared<ngraph::opset5::Loop>(trip_count_const, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
for (int i = 0; i < body_params.size(); ++i) {
|
||||
if (types_separate[i] == LOOP_IN_TYPE::INVARIANT) {
|
||||
loop->set_invariant_input(body_params[i], params[i]);
|
||||
} else if (types_separate[i] == LOOP_IN_TYPE::MERGED) {
|
||||
// todo: support several merged inputs
|
||||
// now supported only one in this sample
|
||||
loop->set_merged_input(body_params[i], params[i], Zo);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition_const, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
auto result0 = std::make_shared<ngraph::op::Result>(out0);
|
||||
auto result1 = std::make_shared<ngraph::op::Result>(out1);
|
||||
auto result2 = std::make_shared<ngraph::op::Result>(out2);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::ResultVector{result0, result1, result2}, params, "loop");
|
||||
}
|
||||
|
||||
|
||||
TEST_P(LoopTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
||||
105
ngraph/core/include/ngraph/op/loop.hpp
Normal file
105
ngraph/core/include/ngraph/op/loop.hpp
Normal file
@@ -0,0 +1,105 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/factory_adapter.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v5
|
||||
{
|
||||
/// \brief Iterate a body over tensors, accumulating into tensors.
|
||||
class NGRAPH_API Loop : public op::util::SubGraphOp
|
||||
{
|
||||
public:
|
||||
/// \brief Allows to define the purpose of inputs/outputs in the body
|
||||
struct SpecialBodyPorts
|
||||
{
|
||||
SpecialBodyPorts() = default;
|
||||
SpecialBodyPorts(int64_t in_current_iteration_input_idx,
|
||||
int64_t in_body_condition_output_idx)
|
||||
: current_iteration_input_idx(in_current_iteration_input_idx)
|
||||
, body_condition_output_idx(in_body_condition_output_idx)
|
||||
{
|
||||
}
|
||||
// -1 means the input is not provided, this input is optional
|
||||
int64_t current_iteration_input_idx = -1;
|
||||
// -1 means the output is not provided,
|
||||
// this output is required, throw an exception if not provided
|
||||
int64_t body_condition_output_idx = -1;
|
||||
};
|
||||
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a Loop operation.
|
||||
Loop() = default;
|
||||
|
||||
/// \brief Constructs a Loop operation.
|
||||
///
|
||||
/// \param trip_count Node specifies the maximum number of iterations.
|
||||
/// \param execution_condition Node determines whether to execute the first
|
||||
/// iteration or not.
|
||||
Loop(const Output<Node>& trip_count, const Output<Node>& execution_condition);
|
||||
|
||||
int64_t get_num_iterations() const { return m_num_iterations; }
|
||||
void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis) override
|
||||
{
|
||||
NGRAPH_CHECK(false,
|
||||
"Incorrect type of input. Implicit slicing is not supported in "
|
||||
"Loop operation.");
|
||||
}
|
||||
|
||||
Output<Node> get_concatenated_slices(const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis) override;
|
||||
|
||||
void set_special_body_ports(const SpecialBodyPorts& special_body_ports)
|
||||
{
|
||||
m_special_body_ports = special_body_ports;
|
||||
}
|
||||
|
||||
SpecialBodyPorts get_special_body_ports() const { return m_special_body_ports; }
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
private:
|
||||
SpecialBodyPorts m_special_body_ports;
|
||||
int64_t m_num_iterations = -1; // -1 means infinity
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,7 +293,7 @@ namespace ngraph
|
||||
class NGRAPH_API LSTMCell : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMCell", 1};
|
||||
static constexpr NodeTypeInfo type_info{"LSTMCell", 4};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
LSTMCell();
|
||||
///
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/factory_adapter.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
@@ -36,273 +35,9 @@ namespace ngraph
|
||||
static constexpr NodeTypeInfo type_info{"TensorIterator", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
// Forward declarations
|
||||
class SliceInputDescription;
|
||||
class MergedInputDescription;
|
||||
class InvariantInputDescription;
|
||||
|
||||
TensorIterator() = default;
|
||||
TensorIterator(const OutputVector& values);
|
||||
|
||||
/// \brief Describes a connection between a TensorIterator input and the body.
|
||||
class InputDescription
|
||||
{
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the TensorIterator input
|
||||
/// \param body_parameter_index Body parameter to receive input
|
||||
///
|
||||
InputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~InputDescription() {}
|
||||
virtual std::shared_ptr<InputDescription> copy() const = 0;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
|
||||
uint64_t m_input_index{0};
|
||||
uint64_t m_body_parameter_index{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input formed from slices of an input to
|
||||
/// TensorIterator.
|
||||
///
|
||||
class NGRAPH_API SliceInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"SliceInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the TensorIterator input
|
||||
/// \param body_parameter_index Body parameter position to receive input
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
SliceInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input initialized from a TensorIterator input on
|
||||
/// the first iteration, and then a body output thereafter.
|
||||
///
|
||||
class NGRAPH_API MergedInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"MergedInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the TensorIterator input
|
||||
/// supplying a value to body_parameter for
|
||||
/// the initial iteration.
|
||||
/// \param body_parameter_index Body parameter position to receive input.
|
||||
/// \param body_value_index Body value to supply body_parameter for
|
||||
/// successive
|
||||
/// iterations.
|
||||
///
|
||||
MergedInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
uint64_t body_value_index);
|
||||
MergedInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
uint64_t m_body_value_index{0};
|
||||
};
|
||||
|
||||
class NGRAPH_API InvariantInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"InvariantInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InvariantInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
};
|
||||
|
||||
// Forward declarations
|
||||
class ConcatOutputDescription;
|
||||
class BodyOutputDescription;
|
||||
|
||||
/// \brief Describes how a TensorIterator output is produced from the body.
|
||||
class OutputDescription
|
||||
{
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The TensorIterator output index
|
||||
///
|
||||
OutputDescription(uint64_t body_value_index, uint64_t output_index);
|
||||
OutputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~OutputDescription() {}
|
||||
virtual std::shared_ptr<OutputDescription> copy() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
uint64_t m_body_value_index{0};
|
||||
uint64_t m_output_index{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an output by concatenating an output from each iteration
|
||||
class NGRAPH_API ConcatOutputDescription : public OutputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The TensorIterator output index
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
ConcatOutputDescription() = default;
|
||||
|
||||
virtual std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an output from a specific iteration
|
||||
class NGRAPH_API BodyOutputDescription : public OutputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"BodyOutputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The TensorIterator output index
|
||||
/// \param iteration which iteration (typically -1, final) will
|
||||
/// supply the value
|
||||
///
|
||||
BodyOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t iteration);
|
||||
BodyOutputDescription() = default;
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_iteration{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Indicate that a body parameter comes from slices of a value
|
||||
///
|
||||
/// \param parameter The parameter to receive the slices
|
||||
/// \param value The value to be sliced. This will be added as an input to
|
||||
/// TensorIterator.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an initial value in the first
|
||||
/// iteration and computed value thereafter
|
||||
///
|
||||
/// \param[in] body_parameter The body parameter
|
||||
/// \param initial_value Value for the parameter in first iteration. This
|
||||
/// will be added as an input to TensorIterator.
|
||||
/// \param successive_value Value for the parameter in successive iterations.
|
||||
/// The value is what is active in the most recent
|
||||
/// completed iteration.
|
||||
///
|
||||
void set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& initial_value,
|
||||
const Output<Node>& successive_value);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an invariant value during
|
||||
/// iteration that may depend on values computed outside of the
|
||||
/// iteration.
|
||||
///
|
||||
/// \param body_parameter The body parameter
|
||||
/// \param value The value supplied as an input to the block
|
||||
///
|
||||
void set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& value);
|
||||
///
|
||||
/// \brief Gets a value for a particular iteration point
|
||||
///
|
||||
/// \param body_value The value
|
||||
/// \param iteration The iteration that supplies the value. Negative values
|
||||
/// are from the last iteration.
|
||||
///
|
||||
/// \return The iterator value.
|
||||
///
|
||||
Output<Node> get_iter_value(const Output<Node>& body_value, int64_t iteration);
|
||||
///
|
||||
/// \brief Concatenates slices from all iterations
|
||||
///
|
||||
/// \param value The value supplying slice values from each iteration.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
/// \return The concatenated slices.
|
||||
///
|
||||
Output<Node> get_concatenated_slices(const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
explicit TensorIterator(const OutputVector& values);
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
@@ -310,32 +45,7 @@ namespace ngraph
|
||||
std::shared_ptr<Function> get_body() const { return m_body; }
|
||||
/// \param body set the body of the iteration
|
||||
void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
|
||||
/// \return a reference to the input descriptions.
|
||||
const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
|
||||
{
|
||||
return m_input_descriptions;
|
||||
}
|
||||
/// \return a reference to the input descriptions. Can add input descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
|
||||
{
|
||||
return m_input_descriptions;
|
||||
}
|
||||
/// \return a reference to the output descriptions.
|
||||
const std::vector<std::shared_ptr<OutputDescription>>&
|
||||
get_output_descriptions() const
|
||||
{
|
||||
return m_output_descriptions;
|
||||
}
|
||||
/// \return a reference to the output descriptions. Can add output descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
|
||||
{
|
||||
return m_output_descriptions;
|
||||
}
|
||||
virtual void validate_and_infer_types() override;
|
||||
void validate_and_infer_types() override;
|
||||
void revalidate_and_infer_types_for_body_ops();
|
||||
/// \return the body of the iteration
|
||||
std::shared_ptr<Function> get_function() override;
|
||||
@@ -347,81 +57,9 @@ namespace ngraph
|
||||
}
|
||||
|
||||
private:
|
||||
// Find an input corresponding to value, adding one if necessary.
|
||||
Input<Node> input_for_value(const Output<Node>& value);
|
||||
|
||||
std::shared_ptr<Function> m_body;
|
||||
std::vector<std::shared_ptr<InputDescription>> m_input_descriptions;
|
||||
std::vector<std::shared_ptr<OutputDescription>> m_output_descriptions;
|
||||
|
||||
int64_t m_num_iterations = -1;
|
||||
};
|
||||
}
|
||||
using v0::TensorIterator;
|
||||
}
|
||||
template class NGRAPH_API FactoryRegistry<op::v0::TensorIterator::InputDescription>;
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::TensorIterator::InputDescription>>
|
||||
: public FactoryAttributeAdapter<op::TensorIterator::InputDescription>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::TensorIterator::InputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>>
|
||||
: public VisitorAdapter
|
||||
{
|
||||
public:
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>& ref);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>& m_ref;
|
||||
};
|
||||
|
||||
template class NGRAPH_API FactoryRegistry<op::v0::TensorIterator::OutputDescription>;
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::TensorIterator::OutputDescription>>
|
||||
: public FactoryAttributeAdapter<op::TensorIterator::OutputDescription>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::TensorIterator::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>>
|
||||
: public VisitorAdapter
|
||||
{
|
||||
public:
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>& ref);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>& m_ref;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/op/parameter.hpp>
|
||||
#include "ngraph/factory_adapter.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@@ -29,13 +31,383 @@ namespace ngraph
|
||||
class NGRAPH_API SubGraphOp : public Op
|
||||
{
|
||||
public:
|
||||
virtual std::shared_ptr<Function> get_function();
|
||||
/// \brief Describes a connection between a SubGraphOp input and the body.
|
||||
class InputDescription
|
||||
{
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the SubGraphOp input
|
||||
/// \param body_parameter_index Body parameter to receive input
|
||||
///
|
||||
InputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~InputDescription() = default;
|
||||
virtual std::shared_ptr<InputDescription> copy() const = 0;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
|
||||
uint64_t m_input_index{0};
|
||||
uint64_t m_body_parameter_index{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input formed from slices of an input to
|
||||
/// SubGraphOp.
|
||||
///
|
||||
class NGRAPH_API SliceInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"SliceInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the SubGraphOp input
|
||||
/// \param body_parameter_index Body parameter position to receive input
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
SliceInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input initialized from a SubGraphOp input on
|
||||
/// the first iteration, and then a body output thereafter.
|
||||
///
|
||||
class NGRAPH_API MergedInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"MergedInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the SubGraphOp input
|
||||
/// supplying a value to body_parameter for
|
||||
/// the initial iteration.
|
||||
/// \param body_parameter_index Body parameter position to receive input.
|
||||
/// \param body_value_index Body value to supply body_parameter for
|
||||
/// successive
|
||||
/// iterations.
|
||||
///
|
||||
MergedInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
uint64_t body_value_index);
|
||||
MergedInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
uint64_t m_body_value_index{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input initialized from a SubGraphOp input on
|
||||
/// the first iteration, and invariant thereafter.
|
||||
///
|
||||
class NGRAPH_API InvariantInputDescription : public InputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"InvariantInputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the SubGraphOp input
|
||||
/// \param body_parameter_index Body parameter to receive input
|
||||
///
|
||||
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InvariantInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
};
|
||||
|
||||
/// \brief Describes how a SubGraphOp output is produced from the body.
|
||||
class OutputDescription
|
||||
{
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The SubGraphOp output index
|
||||
///
|
||||
OutputDescription(uint64_t body_value_index, uint64_t output_index);
|
||||
OutputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~OutputDescription() = default;
|
||||
virtual std::shared_ptr<OutputDescription> copy() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
uint64_t m_body_value_index{0};
|
||||
uint64_t m_output_index{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an output by concatenating an output from each iteration
|
||||
class NGRAPH_API ConcatOutputDescription : public OutputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The SubGraphOp output index
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
ConcatOutputDescription() = default;
|
||||
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an output from a specific iteration
|
||||
class NGRAPH_API BodyOutputDescription : public OutputDescription
|
||||
{
|
||||
public:
|
||||
static constexpr type_info_t type_info{"BodyOutputDescription", 0};
|
||||
const type_info_t& get_type_info() const override { return type_info; }
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The SubGraphOp output index
|
||||
/// \param iteration which iteration (typically -1, final) will
|
||||
/// supply the value
|
||||
///
|
||||
BodyOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t iteration);
|
||||
BodyOutputDescription() = default;
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_iteration{0};
|
||||
};
|
||||
|
||||
virtual std::shared_ptr<Function> get_function() { return m_body; };
|
||||
virtual void set_function(const std::shared_ptr<Function>& func) { m_body = func; };
|
||||
/// \return a reference to the input descriptions.
|
||||
const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
|
||||
{
|
||||
return m_input_descriptions;
|
||||
}
|
||||
/// \return a reference to the input descriptions. Can add input descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
|
||||
{
|
||||
return m_input_descriptions;
|
||||
}
|
||||
/// \return a reference to the output descriptions.
|
||||
const std::vector<std::shared_ptr<OutputDescription>>&
|
||||
get_output_descriptions() const
|
||||
{
|
||||
return m_output_descriptions;
|
||||
}
|
||||
/// \return a reference to the output descriptions. Can add output descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
|
||||
{
|
||||
return m_output_descriptions;
|
||||
}
|
||||
|
||||
///
|
||||
/// \brief Indicate that a body parameter comes from slices of a value
|
||||
///
|
||||
/// \param parameter The parameter to receive the slices
|
||||
/// \param value The value to be sliced. This will be added as an input to
|
||||
/// SubGraphOp.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
virtual void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an initial value in the first
|
||||
/// iteration and computed value thereafter
|
||||
///
|
||||
/// \param[in] body_parameter The body parameter
|
||||
/// \param initial_value Value for the parameter in first iteration. This
|
||||
/// will be added as an input to Loop.
|
||||
/// \param successive_value Value for the parameter in successive iterations.
|
||||
/// The value is what is active in the most recent
|
||||
/// completed iteration.
|
||||
///
|
||||
virtual void set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& initial_value,
|
||||
const Output<Node>& successive_value);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an invariant value during
|
||||
/// iteration that may depend on values computed outside of the
|
||||
/// iteration.
|
||||
///
|
||||
/// \param body_parameter The body parameter
|
||||
/// \param value The value supplied as an input to the block
|
||||
///
|
||||
virtual void set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& value);
|
||||
///
|
||||
/// \brief Gets a value for a particular iteration point
|
||||
///
|
||||
/// \param body_value The value
|
||||
/// \param iteration The iteration that supplies the value. Negative values
|
||||
/// are from the last iteration.
|
||||
/// Default value -1 (the last iteration).
|
||||
///
|
||||
/// \return The iterator value.
|
||||
///
|
||||
virtual Output<Node> get_iter_value(const Output<Node>& body_value,
|
||||
int64_t iteration = -1);
|
||||
///
|
||||
/// \brief Concatenates slices from all iterations
|
||||
///
|
||||
/// \param value The value supplying slice values from each iteration.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
/// \return The concatenated slices.
|
||||
///
|
||||
virtual Output<Node> get_concatenated_slices(const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
|
||||
protected:
|
||||
// Find an input corresponding to value, adding one if necessary.
|
||||
Input<Node> input_for_value(const Output<Node>& value);
|
||||
|
||||
SubGraphOp() = default;
|
||||
|
||||
SubGraphOp(const OutputVector& args);
|
||||
explicit SubGraphOp(const OutputVector& args);
|
||||
|
||||
std::shared_ptr<Function> m_body;
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
|
||||
m_input_descriptions;
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
|
||||
m_output_descriptions;
|
||||
};
|
||||
}
|
||||
}
|
||||
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::InputDescription>;
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
|
||||
: public FactoryAttributeAdapter<op::util::SubGraphOp::InputDescription>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>
|
||||
: public VisitorAdapter
|
||||
{
|
||||
public:
|
||||
explicit AttributeAdapter(
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& m_ref;
|
||||
};
|
||||
|
||||
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::OutputDescription>;
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
|
||||
: public FactoryAttributeAdapter<op::util::SubGraphOp::OutputDescription>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>
|
||||
: public VisitorAdapter
|
||||
{
|
||||
public:
|
||||
explicit AttributeAdapter(
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& m_ref;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -83,6 +83,7 @@
|
||||
#include "ngraph/op/less_eq.hpp"
|
||||
#include "ngraph/op/log.hpp"
|
||||
#include "ngraph/op/log_softmax.hpp"
|
||||
#include "ngraph/op/loop.hpp"
|
||||
#include "ngraph/op/lrn.hpp"
|
||||
#include "ngraph/op/lstm_cell.hpp"
|
||||
#include "ngraph/op/lstm_sequence.hpp"
|
||||
|
||||
@@ -167,6 +167,7 @@ NGRAPH_OP(GatherND, ngraph::op::v5)
|
||||
NGRAPH_OP(GRUSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(HSigmoid, ngraph::op::v5)
|
||||
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
|
||||
NGRAPH_OP(Loop, ngraph::op::v5)
|
||||
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v5)
|
||||
NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
@@ -311,7 +312,7 @@ std::vector<std::shared_ptr<ngraph::Node>>
|
||||
// There is a friendly name for this node so copy it
|
||||
cloned_node->set_friendly_name(node->get_friendly_name());
|
||||
// TODO: workaround for shape inference, delete it after fix
|
||||
if (ngraph::as_type_ptr<ngraph::op::TensorIterator>(cloned_node))
|
||||
if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
|
||||
{
|
||||
cloned_node->validate_and_infer_types();
|
||||
}
|
||||
@@ -379,7 +380,7 @@ std::list<std::shared_ptr<ngraph::Node>>
|
||||
// There is a friendly name for this node so copy it
|
||||
cloned_node->set_friendly_name(node->get_friendly_name());
|
||||
// TODO: workaround for shape inference, delete it after fix
|
||||
if (ngraph::as_type_ptr<ngraph::op::TensorIterator>(cloned_node))
|
||||
if (std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(cloned_node))
|
||||
{
|
||||
cloned_node->validate_and_infer_types();
|
||||
}
|
||||
|
||||
350
ngraph/core/src/op/loop.cpp
Normal file
350
ngraph/core/src/op/loop.cpp
Normal file
@@ -0,0 +1,350 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/loop.hpp"
|
||||
#include "ngraph/factory.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "ngraph/specialize_function.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v5::Loop, "Loop", 5);
|
||||
|
||||
op::v5::Loop::Loop(const Output<Node>& trip_count, const Output<Node>& execution_condition)
|
||||
{
|
||||
set_argument(0, trip_count);
|
||||
set_argument(1, execution_condition);
|
||||
}
|
||||
|
||||
bool op::v5::Loop::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("body", m_body);
|
||||
visitor.on_attribute("input_descriptions", m_input_descriptions);
|
||||
visitor.on_attribute("output_descriptions", m_output_descriptions);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void op::v5::Loop::validate_and_infer_types()
|
||||
{
|
||||
if (m_special_body_ports.current_iteration_input_idx >= 0)
|
||||
{
|
||||
const auto& cur_iter_rank = m_body->get_parameters()
|
||||
.at(m_special_body_ports.current_iteration_input_idx)
|
||||
->get_partial_shape()
|
||||
.rank();
|
||||
if (cur_iter_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
cur_iter_rank.compatible(1) || cur_iter_rank.compatible(0),
|
||||
"Rank of CurrentIteration input must be equal to 0 or 1");
|
||||
}
|
||||
}
|
||||
bool zero_number_of_iter = false;
|
||||
const auto& loop_execution_condition = input_value(1);
|
||||
const auto& loop_condition_rank = loop_execution_condition.get_partial_shape().rank();
|
||||
if (loop_condition_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
loop_condition_rank.compatible(1) ||
|
||||
loop_condition_rank.compatible(0),
|
||||
"Rank of ExecutionCondition input must be equal to 0 or 1");
|
||||
}
|
||||
if (const auto& cond_value = std::dynamic_pointer_cast<const ngraph::opset5::Constant>(
|
||||
loop_execution_condition.get_node_shared_ptr()))
|
||||
{
|
||||
auto val = cond_value->cast_vector<bool>();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
val.size() == 1,
|
||||
"The number of values in the Condition constant is greater than 1");
|
||||
|
||||
if (!val[0])
|
||||
{
|
||||
zero_number_of_iter = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool condition_always_true = false;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_special_body_ports.body_condition_output_idx >= 0,
|
||||
"Condition body output is not provided. "
|
||||
"Condition is a mandatory output of the body in Loop op.");
|
||||
const auto& body_execution_condition =
|
||||
m_body->get_results().at(m_special_body_ports.body_condition_output_idx)->input_value(0);
|
||||
const auto& body_condition_rank = body_execution_condition.get_partial_shape().rank();
|
||||
if (body_condition_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
body_condition_rank.compatible(0) ||
|
||||
body_condition_rank.compatible(1),
|
||||
"Rank of BodyExecutionCondition output must be equal to 0 or 1");
|
||||
}
|
||||
if (const auto& cond_value = std::dynamic_pointer_cast<const ngraph::opset5::Constant>(
|
||||
body_execution_condition.get_node_shared_ptr()))
|
||||
{
|
||||
auto val = cond_value->cast_vector<bool>();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
val.size() == 1,
|
||||
"The number of values in the Condition constant is greater than 1");
|
||||
|
||||
if (val[0])
|
||||
{
|
||||
condition_always_true = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_num_iterations = 1; // condition_always_false, do_while mode
|
||||
}
|
||||
}
|
||||
|
||||
const auto& trip_count = input_value(0);
|
||||
const auto& trip_count_rank = trip_count.get_partial_shape().rank();
|
||||
if (trip_count_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
trip_count_rank.compatible(1) || trip_count_rank.compatible(0),
|
||||
"Rank of TripCount input must be equal to 0 or 1");
|
||||
}
|
||||
if (const auto& trip_count_val = std::dynamic_pointer_cast<const ngraph::opset5::Constant>(
|
||||
trip_count.get_node_shared_ptr()))
|
||||
{
|
||||
auto val = trip_count_val->cast_vector<int64_t>();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
val.size() == 1,
|
||||
"The number of values in the TripCount constant is greater than 1");
|
||||
if (condition_always_true)
|
||||
m_num_iterations = val[0];
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_size() == m_input_descriptions.size() + 2,
|
||||
"Number of inputs must be the same as number of input descriptions");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_output_size() == m_output_descriptions.size(),
|
||||
"Number of outputs must be the same as number of output descriptions");
|
||||
|
||||
std::vector<std::shared_ptr<Node>> ends;
|
||||
|
||||
// Input
|
||||
uint64_t index_it = 2;
|
||||
for (const auto& input_description : m_input_descriptions)
|
||||
{
|
||||
auto index = input_description->m_input_index;
|
||||
NODE_VALIDATION_CHECK(this, index == index_it, "Input_index not in order");
|
||||
index_it++;
|
||||
|
||||
if (auto merged_input_description = as_type_ptr<MergedInputDescription>(input_description))
|
||||
{
|
||||
auto body_value =
|
||||
m_body->get_results().at(merged_input_description->m_body_value_index);
|
||||
ends.push_back(body_value);
|
||||
|
||||
const auto& body_value_partial_shape = body_value->get_input_partial_shape(0);
|
||||
auto body_parameter =
|
||||
m_body->get_parameters().at(merged_input_description->m_body_parameter_index);
|
||||
|
||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
||||
auto input_partial_shape = input(index).get_partial_shape();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
body_value_partial_shape.compatible(body_param_partial_shape),
|
||||
"Iterator successive value is not compatible with body param");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_partial_shape.compatible(body_param_partial_shape),
|
||||
"Iterator initial value is not compatible with body param");
|
||||
|
||||
if (input_partial_shape.is_static())
|
||||
{
|
||||
auto input_shape = input_partial_shape.to_shape();
|
||||
// infer type for body_parameter
|
||||
if (body_param_partial_shape.is_dynamic())
|
||||
{
|
||||
body_parameter->set_partial_shape(input_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto invariant_input_description =
|
||||
as_type_ptr<TensorIterator::InvariantInputDescription>(input_description))
|
||||
{
|
||||
auto body_parameter =
|
||||
m_body->get_parameters().at(invariant_input_description->m_body_parameter_index);
|
||||
|
||||
auto body_param_partial_shape = body_parameter->get_partial_shape();
|
||||
auto input_partial_shape = input(index).get_partial_shape();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_partial_shape.compatible(body_param_partial_shape),
|
||||
"Iterator initial value is not compatible with body param");
|
||||
|
||||
if (input_partial_shape.is_static())
|
||||
{
|
||||
auto input_shape = input_partial_shape.to_shape();
|
||||
// infer type for m_body_parameter
|
||||
if (body_param_partial_shape.is_dynamic())
|
||||
{
|
||||
body_parameter->set_partial_shape(input_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Body
|
||||
m_body->validate_nodes_and_infer_types();
|
||||
|
||||
// Output
|
||||
index_it = 0;
|
||||
for (const auto& output_description : m_output_descriptions)
|
||||
{
|
||||
auto index = output_description->m_output_index;
|
||||
NODE_VALIDATION_CHECK(this, index == index_it, "Output_index not in order");
|
||||
index_it++;
|
||||
|
||||
auto body_value =
|
||||
m_body->get_results().at(output_description->m_body_value_index)->input_value(0);
|
||||
|
||||
if (auto concat_output_description =
|
||||
as_type_ptr<TensorIterator::ConcatOutputDescription>(output_description))
|
||||
{
|
||||
const auto& body_value_partial_shape = body_value.get_partial_shape();
|
||||
set_output_type(index, body_value.get_element_type(), PartialShape::dynamic());
|
||||
if (body_value_partial_shape.is_static())
|
||||
{
|
||||
auto body_value_shape = body_value_partial_shape.to_shape();
|
||||
auto axis = concat_output_description->m_axis;
|
||||
|
||||
Shape out_shape{body_value_shape};
|
||||
|
||||
if (body_value_shape.empty())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
axis == 0,
|
||||
"Axis must be equal to 0 if concatenated output tensor slices are scalars. "
|
||||
"Loop output index: ",
|
||||
index);
|
||||
out_shape = Shape(1);
|
||||
}
|
||||
|
||||
if (m_num_iterations != -1)
|
||||
{
|
||||
out_shape[axis] = m_num_iterations * body_value_shape[axis];
|
||||
if (zero_number_of_iter)
|
||||
{
|
||||
out_shape.at(0) = 0;
|
||||
}
|
||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto body_output_description =
|
||||
as_type_ptr<TensorIterator::BodyOutputDescription>(output_description))
|
||||
{
|
||||
const PartialShape& ps = body_value.get_partial_shape();
|
||||
if (ps.is_dynamic())
|
||||
{
|
||||
set_output_type(index, body_value.get_element_type(), ps);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto shape = ps.get_shape();
|
||||
if (zero_number_of_iter)
|
||||
{
|
||||
shape.at(0) = 0;
|
||||
}
|
||||
set_output_type(index, body_value.get_element_type(), shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::v5::Loop::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
// 0 - trip_count, 1 - execution condition, these inputs are not connected to the body params
|
||||
OutputVector body_params_args(new_args.begin() + 2, new_args.end());
|
||||
auto op = make_shared<op::v5::Loop>(new_args[0], new_args[1]);
|
||||
for (int idx = 2; idx < new_args.size(); ++idx)
|
||||
{
|
||||
op->set_argument(idx, new_args[idx]);
|
||||
}
|
||||
NGRAPH_CHECK(op.get(),
|
||||
op != nullptr,
|
||||
"Cannot clone ",
|
||||
description(),
|
||||
" operation with name ",
|
||||
get_friendly_name());
|
||||
op->set_output_size(m_output_descriptions.size());
|
||||
|
||||
std::vector<::ngraph::element::Type> types(m_body->get_parameters().size());
|
||||
std::vector<::ngraph::PartialShape> new_shapes(m_body->get_parameters().size());
|
||||
|
||||
for (size_t input_index = 0; input_index < new_args.size(); ++input_index)
|
||||
{
|
||||
for (auto& input_description : m_input_descriptions)
|
||||
{
|
||||
if (input_description->m_input_index == input_index)
|
||||
{
|
||||
types[input_description->m_body_parameter_index] =
|
||||
new_args[input_index].get_element_type();
|
||||
new_shapes[input_description->m_body_parameter_index] =
|
||||
new_args[input_index].get_partial_shape();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (m_special_body_ports.current_iteration_input_idx >= 0)
|
||||
{
|
||||
const auto& cur_iterations_param =
|
||||
m_body->get_parameters().at(m_special_body_ports.current_iteration_input_idx);
|
||||
body_params_args.insert(body_params_args.begin() +
|
||||
m_special_body_ports.current_iteration_input_idx,
|
||||
cur_iterations_param);
|
||||
new_shapes.at(m_special_body_ports.current_iteration_input_idx) =
|
||||
cur_iterations_param->get_partial_shape();
|
||||
types.at(m_special_body_ports.current_iteration_input_idx) =
|
||||
cur_iterations_param->get_element_type();
|
||||
}
|
||||
op->m_num_iterations = m_num_iterations;
|
||||
op->m_special_body_ports = m_special_body_ports;
|
||||
auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters());
|
||||
auto spec_func = specialize_function(
|
||||
func, types, new_shapes, std::vector<void*>(body_params_args.size(), nullptr));
|
||||
op->m_body = std::make_shared<Function>(spec_func->get_results(), spec_func->get_parameters());
|
||||
|
||||
for (auto& input_description : m_input_descriptions)
|
||||
{
|
||||
op->m_input_descriptions.push_back(input_description->copy());
|
||||
}
|
||||
for (auto& output_description : m_output_descriptions)
|
||||
{
|
||||
op->m_output_descriptions.push_back(output_description->copy());
|
||||
}
|
||||
return move(op);
|
||||
}
|
||||
|
||||
Output<Node> op::v5::Loop::get_concatenated_slices(const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
{
|
||||
NGRAPH_CHECK(start == 0 && stride == 1 && part_size == 1 && end == -1,
|
||||
"Invalid start, stride, part_size, or end attribute values in Loop op. "
|
||||
"Supported values for start {0}, for stride and part_size {1}, for end "
|
||||
"{-1}");
|
||||
return SubGraphOp::get_concatenated_slices(value, start, stride, part_size, end, axis);
|
||||
}
|
||||
@@ -22,285 +22,13 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
constexpr NodeTypeInfo op::v0::TensorIterator::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo op::v0::TensorIterator::SliceInputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::v0::TensorIterator::MergedInputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::v0::TensorIterator::InvariantInputDescription::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo op::v0::TensorIterator::BodyOutputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::v0::TensorIterator::ConcatOutputDescription::type_info;
|
||||
|
||||
op::v0::TensorIterator::TensorIterator(const OutputVector& values)
|
||||
: op::util::SubGraphOp(values)
|
||||
{
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::InputDescription::InputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index)
|
||||
: m_input_index(input_index)
|
||||
, m_body_parameter_index(body_parameter_index)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::InputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("input_index", m_input_index);
|
||||
visitor.on_attribute("body_parameter_index", m_body_parameter_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::SliceInputDescription::SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
, m_start(start)
|
||||
, m_stride(stride)
|
||||
, m_part_size(part_size)
|
||||
, m_end(end)
|
||||
, m_axis(axis)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<op::v0::TensorIterator::InputDescription>
|
||||
op::v0::TensorIterator::SliceInputDescription::copy() const
|
||||
{
|
||||
return make_shared<SliceInputDescription>(
|
||||
m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis);
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::SliceInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::MergedInputDescription::MergedInputDescription(
|
||||
uint64_t input_index, uint64_t body_parameter_index, uint64_t body_value_index)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
, m_body_value_index(body_value_index)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<op::v0::TensorIterator::InputDescription>
|
||||
op::v0::TensorIterator::MergedInputDescription::copy() const
|
||||
{
|
||||
return make_shared<MergedInputDescription>(
|
||||
m_input_index, m_body_parameter_index, m_body_value_index);
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::MergedInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::InvariantInputDescription::InvariantInputDescription(
|
||||
uint64_t input_index, uint64_t body_parameter_index)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<op::v0::TensorIterator::InputDescription>
|
||||
op::v0::TensorIterator::InvariantInputDescription::copy() const
|
||||
{
|
||||
return make_shared<InvariantInputDescription>(m_input_index, m_body_parameter_index);
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::InvariantInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::OutputDescription::OutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index)
|
||||
: m_body_value_index(body_value_index)
|
||||
, m_output_index(output_index)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::OutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
visitor.on_attribute("output_index", m_output_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::ConcatOutputDescription::ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
: OutputDescription(body_value_index, output_index)
|
||||
, m_start(start)
|
||||
, m_stride(stride)
|
||||
, m_part_size(part_size)
|
||||
, m_end(end)
|
||||
, m_axis(axis)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::ConcatOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<op::v0::TensorIterator::OutputDescription>
|
||||
op::v0::TensorIterator::ConcatOutputDescription::copy() const
|
||||
{
|
||||
return make_shared<ConcatOutputDescription>(
|
||||
m_body_value_index, m_output_index, m_start, m_stride, m_part_size, m_end, m_axis);
|
||||
}
|
||||
|
||||
op::v0::TensorIterator::BodyOutputDescription::BodyOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t iteration)
|
||||
: OutputDescription(body_value_index, output_index)
|
||||
, m_iteration(iteration)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<op::v0::TensorIterator::OutputDescription>
|
||||
op::v0::TensorIterator::BodyOutputDescription::copy() const
|
||||
{
|
||||
return make_shared<BodyOutputDescription>(m_body_value_index, m_output_index, m_iteration);
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::BodyOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("iteration", m_iteration);
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
template <>
|
||||
FactoryRegistry<op::v0::TensorIterator::InputDescription>&
|
||||
FactoryRegistry<op::v0::TensorIterator::InputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::v0::TensorIterator::InputDescription> registry;
|
||||
static mutex init_guard;
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
lock_guard<mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::v0::TensorIterator::SliceInputDescription>();
|
||||
registry.register_factory<op::v0::TensorIterator::MergedInputDescription>();
|
||||
registry.register_factory<op::v0::TensorIterator::InvariantInputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::TensorIterator::InputDescription>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::InputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
FactoryRegistry<op::v0::TensorIterator::OutputDescription>&
|
||||
FactoryRegistry<op::v0::TensorIterator::OutputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::v0::TensorIterator::OutputDescription> registry;
|
||||
static mutex init_guard;
|
||||
// TODO: Add a lock
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
lock_guard<mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::v0::TensorIterator::ConcatOutputDescription>();
|
||||
registry.register_factory<op::v0::TensorIterator::BodyOutputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::TensorIterator::OutputDescription>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::TensorIterator::OutputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("body", m_body);
|
||||
@@ -310,72 +38,6 @@ bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
|
||||
return false;
|
||||
}
|
||||
|
||||
Input<Node> op::v0::TensorIterator::input_for_value(const Output<Node>& value)
|
||||
{
|
||||
auto input_index = get_input_size();
|
||||
set_argument(input_index, value);
|
||||
return Input<Node>(this, input_index);
|
||||
}
|
||||
|
||||
void op::v0::TensorIterator::set_sliced_input(const std::shared_ptr<op::Parameter>& body_parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
{
|
||||
m_input_descriptions.push_back(
|
||||
make_shared<SliceInputDescription>(input_for_value(value).get_index(),
|
||||
m_body->get_parameter_index(body_parameter),
|
||||
start,
|
||||
stride,
|
||||
part_size,
|
||||
end,
|
||||
axis));
|
||||
}
|
||||
|
||||
void op::v0::TensorIterator::set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& initial_value,
|
||||
const Output<Node>& successive_value)
|
||||
{
|
||||
m_input_descriptions.push_back(
|
||||
make_shared<MergedInputDescription>(input_for_value(initial_value).get_index(),
|
||||
m_body->get_parameter_index(body_parameter),
|
||||
m_body->get_result_index(successive_value)));
|
||||
}
|
||||
|
||||
void op::v0::TensorIterator::set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& value)
|
||||
{
|
||||
m_input_descriptions.push_back(make_shared<InvariantInputDescription>(
|
||||
input_for_value(value).get_index(), m_body->get_parameter_index(body_parameter)));
|
||||
}
|
||||
|
||||
Output<Node> op::v0::TensorIterator::get_iter_value(const Output<Node>& body_value,
|
||||
int64_t iteration)
|
||||
{
|
||||
auto output_index = get_output_size();
|
||||
m_output_descriptions.push_back(make_shared<BodyOutputDescription>(
|
||||
m_body->get_result_index(body_value), output_index, iteration));
|
||||
set_output_size(output_index + 1);
|
||||
return Output<Node>(shared_from_this(), output_index);
|
||||
}
|
||||
|
||||
Output<Node> op::v0::TensorIterator::get_concatenated_slices(const Output<Node>& body_value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
{
|
||||
auto output_index = get_output_size();
|
||||
m_output_descriptions.push_back(make_shared<ConcatOutputDescription>(
|
||||
m_body->get_result_index(body_value), output_index, start, stride, part_size, end, axis));
|
||||
set_output_size(output_index + 1);
|
||||
return Output<Node>(shared_from_this(), output_index);
|
||||
}
|
||||
|
||||
void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops()
|
||||
{
|
||||
std::stack<std::shared_ptr<Node>, std::vector<std::shared_ptr<Node>>> nodes_to_do;
|
||||
@@ -669,7 +331,3 @@ std::shared_ptr<Node>
|
||||
}
|
||||
return move(op);
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
}
|
||||
|
||||
@@ -15,17 +15,345 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr DiscreteTypeInfo op::util::SubGraphOp::SliceInputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::util::SubGraphOp::MergedInputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::util::SubGraphOp::InvariantInputDescription::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo op::util::SubGraphOp::BodyOutputDescription::type_info;
|
||||
constexpr DiscreteTypeInfo op::util::SubGraphOp::ConcatOutputDescription::type_info;
|
||||
|
||||
op::util::SubGraphOp::InputDescription::InputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index)
|
||||
: m_input_index(input_index)
|
||||
, m_body_parameter_index(body_parameter_index)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::InputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("input_index", m_input_index);
|
||||
visitor.on_attribute("body_parameter_index", m_body_parameter_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::SliceInputDescription::SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
, m_start(start)
|
||||
, m_stride(stride)
|
||||
, m_part_size(part_size)
|
||||
, m_end(end)
|
||||
, m_axis(axis)
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
op::util::SubGraphOp::SliceInputDescription::copy() const
|
||||
{
|
||||
return std::make_shared<SliceInputDescription>(
|
||||
m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::SliceInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::MergedInputDescription::MergedInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
uint64_t body_value_index)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
, m_body_value_index(body_value_index)
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
op::util::SubGraphOp::MergedInputDescription::copy() const
|
||||
{
|
||||
return std::make_shared<MergedInputDescription>(
|
||||
m_input_index, m_body_parameter_index, m_body_value_index);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::MergedInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::InvariantInputDescription::InvariantInputDescription(
|
||||
uint64_t input_index, uint64_t body_parameter_index)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
op::util::SubGraphOp::InvariantInputDescription::copy() const
|
||||
{
|
||||
return std::make_shared<InvariantInputDescription>(m_input_index, m_body_parameter_index);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::InvariantInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index)
|
||||
: m_body_value_index(body_value_index)
|
||||
, m_output_index(output_index)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::OutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
visitor.on_attribute("output_index", m_output_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
: OutputDescription(body_value_index, output_index)
|
||||
, m_start(start)
|
||||
, m_stride(stride)
|
||||
, m_part_size(part_size)
|
||||
, m_end(end)
|
||||
, m_axis(axis)
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::ConcatOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::OutputDescription>
|
||||
op::util::SubGraphOp::ConcatOutputDescription::copy() const
|
||||
{
|
||||
return std::make_shared<ConcatOutputDescription>(
|
||||
m_body_value_index, m_output_index, m_start, m_stride, m_part_size, m_end, m_axis);
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::BodyOutputDescription::BodyOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t iteration)
|
||||
: OutputDescription(body_value_index, output_index)
|
||||
, m_iteration(iteration)
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::OutputDescription>
|
||||
op::util::SubGraphOp::BodyOutputDescription::copy() const
|
||||
{
|
||||
return std::make_shared<BodyOutputDescription>(m_body_value_index, m_output_index, m_iteration);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::BodyOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("iteration", m_iteration);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::SubGraphOp(const OutputVector& args)
|
||||
: Op(args)
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> op::util::SubGraphOp::get_function()
|
||||
void op::util::SubGraphOp::set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& initial_value,
|
||||
const Output<Node>& successive_value)
|
||||
{
|
||||
return nullptr;
|
||||
m_input_descriptions.push_back(std::make_shared<TensorIterator::MergedInputDescription>(
|
||||
input_for_value(initial_value).get_index(),
|
||||
m_body->get_parameter_index(body_parameter),
|
||||
m_body->get_result_index(successive_value)));
|
||||
}
|
||||
|
||||
void op::util::SubGraphOp::set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& value)
|
||||
{
|
||||
m_input_descriptions.push_back(std::make_shared<TensorIterator::InvariantInputDescription>(
|
||||
input_for_value(value).get_index(), m_body->get_parameter_index(body_parameter)));
|
||||
}
|
||||
|
||||
Output<Node> op::util::SubGraphOp::get_iter_value(const Output<Node>& body_value, int64_t iteration)
|
||||
{
|
||||
auto output_index = get_output_size();
|
||||
m_output_descriptions.push_back(std::make_shared<BodyOutputDescription>(
|
||||
m_body->get_result_index(body_value), output_index, iteration));
|
||||
set_output_size(output_index + 1);
|
||||
return Output<Node>(shared_from_this(), output_index);
|
||||
}
|
||||
|
||||
Output<Node> op::util::SubGraphOp::get_concatenated_slices(const Output<Node>& body_value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
{
|
||||
auto output_index = get_output_size();
|
||||
m_output_descriptions.push_back(std::make_shared<ConcatOutputDescription>(
|
||||
m_body->get_result_index(body_value), output_index, start, stride, part_size, end, axis));
|
||||
set_output_size(output_index + 1);
|
||||
return Output<Node>(shared_from_this(), output_index);
|
||||
}
|
||||
|
||||
void op::util::SubGraphOp::set_sliced_input(const std::shared_ptr<Parameter>& parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis)
|
||||
{
|
||||
m_input_descriptions.push_back(
|
||||
std::make_shared<SliceInputDescription>(input_for_value(value).get_index(),
|
||||
m_body->get_parameter_index(parameter),
|
||||
start,
|
||||
stride,
|
||||
part_size,
|
||||
end,
|
||||
axis));
|
||||
}
|
||||
|
||||
Input<Node> op::util::SubGraphOp::input_for_value(const Output<Node>& value)
|
||||
{
|
||||
auto input_index = get_input_size();
|
||||
set_argument(input_index, value);
|
||||
return Input<Node>(this, input_index);
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::util::SubGraphOp::InputDescription> registry;
|
||||
static std::mutex init_guard;
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::util::SubGraphOp::SliceInputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::MergedInputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::InvariantInputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
std::ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::util::SubGraphOp::OutputDescription> registry;
|
||||
static std::mutex init_guard;
|
||||
// TODO: Add a lock
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::util::SubGraphOp::ConcatOutputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::BodyOutputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
std::ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -138,6 +138,7 @@ set(SRC
|
||||
type_prop/lrn.cpp
|
||||
type_prop/lstm_cell.cpp
|
||||
type_prop/lstm_sequence.cpp
|
||||
type_prop/loop.cpp
|
||||
type_prop/matmul.cpp
|
||||
type_prop/max_pool.cpp
|
||||
type_prop/mish.cpp
|
||||
@@ -183,6 +184,7 @@ set(SRC
|
||||
type_prop/swish.cpp
|
||||
type_prop/reduce_prod.cpp
|
||||
type_prop/reduce_sum.cpp
|
||||
type_prop/ti.cpp
|
||||
type_prop/tile.cpp
|
||||
type_prop/top_k.cpp
|
||||
type_prop/transpose.cpp
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
@@ -368,3 +369,70 @@ TEST(copy, tanh)
|
||||
{
|
||||
ASSERT_TRUE(check_unary<op::Tanh>());
|
||||
}
|
||||
|
||||
TEST(copy, loop)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{}, 10);
|
||||
auto exec_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{Zo, body_condition},
|
||||
ParameterVector{Xi, current_iteration, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{1, 1});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
loop->validate_and_infer_types();
|
||||
// That which we iterate over
|
||||
auto X_new = make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5});
|
||||
auto Y_new = make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5});
|
||||
auto M_new = make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5});
|
||||
OutputVector new_args = {trip_count, exec_condition, X_new, Y_new, M_new};
|
||||
auto loop_copy = loop->clone_with_new_inputs(new_args);
|
||||
|
||||
auto node_cast = std::dynamic_pointer_cast<opset5::Loop>(loop_copy);
|
||||
ASSERT_NE(node_cast, nullptr);
|
||||
ASSERT_TRUE(nullptr != loop_copy);
|
||||
EXPECT_EQ(loop->get_num_iterations(), node_cast->get_num_iterations());
|
||||
EXPECT_EQ(loop->get_special_body_ports().body_condition_output_idx,
|
||||
node_cast->get_special_body_ports().body_condition_output_idx);
|
||||
EXPECT_EQ(loop->get_special_body_ports().current_iteration_input_idx,
|
||||
node_cast->get_special_body_ports().current_iteration_input_idx);
|
||||
ASSERT_TRUE(new_args == loop_copy->input_values());
|
||||
|
||||
loop_copy->validate_and_infer_types();
|
||||
Shape out0_shape{};
|
||||
Shape out1_shape{3, 2, 5};
|
||||
Shape out2_shape{3, 20, 5};
|
||||
EXPECT_EQ(loop_copy->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop_copy->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop_copy->get_output_shape(2), out2_shape);
|
||||
}
|
||||
753
ngraph/test/type_prop/loop.cpp
Normal file
753
ngraph/test/type_prop/loop.cpp
Normal file
@@ -0,0 +1,753 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition = true
|
||||
// all shapes are static, 10 iterations will be executed
|
||||
TEST(type_prop, loop_operation_for_mode_10_iter_static_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
Shape out2_shape{32, 10, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition = false
|
||||
// will be executed only 1 iteration, all shapes are static
|
||||
TEST(type_prop, loop_operation_dowhile_mode_1_iter_static_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, false);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
Shape out2_shape{32, 1, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition is not a Constant
|
||||
// concat output is not provided, another outputs will be static
|
||||
TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_static_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto condition_const =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1}, 10);
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Greater>(M_body, condition_const);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{1};
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
}
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition is not a Constant
|
||||
// concat output will be dynamic, another outputs are static
|
||||
TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto condition_const =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1}, 10);
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Greater>(M_body, condition_const);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{1};
|
||||
PartialShape out2_shape{PartialShape::dynamic()};
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_partial_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// trip_count = -1
|
||||
// execution_condition = true
|
||||
// body_condition = true
|
||||
// concat output will be dynamic, another outputs are static
|
||||
TEST(type_prop, loop_operation_infinite_loop_mode_dynamic_iter_dynamic_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, -1);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
PartialShape out2_shape{PartialShape::dynamic()};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_partial_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// SpecialBodyPorts (1, 1) <- test specific
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition = true
|
||||
// all shapes are static, 10 iterations will be executed
|
||||
TEST(type_prop, loop_operation_for_mode_10_iter_static_shapes_special_body_ports)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{Zo, body_condition},
|
||||
ParameterVector{Xi, current_iteration, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{1, 1});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
Shape out2_shape{32, 10, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// Scalars instead of 1d tensors with 1 element <-- test specific
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition = true
|
||||
// all shapes are static, 10 iterations will be executed
|
||||
TEST(type_prop, loop_operation_for_mode_10_iter_static_shapes_special_body_ports_scalars)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{32, 1, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{}, 10);
|
||||
auto exec_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{Zo, body_condition},
|
||||
ParameterVector{Xi, current_iteration, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{1, 1});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
Shape out2_shape{32, 10, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_shape(2), out2_shape);
|
||||
}
|
||||
204
ngraph/test/type_prop/ti.cpp
Normal file
204
ngraph/test/type_prop/ti.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, tensor_iterator_lstm)
|
||||
{
|
||||
// That which we iterate over
|
||||
const size_t N = 32; // Batch size
|
||||
const size_t L = 10; // Sequence length
|
||||
const size_t I = 8; // Input size
|
||||
const size_t H = 32; // Hidden size
|
||||
auto SENT = make_shared<op::Parameter>(element::f32, Shape{N, L, I});
|
||||
|
||||
auto H_init = make_shared<op::Parameter>(element::f32, Shape{N, 1, H});
|
||||
auto C_init = make_shared<op::Parameter>(element::f32, Shape{N, 1, H});
|
||||
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{4 * H, I});
|
||||
auto R = make_shared<op::Parameter>(element::f32, Shape{4 * H, H});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, Shape{N, 1, H});
|
||||
auto C_t = make_shared<op::Parameter>(element::f32, Shape{N, 1, H});
|
||||
|
||||
// Body
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{N, 1, I});
|
||||
auto W_body = make_shared<op::Parameter>(element::f32, Shape{4 * H, I});
|
||||
auto R_body = make_shared<op::Parameter>(element::f32, Shape{4 * H, H});
|
||||
auto LSTM_cell = make_shared<opset5::LSTMCell>(
|
||||
make_shared<op::Reshape>(X, AxisVector{0, 1, 2}, Shape{N, I}),
|
||||
make_shared<op::Reshape>(H_t, AxisVector{0, 1, 2}, Shape{N, H}),
|
||||
make_shared<op::Reshape>(C_t, AxisVector{0, 1, 2}, Shape{N, H}),
|
||||
W_body,
|
||||
R_body,
|
||||
H);
|
||||
auto H_o = make_shared<op::Reshape>(LSTM_cell->output(0), AxisVector{0, 1}, Shape{N, 1, H});
|
||||
auto C_o = make_shared<op::Reshape>(LSTM_cell->output(1), AxisVector{0, 1}, Shape{N, 1, H});
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{H_o, C_o},
|
||||
ParameterVector{X, H_t, C_t, W_body, R_body});
|
||||
|
||||
auto tensor_iterator = make_shared<op::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
// start=0, stride=1, part_size=1, end=39, axis=1
|
||||
tensor_iterator->set_sliced_input(X, SENT, 0, 1, 1, -1, 1);
|
||||
// H_t is Hinit on the first iteration, Ho after that
|
||||
tensor_iterator->set_merged_input(H_t, H_init, H_o);
|
||||
tensor_iterator->set_merged_input(C_t, C_init, C_o);
|
||||
tensor_iterator->set_invariant_input(W_body, W);
|
||||
tensor_iterator->set_invariant_input(R_body, R);
|
||||
|
||||
// Output 0 is last Ho, result 0 of body
|
||||
auto out0 = tensor_iterator->get_iter_value(H_o, -1);
|
||||
// Output 1 is last Co, result 1 of body
|
||||
auto out1 = tensor_iterator->get_iter_value(C_o, -1);
|
||||
|
||||
auto results = ResultVector{make_shared<op::Result>(out0), make_shared<op::Result>(out1)};
|
||||
auto f = make_shared<Function>(results, ParameterVector{SENT, H_init, C_init, W, R});
|
||||
}
|
||||
|
||||
TEST(type_prop, tensor_iterator_2_slice_inputs_part_size_2)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{32, 40, 10});
|
||||
auto Y = make_shared<op::Parameter>(element::f32, Shape{32, 40, 10});
|
||||
auto M = make_shared<op::Parameter>(element::f32, Shape{32, 2, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto Xi = make_shared<op::Parameter>(element::f32, Shape{32, 2, 10});
|
||||
auto Yi = make_shared<op::Parameter>(element::f32, Shape{32, 2, 10});
|
||||
auto M_body = make_shared<op::Parameter>(element::f32, Shape{32, 2, 10});
|
||||
|
||||
// Body
|
||||
auto Zo = (Xi + Yi) * M_body;
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{Zo}, ParameterVector{Xi, Yi, M_body});
|
||||
|
||||
auto tensor_iterator = make_shared<op::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
// The Xi are the elements of Xseq
|
||||
// start=0, stride=2, part_size=2, end=39, axis=1
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 39, 1);
|
||||
// The Yi are the elements of Yseq
|
||||
// start=0, stride=2, part_size=2, end=-1, axis=1
|
||||
tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -1, 1);
|
||||
tensor_iterator->set_invariant_input(M_body, M);
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = tensor_iterator->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=2, part_size=2, end=39, axis=1
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 39, 1);
|
||||
|
||||
auto result0 = make_shared<op::Result>(out0);
|
||||
auto result1 = make_shared<op::Result>(out1);
|
||||
Shape out0_shape{32, 2, 10};
|
||||
Shape out1_shape{32, 40, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, tensor_iterator_2_slice_inputs_part_size_2_dynamic)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{32, 40, 10});
|
||||
auto Y = make_shared<op::Parameter>(element::f32, Shape{32, 40, 10});
|
||||
auto M = make_shared<op::Parameter>(element::f32, Shape{32, 2, 10});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto Xi = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
// Body
|
||||
auto Zo = (Xi + Yi) * M_body;
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{Zo}, ParameterVector{Xi, Yi, M_body});
|
||||
|
||||
auto tensor_iterator = make_shared<op::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
// The Xi are the elements of Xseq
|
||||
// start=0, stride=2, part_size=2, end=38, axis=1
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 38, 1);
|
||||
// The Yi are the elements of Yseq
|
||||
// start=0, stride=2, part_size=2, end=-2, axis=1
|
||||
tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -2, 1);
|
||||
tensor_iterator->set_invariant_input(M_body, M);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : tensor_iterator->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::op::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc = as_type_ptr<ngraph::op::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc = as_type_ptr<ngraph::op::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = tensor_iterator->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=2, part_size=2, end=38, axis=1
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 38, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : tensor_iterator->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::op::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc = as_type_ptr<ngraph::op::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<op::Result>(out0);
|
||||
auto result1 = make_shared<op::Result>(out1);
|
||||
Shape out0_shape{32, 2, 10};
|
||||
Shape out1_shape{32, 38, 10};
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
|
||||
EXPECT_EQ(body->get_results()[0]->get_output_shape(0), out0_shape);
|
||||
}
|
||||
Reference in New Issue
Block a user