Resolved problems with ssd_resnet34_mlperf_opset10 (#3487)
* Resolved problems with ssd_resnet34_1200 * removed debug code * Added correct handling onnx nodes from parent graph scope * removed unnecessary include * fixed calcution index to replace * fixed LoopParentParametersUsedInBody test * added set_friendly_name * apply Unsqueeze for each concatenated Loop output * added handling trip count with value max_int * merge from upstream/master * update xfail list * added checking is trip_count is constant
This commit is contained in:
parent
c6bfac6e05
commit
0b05653d7a
@ -256,4 +256,134 @@ TEST(SmartReshapeTests, LoopDynamicParameters) {
|
||||
// concat output
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({32, 10, 10}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({32, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SmartReshapeTests, LoopParentParametersUsedInBody) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto add_Y = std::make_shared<opset5::Add>(Y,
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{0.f}));
|
||||
auto M = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
X->set_friendly_name("X");
|
||||
Y->set_friendly_name("Y");
|
||||
M->set_friendly_name("M");
|
||||
|
||||
// Set up the cell body, a function from (Xi, add_Y) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = std::make_shared<opset5::Parameter>(element::i64, Shape{});
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{}, 10);
|
||||
auto exec_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
// Body
|
||||
auto sum = std::make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = std::make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = std::make_shared<ngraph::Function>(OutputVector{Zo, body_condition, sum},
|
||||
ParameterVector{Xi, current_iteration, Yi, M_body});
|
||||
|
||||
auto loop = std::make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{1, 1});
|
||||
|
||||
loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 2);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
// Set invariant input which uses parameter from parent graph
|
||||
loop->set_invariant_input(Yi, add_Y);
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
auto out3 = loop->get_iter_value(sum, -1);
|
||||
|
||||
f = std::make_shared<Function>(OutputVector{out0, out1, out2, out3}, ParameterVector{X, Y, M});
|
||||
}
|
||||
|
||||
InferenceEngine::CNNNetwork network(f);
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible(PartialShape::dynamic()));
|
||||
// concat output
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible(PartialShape::dynamic()));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible(PartialShape::dynamic()));
|
||||
|
||||
ASSERT_NO_THROW(network.reshape({{"X", {4, 3, 2}}, {"Y", {4, 3, 2}}, {"M", {4, 3, 2}}}));
|
||||
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({4, 3, 2}));
|
||||
// concat output
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({4, 30, 2}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({4, 3, 2}));
|
||||
}
|
||||
|
||||
TEST(SmartReshapeTests, TensorIteratorParentParameterUsedInBody) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 1});
|
||||
auto add_Y = std::make_shared<opset5::Add>(Y,
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{}, std::vector<float>{0.f}));
|
||||
auto M = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 1});
|
||||
X->set_friendly_name("X");
|
||||
Y->set_friendly_name("Y");
|
||||
M->set_friendly_name("M");
|
||||
|
||||
// Set up the cell body, a function from (Xi, add_Y) -> (Zo)
|
||||
// Body parameters
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = std::make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
|
||||
// Body
|
||||
auto sum = std::make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = std::make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = std::make_shared<ngraph::Function>(OutputVector{Zo, body_condition, sum},
|
||||
ParameterVector{Xi, Yi, M_body});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_function(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 2);
|
||||
tensor_iterator->set_merged_input(M_body, M, Zo);
|
||||
// Set invariant input which uses parameter from parent graph
|
||||
tensor_iterator->set_invariant_input(Yi, add_Y);
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = tensor_iterator->get_iter_value(body_condition, -1);
|
||||
auto out1 = tensor_iterator->get_iter_value(Zo, -1);
|
||||
// Output 1 is concat of Zos
|
||||
// start=0, stride=1, part_size=1, end=-1, axis=1
|
||||
auto out2 = tensor_iterator->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
auto out3 = tensor_iterator->get_iter_value(sum, -1);
|
||||
|
||||
f = std::make_shared<Function>(OutputVector{out0, out1, out2, out3}, ParameterVector{X, Y, M});
|
||||
}
|
||||
|
||||
InferenceEngine::CNNNetwork network(f);
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({1, 1, 1}));
|
||||
// concat output (seq len = 1, so it means num_iter = 1)
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({1, 1, 1}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({1, 1, 1}));
|
||||
|
||||
ASSERT_NO_THROW(network.reshape({{"X", {32, 1, 10}}, {"Y", {1, 1, 1}}, {"M", {32, 1, 10}}}));
|
||||
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[0]->get_output_partial_shape(0).compatible({}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[1]->get_output_partial_shape(0).compatible({32, 1, 10}));
|
||||
// concat output
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[2]->get_output_partial_shape(0).compatible({32, 10, 10}));
|
||||
ASSERT_TRUE(network.getFunction()->get_results()[3]->get_output_partial_shape(0).compatible({32, 1, 1}));
|
||||
}
|
||||
|
@ -247,15 +247,22 @@ void op::v5::Loop::validate_and_infer_types()
|
||||
as_type_ptr<TensorIterator::ConcatOutputDescription>(output_description))
|
||||
{
|
||||
const auto& body_value_partial_shape = body_value.get_partial_shape();
|
||||
set_output_type(index, body_value.get_element_type(), PartialShape::dynamic());
|
||||
if (body_value_partial_shape.is_static())
|
||||
if (body_value_partial_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(index, body_value.get_element_type(), PartialShape::dynamic());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto body_value_shape = body_value_partial_shape.to_shape();
|
||||
auto axis = concat_output_description->m_axis;
|
||||
|
||||
Shape out_shape{body_value_shape};
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis < body_value_partial_shape.rank().get_length(),
|
||||
"Concatenation axis must be less than sliced output rank");
|
||||
|
||||
if (body_value_shape.empty())
|
||||
PartialShape out_shape{body_value_partial_shape};
|
||||
|
||||
if (body_value_partial_shape.is_static() &&
|
||||
ngraph::is_scalar(body_value_partial_shape.to_shape()))
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@ -266,23 +273,23 @@ void op::v5::Loop::validate_and_infer_types()
|
||||
out_shape = Shape(1);
|
||||
}
|
||||
|
||||
if (m_num_iterations != -1)
|
||||
if (m_num_iterations != -1 && body_value_partial_shape[axis].is_static())
|
||||
{
|
||||
out_shape[axis] = m_num_iterations * body_value_shape[axis];
|
||||
out_shape[axis] =
|
||||
m_num_iterations * body_value_partial_shape[axis].get_length();
|
||||
if (zero_number_of_iter)
|
||||
{
|
||||
out_shape.at(0) = 0;
|
||||
out_shape[0] = 0;
|
||||
}
|
||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(index,
|
||||
body_value.get_element_type(),
|
||||
PartialShape::dynamic(body_value.get_partial_shape().rank()));
|
||||
else
|
||||
{
|
||||
out_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(index, body_value.get_element_type(), out_shape);
|
||||
}
|
||||
}
|
||||
|
||||
else if (auto body_output_description =
|
||||
as_type_ptr<TensorIterator::BodyOutputDescription>(output_description))
|
||||
{
|
||||
|
@ -67,10 +67,10 @@ namespace ngraph
|
||||
|
||||
protected:
|
||||
ParameterVector m_parameters;
|
||||
std::unique_ptr<GraphCache> m_cache;
|
||||
|
||||
private:
|
||||
const ONNX_NAMESPACE::GraphProto* m_graph_proto;
|
||||
std::unique_ptr<GraphCache> m_cache;
|
||||
std::vector<Node> m_nodes;
|
||||
std::vector<ValueInfo> m_inputs;
|
||||
std::vector<ValueInfo> m_outputs;
|
||||
@ -91,6 +91,13 @@ namespace ngraph
|
||||
Subgraph(const ONNX_NAMESPACE::GraphProto& proto,
|
||||
Model& model,
|
||||
const Graph& parent_graph);
|
||||
|
||||
/// \brief Return outputs which are on the edge the subgraph and the parent graph.
|
||||
/// \return Vector of edge nodes from parent scope.
|
||||
const std::vector<Output<ngraph::Node>> get_outputs_from_parent() const;
|
||||
|
||||
private:
|
||||
std::vector<Output<ngraph::Node>> m_outputs_from_parent;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
|
||||
|
@ -25,6 +25,17 @@ namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
/// \brief Enum which determines scope (visibility) of nodes in GraphCache.
|
||||
enum class NodeScope
|
||||
{
|
||||
// in parent graph scope
|
||||
ParentGraph = 1,
|
||||
// in subgraph scope
|
||||
SubGraph,
|
||||
// not available at all
|
||||
Lack
|
||||
};
|
||||
|
||||
/// \brief GraphCache stores and provides access to ONNX graph initializers.
|
||||
class GraphCache
|
||||
{
|
||||
@ -53,6 +64,16 @@ namespace ngraph
|
||||
/// \return true if the node named `name` exist in the cache, false otherwise.
|
||||
virtual bool contains(const std::string& name) const;
|
||||
|
||||
/// \brief Return NodeScope enum which determines scope of the node.
|
||||
/// \note If the method is called on GraphCache the ParentGraph enum
|
||||
/// value is retunred always.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return SubGraph if node belongs to SubgraphCache, ParentGraph if
|
||||
/// is avalible in parent_graph_cache, otherwise Lack
|
||||
virtual NodeScope node_scope(const std::string& name) const;
|
||||
|
||||
private:
|
||||
std::map<std::string, Output<ngraph::Node>> m_graph_cache_map;
|
||||
};
|
||||
@ -82,6 +103,14 @@ namespace ngraph
|
||||
/// (subgraph or parent graph), false otherwise.
|
||||
bool contains(const std::string& name) const override;
|
||||
|
||||
/// \brief Return NodeScope enum which determines scope of the node.
|
||||
///
|
||||
/// \param[in] name The name of the node.
|
||||
///
|
||||
/// \return SubGraph if the node belongs to SubgraphCache, ParentGraph if
|
||||
/// is avalible in parent_graph_cache, otherwise Lack
|
||||
NodeScope node_scope(const std::string& name) const override;
|
||||
|
||||
private:
|
||||
const GraphCache* m_parent_graph_cache;
|
||||
};
|
||||
|
@ -315,39 +315,48 @@ namespace ngraph
|
||||
model,
|
||||
std::unique_ptr<SubgraphCache>(new SubgraphCache(parent_graph.get_graph_cache())))
|
||||
{
|
||||
std::vector<std::shared_ptr<ngraph::Node>> subgraph_root_nodes;
|
||||
const auto& outputs = as_result_vector(get_ng_outputs());
|
||||
for (auto& out : outputs)
|
||||
// find all nodes on edge parent graph-subgraph
|
||||
// (it means input of node from parent graph, output from subgraph)
|
||||
for (const auto& node_proto : proto.node())
|
||||
{
|
||||
subgraph_root_nodes.push_back(out);
|
||||
}
|
||||
const auto& params = get_ng_parameters();
|
||||
for (auto& param : params)
|
||||
{
|
||||
subgraph_root_nodes.push_back(param);
|
||||
}
|
||||
const auto subgraph_nodes = topological_sort(subgraph_root_nodes);
|
||||
|
||||
const auto& parent_graph_parameters = parent_graph.get_ng_parameters();
|
||||
for (const auto& node : subgraph_nodes)
|
||||
{
|
||||
if (op::is_parameter(node))
|
||||
int input_index = 0;
|
||||
for (const auto& in_name : node_proto.input())
|
||||
{
|
||||
const auto sub_it = std::find(m_parameters.begin(), m_parameters.end(), node);
|
||||
// not present as subgraph parameter
|
||||
if (sub_it == m_parameters.end())
|
||||
if (m_cache->node_scope(in_name) == NodeScope::ParentGraph)
|
||||
{
|
||||
const auto parent_it = std::find(
|
||||
parent_graph_parameters.begin(), parent_graph_parameters.end(), node);
|
||||
if (parent_it != m_parameters.end())
|
||||
const auto& from_parent_node = m_cache->get_node(in_name);
|
||||
// constants are skipped
|
||||
if (!ngraph::is_type<ngraph::op::Constant>(
|
||||
from_parent_node.get_node_shared_ptr()))
|
||||
{
|
||||
m_parameters.push_back(*parent_it);
|
||||
for (const auto& out_name : node_proto.output())
|
||||
{
|
||||
if (m_cache->node_scope(out_name) == NodeScope::SubGraph)
|
||||
{
|
||||
auto out_node_to_replace_input = m_cache->get_node(out_name);
|
||||
auto new_param = std::make_shared<ngraph::op::Parameter>(
|
||||
from_parent_node.get_element_type(),
|
||||
from_parent_node.get_partial_shape());
|
||||
// replace input from parent scope with parameter
|
||||
out_node_to_replace_input.get_node()
|
||||
->input(input_index)
|
||||
.replace_source_output(new_param);
|
||||
m_parameters.push_back(new_param);
|
||||
m_outputs_from_parent.push_back(from_parent_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
++input_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<Output<ngraph::Node>> Subgraph::get_outputs_from_parent() const
|
||||
{
|
||||
return m_outputs_from_parent;
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -43,6 +43,11 @@ namespace ngraph
|
||||
return (m_graph_cache_map.count(name) > 0);
|
||||
}
|
||||
|
||||
NodeScope GraphCache::node_scope(const std::string& name) const
|
||||
{
|
||||
return contains(name) ? NodeScope::ParentGraph : NodeScope::Lack;
|
||||
}
|
||||
|
||||
SubgraphCache::SubgraphCache(const GraphCache& parent_graph_cache)
|
||||
: m_parent_graph_cache{&parent_graph_cache}
|
||||
{
|
||||
@ -71,5 +76,21 @@ namespace ngraph
|
||||
return GraphCache::contains(name) || m_parent_graph_cache->contains(name);
|
||||
}
|
||||
|
||||
NodeScope SubgraphCache::node_scope(const std::string& name) const
|
||||
{
|
||||
if (GraphCache::contains(name))
|
||||
{
|
||||
return NodeScope::SubGraph;
|
||||
}
|
||||
else if (m_parent_graph_cache->contains(name))
|
||||
{
|
||||
return NodeScope::ParentGraph;
|
||||
}
|
||||
else
|
||||
{
|
||||
return NodeScope::Lack;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -87,7 +87,12 @@ namespace ngraph
|
||||
|
||||
// optional inputs
|
||||
Output<ngraph::Node> trip_count;
|
||||
if (ngraph::op::is_null(ng_inputs.at(0))) // trip count skipped
|
||||
// trip count skipped or has value max(int64_t) means infinitive loop
|
||||
if (ngraph::op::is_null(ng_inputs.at(0)) ||
|
||||
(ngraph::op::is_constant(ng_inputs.at(0).get_node_shared_ptr()) &&
|
||||
as_type_ptr<default_opset::Constant>(ng_inputs.at(0).get_node_shared_ptr())
|
||||
->cast_vector<int64_t>()[0] ==
|
||||
std::numeric_limits<int64_t>::max()))
|
||||
{
|
||||
// -1 means infinite Loop
|
||||
trip_count = ngraph::op::Constant::create(ngraph::element::i64, {1}, {-1});
|
||||
@ -132,17 +137,13 @@ namespace ngraph
|
||||
const int64_t concat_axis = 0;
|
||||
const auto concat_axis_const =
|
||||
ngraph::op::Constant::create(ngraph::element::i64, {1}, {concat_axis});
|
||||
// provide scalar handing for scan outputs
|
||||
// add dimension along which scan outputs will be concatenated
|
||||
for (size_t i = loop_carried_dependencies.size() + 1; i < body_outputs.size();
|
||||
++i)
|
||||
{
|
||||
auto body_output_shape = body_outputs[i].get_partial_shape();
|
||||
if (body_output_shape.is_static() &&
|
||||
ngraph::is_scalar(body_output_shape.to_shape()))
|
||||
{
|
||||
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(
|
||||
body_outputs[i], concat_axis_const);
|
||||
}
|
||||
const auto& body_output_shape = body_outputs[i].get_partial_shape();
|
||||
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(
|
||||
body_outputs[i], concat_axis_const);
|
||||
}
|
||||
|
||||
const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr();
|
||||
@ -193,6 +194,22 @@ namespace ngraph
|
||||
final_values.push_back(loop->get_iter_value(*body_outputs_it++, -1));
|
||||
}
|
||||
|
||||
const auto& outputs_from_parent = body_graph.get_outputs_from_parent();
|
||||
CHECK_VALID_NODE(node,
|
||||
std::distance(body_inputs_it, body_inputs.end()) ==
|
||||
outputs_from_parent.size(),
|
||||
"Expected number of invariant parameters is"
|
||||
" not equal number of provided outputs from parent scope");
|
||||
|
||||
// Set-up parameters from parent graph which are not changed during Loop's
|
||||
// iterations
|
||||
for (auto out_from_parent_it = outputs_from_parent.begin();
|
||||
body_inputs_it != body_inputs.end();
|
||||
++body_inputs_it, ++out_from_parent_it)
|
||||
{
|
||||
loop->set_invariant_input(*body_inputs_it, *out_from_parent_it);
|
||||
}
|
||||
|
||||
// Set-up scan outputs
|
||||
OutputVector scan_outputs;
|
||||
for (; body_outputs_it != body_outputs.end(); body_outputs_it++)
|
||||
|
@ -29,16 +29,6 @@
|
||||
|
||||
namespace
|
||||
{
|
||||
/// \return Parse node attribute value for axis and adjust for negative value if needed.
|
||||
std::int64_t get_axis(const ngraph::onnx_import::Node& node)
|
||||
{
|
||||
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
|
||||
|
||||
const auto data = node.get_ng_inputs().at(0);
|
||||
const auto data_rank = data.get_partial_shape().rank();
|
||||
return ngraph::normalize_axis(node.get_description(), axis, data_rank);
|
||||
}
|
||||
|
||||
/// \return Return the second input to the TopK node reshaped to a scalar.
|
||||
ngraph::Output<ngraph::Node> get_k(const ngraph::onnx_import::Node& node)
|
||||
{
|
||||
@ -64,7 +54,7 @@ namespace ngraph
|
||||
auto data = node.get_ng_inputs().at(0);
|
||||
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
|
||||
auto k_node = default_opset::Constant::create(element::i64, Shape{}, {k});
|
||||
auto axis = get_axis(node);
|
||||
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
|
||||
|
||||
std::shared_ptr<ngraph::Node> top_k = std::make_shared<default_opset::TopK>(
|
||||
data,
|
||||
@ -84,7 +74,7 @@ namespace ngraph
|
||||
{
|
||||
auto data = node.get_ng_inputs().at(0);
|
||||
auto k = get_k(node);
|
||||
auto axis = get_axis(node);
|
||||
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
|
||||
|
||||
std::shared_ptr<ngraph::Node> top_k = std::make_shared<default_opset::TopK>(
|
||||
data,
|
||||
@ -107,7 +97,7 @@ namespace ngraph
|
||||
auto k = get_k(node);
|
||||
|
||||
// Process attributes
|
||||
const auto axis = get_axis(node);
|
||||
const std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
|
||||
const auto largest = node.get_attribute_value<std::int64_t>("largest", 1);
|
||||
const auto sorted = node.get_attribute_value<std::int64_t>("sorted", 1);
|
||||
|
||||
|
@ -137,10 +137,8 @@ xfail_issue_38714 = xfail_test(reason="RuntimeError: While validating ONNX node
|
||||
"Argument element types are inconsistent.")
|
||||
xfail_issue_43742 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"If")
|
||||
xfail_issue_43439 = xfail_test(reason="Check 'tensor_rank.is_static()' failed at "
|
||||
"ngraph/core/src/validation_util.cpp:884:"
|
||||
"map_1/while/select_bboxes/sort_bboxes_10/TopKV2 "
|
||||
"Rank must be static in order to normalize negative axis=-1")
|
||||
xfail_issue_45457 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v5::Loop"
|
||||
"Not constant termination condition body output is not supported")
|
||||
xfail_issue_38715 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(OneHot): y>':"
|
||||
"While validating node 'v1::OneHot OneHot_<number>"
|
||||
"(Convert_13525[0]:i64{3}, depth[0]:f32{},"
|
||||
|
@ -29,7 +29,7 @@ from tests import (
|
||||
xfail_issue_38701,
|
||||
xfail_issue_43742,
|
||||
xfail_issue_43380,
|
||||
xfail_issue_43439,
|
||||
xfail_issue_45457,
|
||||
xfail_issue_39684,
|
||||
xfail_issue_40957,
|
||||
xfail_issue_39685,
|
||||
@ -152,7 +152,6 @@ if len(zoo_models) > 0:
|
||||
|
||||
# Model MSFT
|
||||
(xfail_issue_43742, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"),
|
||||
(xfail_issue_43439, "test_MSFT_opset10_mlperf_ssd_resnet34_1200_ssd_resnet34_mAP_20.2_cpu"),
|
||||
(xfail_issue_37957, "test_MSFT_opset10_mask_rcnn_keras_mask_rcnn_keras_cpu"),
|
||||
]
|
||||
for test_case in import_xfail_list:
|
||||
@ -178,6 +177,7 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_43380, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov3_model_tiny_yolov3_11_yolov3_tiny_cpu"),
|
||||
(xfail_issue_45457, "test_MSFT_opset10_mlperf_ssd_resnet34_1200_ssd_resnet34_mAP_20.2_cpu"),
|
||||
|
||||
# Model MSFT
|
||||
(xfail_issue_37973, "test_MSFT_opset7_tf_inception_v2_model_cpu"),
|
||||
|
@ -0,0 +1,166 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "basic loop"
|
||||
node {
|
||||
input: "trip_count"
|
||||
input: ""
|
||||
input: "a_init"
|
||||
output: "a_final"
|
||||
output: "a_values"
|
||||
op_type: "Loop"
|
||||
attribute {
|
||||
name: "body"
|
||||
g {
|
||||
node {
|
||||
input: "a_in"
|
||||
input: "b"
|
||||
output: "current_a"
|
||||
name: "loop_body_add"
|
||||
op_type: "Add"
|
||||
}
|
||||
node {
|
||||
input: "cond_in"
|
||||
output: "cond_out"
|
||||
name: "cond_identity"
|
||||
op_type: "Identity"
|
||||
}
|
||||
node {
|
||||
input: "current_a"
|
||||
output: "a_out"
|
||||
name: "output_accumulator"
|
||||
op_type: "Identity"
|
||||
}
|
||||
name: "simple add"
|
||||
input {
|
||||
name: "i"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "cond_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "a_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cond_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "current_a"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
name: "trip_count"
|
||||
}
|
||||
input {
|
||||
name: "a_init"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "b"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_final"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -0,0 +1,177 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "basic loop"
|
||||
node {
|
||||
input: "trip_count"
|
||||
input: "cond_in"
|
||||
input: "a_init"
|
||||
output: "a_final"
|
||||
output: "a_values"
|
||||
op_type: "Loop"
|
||||
attribute {
|
||||
name: "body"
|
||||
g {
|
||||
node {
|
||||
input: "a_in"
|
||||
input: "b"
|
||||
output: "current_a"
|
||||
name: "loop_body_add"
|
||||
op_type: "Add"
|
||||
}
|
||||
node {
|
||||
input: "i"
|
||||
input: "threshold"
|
||||
output: "cond_out"
|
||||
name: "condition_calc"
|
||||
op_type: "Less"
|
||||
}
|
||||
node {
|
||||
input: "current_a"
|
||||
output: "a_out"
|
||||
name: "output_accumulator"
|
||||
op_type: "Identity"
|
||||
}
|
||||
name: "simple add"
|
||||
initializer {
|
||||
dims: 1
|
||||
dims: 2
|
||||
data_type: 1
|
||||
float_data: 1
|
||||
float_data: 1
|
||||
name: "b"
|
||||
}
|
||||
input {
|
||||
name: "i"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "cond"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "a_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cond_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "current_a"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 5
|
||||
name: "threshold"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 9223372036854775807
|
||||
name: "trip_count"
|
||||
}
|
||||
input {
|
||||
name: "cond_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "a_init"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_final"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -0,0 +1,181 @@
|
||||
ir_version: 6
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
name: "basic loop"
|
||||
node {
|
||||
input: "parent_input"
|
||||
input: "scale"
|
||||
name: "mul_node"
|
||||
op_type: "Mul"
|
||||
output: "b"
|
||||
}
|
||||
node {
|
||||
input: "parent_input"
|
||||
input: "b"
|
||||
name: "parent_add_node"
|
||||
op_type: "Add"
|
||||
output: "c"
|
||||
}
|
||||
node {
|
||||
input: "trip_count"
|
||||
input: "cond_in"
|
||||
input: "a_init"
|
||||
output: "a_final"
|
||||
output: "a_values"
|
||||
op_type: "Loop"
|
||||
attribute {
|
||||
name: "body"
|
||||
g {
|
||||
name: "simple add"
|
||||
node {
|
||||
input: "b"
|
||||
input: "a_in"
|
||||
output: "current_a"
|
||||
name: "loop_body_add"
|
||||
op_type: "Add"
|
||||
}
|
||||
node {
|
||||
input: "cond"
|
||||
output: "cond_out"
|
||||
name: "cond_identity"
|
||||
op_type: "Identity"
|
||||
}
|
||||
node {
|
||||
input: "current_a"
|
||||
output: "a_out"
|
||||
name: "output_accumulator"
|
||||
op_type: "Identity"
|
||||
}
|
||||
input {
|
||||
name: "i"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "cond"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "a_in"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "cond_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "current_a"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_out"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type: GRAPH
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
name: "trip_count"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 9
|
||||
int32_data: 00000001
|
||||
name: "cond_in"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 1
|
||||
float_data: 2
|
||||
name: "scale"
|
||||
}
|
||||
|
||||
input {
|
||||
name: "a_init"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "parent_input"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_final"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "a_values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "c"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
ir_version: 5
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "k"
|
||||
output: "values"
|
||||
output: "indices"
|
||||
op_type: "TopK"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: -1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "largest"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "sorted"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_top_k"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "k"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
name: "k"
|
||||
}
|
||||
output {
|
||||
name: "values"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -2308,6 +2308,20 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_top_k_opset_11_const_k_smallest)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_top_k_opset_11_const_k_smallest_negative_axis)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/top_k_opset_11_const_k_smallest_negative_axis.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({0, 1, 2, 3, 4, 5, 6, 7, 11, 10, 9, 8});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{3, 3}, {0, 1, 2, 4, 5, 6, 8, 9, 10}); // values
|
||||
test_case.add_expected_output<std::int64_t>(Shape{3, 3},
|
||||
{0, 1, 2, 0, 1, 2, 3, 2, 1}); // indices
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_acosh)
|
||||
{
|
||||
auto function =
|
||||
|
@ -60,14 +60,14 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add)
|
||||
EXPECT_EQ(function->get_output_shape(0), (Shape{1, 2}));
|
||||
EXPECT_EQ(function->get_output_element_type(1), ngraph::element::f32);
|
||||
EXPECT_TRUE(function->get_output_partial_shape(1).is_static());
|
||||
EXPECT_EQ(function->get_output_shape(1), (Shape{3, 2}));
|
||||
EXPECT_EQ(function->get_output_shape(1), (Shape{3, 1, 2}));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -89,7 +89,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_no_identity_termination_co
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {6.f, 6.f});
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{6, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f});
|
||||
Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_max_int)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/loop/loop_2d_add_trip_count_max_int.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, TestCaseType::DYNAMIC>(function);
|
||||
// termination condition
|
||||
test_case.add_input<bool>({true});
|
||||
// a_init
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {6.f, 6.f});
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -140,7 +157,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_const_no_identity_terminat
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {4.f, 4.f});
|
||||
test_case.add_expected_output<float>(Shape{4, 2}, {1, 1, 2, 2, 3, 3, 4, 4});
|
||||
test_case.add_expected_output<float>(Shape{4, 1, 2}, {1, 1, 2, 2, 3, 3, 4, 4});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -182,7 +199,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_both_cond_and_trip_count_a
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {6.f, 6.f});
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{6, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f});
|
||||
Shape{6, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -220,7 +237,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_initializer_from_parent_s
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {6.f, 6.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 2}, {2.f, 2.f, 4.f, 4.f, 6.f, 6.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {2.f, 2.f, 4.f, 4.f, 6.f, 6.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -234,7 +251,26 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_node_from_parent_scope)
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {12.f, 12.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 2}, {4.f, 4.f, 8.f, 8.f, 12.f, 12.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {4.f, 4.f, 8.f, 8.f, 12.f, 12.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME},
|
||||
onnx_controlflow_loop_add_node_from_parent_scope_used_in_parent_and_in_body)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/loop/loop_add_node_from_parent_scope_used_in_parent_and_in_body.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
// parent_input
|
||||
test_case.add_input<float>({3.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {18.f, 18.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {6.f, 6.f, 12.f, 12.f, 18.f, 18.f});
|
||||
test_case.add_expected_output<float>(Shape{1}, {9.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -268,7 +304,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_value_the_same_node_from_
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_add_input_from_parent_graph)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/loop/loop_2d_add_input_from_parent_graph.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
// a_init
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
// b input
|
||||
test_case.add_input<float>({1.f, 1.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -321,7 +373,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_add_const_cond)
|
||||
test_case.add_input<float>({0.f, 0.f});
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 2}, {3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2}, {1.f, 1.f, 2.f, 2.f, 3.f, 3.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
@ -379,8 +431,9 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_2d_trip_count_and_cond_skippe
|
||||
EXPECT_TRUE(function->get_output_partial_shape(0).is_static());
|
||||
EXPECT_EQ(function->get_output_shape(0), (Shape{1, 2}));
|
||||
EXPECT_EQ(function->get_output_element_type(1), ngraph::element::f32);
|
||||
// scan_outputs shape is not know if trip_count and termination condition is not determined
|
||||
EXPECT_TRUE(function->get_output_partial_shape(1).rank().is_dynamic());
|
||||
EXPECT_TRUE(function->get_output_partial_shape(1).rank().is_static());
|
||||
EXPECT_EQ(function->get_output_partial_shape(1).rank(), 3);
|
||||
EXPECT_EQ(function->get_output_partial_shape(1), (PartialShape{Dimension::dynamic(), 1, 2}));
|
||||
}
|
||||
|
||||
// infinitive loop execution
|
||||
|
@ -71,6 +71,7 @@ onnx_model_split_equal_parts_2d
|
||||
onnx_model_split_variable_parts_2d
|
||||
onnx_top_k_opset_10_const_k
|
||||
onnx_top_k_opset_11_const_k_smallest
|
||||
onnx_top_k_opset_11_const_k_smallest_negative_axis
|
||||
split_1d
|
||||
split_2d_axis_0
|
||||
split_2d_axis_1
|
||||
@ -1520,6 +1521,7 @@ IE_GPU.onnx_model_fake_quantize_nonconst_inputs_infer
|
||||
# Not supported dynamic shapes cases for Loop
|
||||
onnx_controlflow_loop_2d_no_identity_termination_cond
|
||||
onnx_controlflow_loop_2d_no_identity_termination_cond_false
|
||||
onnx_controlflow_loop_2d_trip_count_max_int
|
||||
onnx_controlflow_loop_2d_const_no_identity_termination_cond
|
||||
onnx_controlflow_loop_2d_both_cond_and_trip_count_as_inputs
|
||||
onnx_controlflow_loop_no_variadic_inputs_and_outputs
|
||||
|
@ -334,7 +334,7 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_static_shapes
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition is not a Constant
|
||||
// concat output will be dynamic, another outputs are static
|
||||
// concat output has only dynamic rank, another outputs are static
|
||||
TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
@ -397,7 +397,7 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 0);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
@ -422,9 +422,9 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{1};
|
||||
PartialShape out2_shape{PartialShape::dynamic()};
|
||||
PartialShape out2_shape{PartialShape::dynamic(1)};
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_shape(0), out1_shape);
|
||||
@ -435,6 +435,176 @@ TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_dynamic_shape
|
||||
EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition is not a Constant
|
||||
// inputs have partially known shape
|
||||
// concat output has dynamic dimension on axis position, another outputs are static
|
||||
TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_partially_dynamic_shapes)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()});
|
||||
auto Y =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto condition_const =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1}, 10);
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Greater>(M_body, condition_const);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
// check input descriptors
|
||||
for (auto& desc : loop->get_input_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::InvariantInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::SliceInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
|
||||
{
|
||||
auto input_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::MergedInputDescription>(desc);
|
||||
EXPECT_NE(input_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 0 is last Zo
|
||||
auto out0 = loop->get_iter_value(body_condition, -1);
|
||||
auto out1 = loop->get_iter_value(Zo, -1);
|
||||
// axis=1 so sliced output on this dimension will be dynamic
|
||||
auto out2 = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, 1);
|
||||
|
||||
// check output descriptors
|
||||
for (auto& desc : loop->get_output_descriptions())
|
||||
{
|
||||
auto type_info = desc->get_type_info();
|
||||
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::ConcatOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
|
||||
{
|
||||
auto output_desc =
|
||||
as_type_ptr<ngraph::opset5::TensorIterator::BodyOutputDescription>(desc);
|
||||
EXPECT_NE(output_desc, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
PartialShape out1_shape{1, 2, 3, Dimension::dynamic()};
|
||||
PartialShape out2_shape{1, Dimension::dynamic(), 3, Dimension::dynamic()};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
EXPECT_EQ(result0->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(result1->get_output_partial_shape(0), out1_shape);
|
||||
EXPECT_EQ(result2->get_output_partial_shape(0), out2_shape);
|
||||
|
||||
EXPECT_EQ(loop->get_output_shape(0), out0_shape);
|
||||
EXPECT_EQ(loop->get_output_partial_shape(1), out1_shape);
|
||||
EXPECT_EQ(loop->get_output_partial_shape(2), out2_shape);
|
||||
}
|
||||
|
||||
// trip_count = 10
|
||||
// execution_condition = true
|
||||
// body_condition is not a Constant
|
||||
// inputs have partially known shape
|
||||
// Axis of silced output is set as incorrect
|
||||
TEST(type_prop, loop_operation_for_and_condition_mode_dynamic_iter_incorrect_sliced_output_axis)
|
||||
{
|
||||
// That which we iterate over
|
||||
auto X =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()});
|
||||
auto Y =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{1, 2, 3, Dimension::dynamic()});
|
||||
auto M = make_shared<opset5::Parameter>(element::f32, Shape{1});
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto current_iteration = make_shared<opset5::Parameter>(element::i64, Shape{1});
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto M_body = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto condition_const =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1}, 10);
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Greater>(M_body, condition_const);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 10);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto Zo = make_shared<ngraph::opset5::Multiply>(sum, M_body);
|
||||
auto body = make_shared<ngraph::Function>(OutputVector{body_condition, Zo},
|
||||
ParameterVector{current_iteration, Xi, Yi, M_body});
|
||||
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_invariant_input(Xi, X);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
loop->set_merged_input(M_body, M, Zo);
|
||||
|
||||
const auto sliced_output_axis = 4;
|
||||
auto out = loop->get_concatenated_slices(Zo, 0, 1, 1, -1, sliced_output_axis);
|
||||
|
||||
auto result = make_shared<opset5::Result>(out);
|
||||
try
|
||||
{
|
||||
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{X, Y, M});
|
||||
FAIL() << "Loop was created with incorrect axis of concatenated slices output.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(), std::string("Concatenation axis must be less than sliced output rank"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Construction loop operator failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
// trip_count = -1
|
||||
// execution_condition = true
|
||||
// body_condition = true
|
||||
@ -527,7 +697,7 @@ TEST(type_prop, loop_operation_infinite_loop_mode_dynamic_iter_dynamic_shapes)
|
||||
auto result2 = make_shared<opset5::Result>(out2);
|
||||
Shape out0_shape{1};
|
||||
Shape out1_shape{32, 1, 10};
|
||||
PartialShape out2_shape{PartialShape::dynamic()};
|
||||
PartialShape out2_shape{32, Dimension::dynamic(), 10};
|
||||
|
||||
auto results = ResultVector{result0, result1, result2};
|
||||
auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
|
||||
|
Loading…
Reference in New Issue
Block a user