From f60b46f3d453bb16d5298056807c7a5375f3e7c4 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Thu, 3 Sep 2020 16:00:46 +0300 Subject: [PATCH] Fixed visitor for Interpolate-1 and Interpolate-4 (#2051) * Fixed visitor for Interpolate-1 and Interpolate-4 * Code style fix * Remove unnecessary changes * Fixed compilation on Linux for Atttribute visitor of vector * Added unit test for IE IR Reader for Interpolate-4 * Updated unit test for IR Reader for Interpolate-4 * Updated unit test --- .../src/readers/ir_reader/ie_ir_parser.cpp | 35 ---- .../src/readers/ir_reader/ie_ir_parser.hpp | 15 ++ .../ngraph_reader/interpolate_tests.cpp | 180 ++++++++++++++++++ ngraph/core/include/ngraph/axis_set.hpp | 1 + ngraph/core/src/op/interpolate.cpp | 4 +- 5 files changed, 198 insertions(+), 37 deletions(-) diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index 8e8ac98c5af..e291ba062cf 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -429,7 +429,6 @@ std::shared_ptr V10Parser::createNode(const std::vector>("BinaryConvolution"), std::make_shared>("GRN"), std::make_shared>("HardSigmoid"), - std::make_shared>("Interpolate"), std::make_shared>("Log"), std::make_shared>("SquaredDifference"), std::make_shared>("Less"), @@ -1384,40 +1383,6 @@ std::shared_ptr V10Parser::LayerCreator::cr return std::make_shared(inputs[0], inputs[1]); } -// Interpolate layer -template <> -std::shared_ptr V10Parser::LayerCreator::createLayer( - const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream, - const GenericLayerParams& layerParsePrms) { - checkParameters(inputs, layerParsePrms, 2); - - pugi::xml_node dn = node.child("data"); - - if (dn.empty()) - THROW_IE_EXCEPTION << "Cannot read parameter for " << getType() << " layer with name: " << layerParsePrms.name; - - ngraph::op::v0::InterpolateAttrs attrs; - for (auto& axis : getParameters(dn, "axes")) { - attrs.axes.insert(axis); - } - - std::set available_modes {"linear", "nearest", "cubic", "area"}; - attrs.mode = GetStrAttr(dn, "mode"); - if (!available_modes.count(attrs.mode)) { - THROW_IE_EXCEPTION << "Interpolate mode: " << attrs.mode << " is unsupported!"; - } - attrs.align_corners = GetIntAttr(dn, "align_corners", 1); - attrs.antialias = GetIntAttr(dn, "antialias", 0); - for (auto& pad : getParameters(dn, "pads_begin")) { - attrs.pads_begin.push_back(pad); - } - for (auto& pad : getParameters(dn, "pads_end")) { - attrs.pads_end.push_back(pad); - } - - return std::make_shared(inputs[0], inputs[1], attrs); -} - // Abs layer template <> std::shared_ptr V10Parser::LayerCreator::createLayer( diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp index bb03fca30de..0d270eed7a2 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp @@ -217,6 +217,21 @@ private: std::vector shape; if (!getParameters(node.child("data"), name, shape)) return; static_cast(*a) = ngraph::Strides(shape); +#ifdef __APPLE__ + } else if (auto a = ngraph::as_type>>(&adapter)) { + std::vector result; + if (!getParameters(node.child("data"), name, result)) return; + static_cast&>(*a) = result; +#else + } else if (auto a = ngraph::as_type>>(&adapter)) { + std::vector result; + if (!getParameters(node.child("data"), name, result)) return; + a->set(result); +#endif + } else if (auto a = ngraph::as_type>(&adapter)) { + std::vector axes; + if (!getParameters(node.child("data"), name, axes)) return; + static_cast(*a) = ngraph::AxisSet(axes); } else if (auto a = ngraph::as_type>(&adapter)) { if (!getStrAttribute(node.child("data"), name, val)) return; static_cast(*a) = ngraph::as_enum(val); diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reader/interpolate_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reader/interpolate_tests.cpp index d6feb7b3477..a45abfea5f5 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reader/interpolate_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reader/interpolate_tests.cpp @@ -220,3 +220,183 @@ TEST_F(NGraphReaderTests, ReadInterpolate2Network) { data[3] = 60; }); } + +TEST_F(NGraphReaderTests, ReadInterpolate4Network) { + std::string model = R"V0G0N( + + + + + + + 1 + 2 + 300 + 300 + + + + + + + + 2 + + + + + + + + 2 + + + + + + + + 2 + + + + + + + + 1 + 2 + 300 + 300 + + + 2 + + + 2 + + + 2 + + + + + 9 + 12 + 600 + 900 + + + + + + + 9 + 12 + 600 + 900 + + + + + + + + + + + + +)V0G0N"; + std::string modelV5 = R"V0G0N( + + + + + + 1 + 2 + 300 + 300 + + + + + + + 2 + + + + + + + + + + 2 + + + + + + + + + + 2 + + + + + + + + + + + 1 + 2 + 300 + 300 + + + 2 + + + 2 + + + 2 + + + + + 9 + 12 + 600 + 900 + + + + + + + + + + + +)V0G0N"; + compareIRs(model, modelV5, 24, [](Blob::Ptr& weights) { + auto *data = weights->buffer().as(); + data[0] = 600; + data[1] = 900; + data[4] = 2; + data[5] = 3; + + auto *fdata = weights->buffer().as(); + fdata[2] = 2.0; + fdata[3] = 2.0; + }); +} \ No newline at end of file diff --git a/ngraph/core/include/ngraph/axis_set.hpp b/ngraph/core/include/ngraph/axis_set.hpp index aad365f8190..7e518ab78e0 100644 --- a/ngraph/core/include/ngraph/axis_set.hpp +++ b/ngraph/core/include/ngraph/axis_set.hpp @@ -60,6 +60,7 @@ namespace ngraph void set(const std::vector& value) override; static constexpr DiscreteTypeInfo type_info{"AttributeAdapter", 0}; const DiscreteTypeInfo& get_type_info() const override { return type_info; } + operator AxisSet&() { return m_ref; } protected: AxisSet& m_ref; std::vector m_buffer; diff --git a/ngraph/core/src/op/interpolate.cpp b/ngraph/core/src/op/interpolate.cpp index 4f263aba368..b327c98de89 100644 --- a/ngraph/core/src/op/interpolate.cpp +++ b/ngraph/core/src/op/interpolate.cpp @@ -38,10 +38,10 @@ op::v0::Interpolate::Interpolate(const Output& image, bool op::v0::Interpolate::visit_attributes(AttributeVisitor& visitor) { - visitor.on_attribute("axes", m_attrs.axes); - visitor.on_attribute("mode", m_attrs.mode); visitor.on_attribute("align_corners", m_attrs.align_corners); visitor.on_attribute("antialias", m_attrs.antialias); + visitor.on_attribute("axes", m_attrs.axes); + visitor.on_attribute("mode", m_attrs.mode); visitor.on_attribute("pads_begin", m_attrs.pads_begin); visitor.on_attribute("pads_end", m_attrs.pads_end); return true;