[IE][VPU][Tests] Support DTS for Select (#3604)
* Support DTS + binary eltwise tests refactoring (avoid code duplication)
This commit is contained in:
parent
5f9ef0cf26
commit
f2f5e99f9f
@ -100,6 +100,7 @@ const Transformations& getDefaultTransformations() {
|
||||
{ngraph::opset3::Maximum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Minimum::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset3::Less::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset5::Select::type_info, dynamicToStaticShapeBinaryEltwise},
|
||||
{ngraph::opset5::NonMaxSuppression::type_info, dynamicToStaticNonMaxSuppression},
|
||||
{ngraph::opset3::NonZero::type_info, dynamicToStaticShapeNonZero},
|
||||
{ngraph::opset3::TopK::type_info, dynamicToStaticShapeTopK},
|
||||
|
@ -10,32 +10,36 @@
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include <ngraph/ops.hpp>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
void dynamicToStaticShapeBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise) {
|
||||
const auto lhsRank = eltwise->input_value(0).get_partial_shape().rank();
|
||||
const auto rhsRank = eltwise->input_value(1).get_partial_shape().rank();
|
||||
namespace {
|
||||
|
||||
void processBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise, size_t lhsIndex, size_t rhsIndex) {
|
||||
const auto lhsRank = eltwise->input_value(lhsIndex).get_partial_shape().rank();
|
||||
const auto rhsRank = eltwise->input_value(rhsIndex).get_partial_shape().rank();
|
||||
|
||||
const auto copied = eltwise->copy_with_new_inputs(eltwise->input_values());
|
||||
|
||||
const auto lhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(0).get_node_shared_ptr());
|
||||
const auto rhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(1).get_node_shared_ptr());
|
||||
const auto lhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(lhsIndex).get_node_shared_ptr());
|
||||
const auto rhsDSR = ngraph::as_type_ptr<ngraph::vpu::op::DynamicShapeResolver>(eltwise->input_value(rhsIndex).get_node_shared_ptr());
|
||||
|
||||
VPU_THROW_UNLESS(lhsDSR || rhsDSR, "DynamicToStaticShape transformation for {} of type {} expects at least one DSR as input",
|
||||
eltwise->get_friendly_name(), eltwise->get_type_info());
|
||||
eltwise->get_friendly_name(), eltwise->get_type_info());
|
||||
if (lhsDSR && rhsDSR) {
|
||||
VPU_THROW_UNLESS(lhsDSR->get_input_element_type(1) == rhsDSR->get_input_element_type(1),
|
||||
"DynamicToStaticShape transformation for {} of type {} expects equal shapes data types, actual {} vs {}",
|
||||
eltwise->get_friendly_name(), eltwise->get_type_info(),
|
||||
lhsDSR->get_input_element_type(1), rhsDSR->get_input_element_type(1));
|
||||
"DynamicToStaticShape transformation for {} of type {} expects equal shapes data types, actual {} vs {}",
|
||||
eltwise->get_friendly_name(), eltwise->get_type_info(),
|
||||
lhsDSR->get_input_element_type(1), rhsDSR->get_input_element_type(1));
|
||||
}
|
||||
const auto shapeElementType = lhsDSR ? lhsDSR->get_input_element_type(1) : rhsDSR->get_input_element_type(1);
|
||||
|
||||
auto lhsInput = lhsDSR ? lhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(0));
|
||||
auto rhsInput = rhsDSR ? rhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(1));
|
||||
auto lhsInput = lhsDSR ? lhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(lhsIndex));
|
||||
auto rhsInput = rhsDSR ? rhsDSR->input_value(1) : shapeToConstant(shapeElementType, eltwise->get_input_shape(rhsIndex));
|
||||
|
||||
const auto diff = std::abs(lhsRank.get_length() - rhsRank.get_length());
|
||||
if (diff) {
|
||||
@ -51,4 +55,17 @@ void dynamicToStaticShapeBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise) {
|
||||
ngraph::replace_node(std::move(eltwise), std::move(outDSR));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void dynamicToStaticShapeBinaryEltwise(std::shared_ptr<ngraph::Node> eltwise) {
|
||||
if (eltwise->get_type_info() == ngraph::opset5::Select::type_info) {
|
||||
processBinaryEltwise(eltwise, 1, 2);
|
||||
} else {
|
||||
VPU_THROW_UNLESS(eltwise->get_input_size() == 2,
|
||||
"DynamicToStaticShape transformation for {} of type {} expects two inputs while {} were provided",
|
||||
eltwise->get_friendly_name(), eltwise->get_type_info(), eltwise->get_input_size());
|
||||
processBinaryEltwise(eltwise, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vpu
|
||||
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/shape.hpp>
|
||||
#include <ngraph/type/element_type.hpp>
|
||||
|
||||
@ -17,24 +18,31 @@
|
||||
|
||||
namespace {
|
||||
|
||||
enum class TestShapeTypes {
|
||||
ALL_DYNAMIC,
|
||||
SINGLE_DSR
|
||||
};
|
||||
|
||||
using DataType = ngraph::element::Type_t;
|
||||
using DataDims = ngraph::Shape;
|
||||
using refFunction = std::function<std::shared_ptr<ngraph::Function> (const DataType&, const ngraph::NodeTypeInfo&, const DataDims&, const DataDims&)>;
|
||||
using refFunction = std::function<std::shared_ptr<ngraph::Function> (
|
||||
const DataType&, const ngraph::NodeTypeInfo&, const DataDims&, const DataDims&, TestShapeTypes)>;
|
||||
using EltwiseParams = std::tuple<DataDims, DataDims, refFunction>;
|
||||
|
||||
class DynamicToStaticShapeEltwise: public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<ngraph::element::Type_t,
|
||||
ngraph::NodeTypeInfo, EltwiseParams>> {
|
||||
ngraph::NodeTypeInfo, EltwiseParams, TestShapeTypes>> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const auto& dataType = std::get<0>(GetParam());
|
||||
const auto& eltwiseType = std::get<1>(GetParam());
|
||||
const auto& eltwiseParams = std::get<2>(GetParam());
|
||||
const auto& testShapeTypes = std::get<3>(GetParam());
|
||||
|
||||
const auto& input0_shape = std::get<0>(eltwiseParams);
|
||||
const auto& input1_shape = std::get<1>(eltwiseParams);
|
||||
const auto& input0Shape = std::get<0>(eltwiseParams);
|
||||
const auto& input1Shape = std::get<1>(eltwiseParams);
|
||||
|
||||
ngraph::helpers::CompareFunctions(*transform(dataType, eltwiseType, input0_shape, input1_shape),
|
||||
*std::get<2>(eltwiseParams)(dataType, eltwiseType, input0_shape, input1_shape));
|
||||
ngraph::helpers::CompareFunctions(*transform(dataType, eltwiseType, input0Shape, input1Shape, testShapeTypes),
|
||||
*std::get<2>(eltwiseParams)(dataType, eltwiseType, input0Shape, input1Shape, testShapeTypes));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -42,27 +50,36 @@ protected:
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) const {
|
||||
const ngraph::Shape& dataDims1,
|
||||
TestShapeTypes testShapeTypes) const {
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()});
|
||||
const auto input0Dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0Dims);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, input1_dsr);
|
||||
ngraph::ParameterVector params{input0, input1, input0Dims};
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, dsr1});
|
||||
std::shared_ptr<ngraph::Node> eltwiseInput1 = input1;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
const auto input1Dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64,
|
||||
ngraph::Shape{dataDims1.size()});
|
||||
eltwiseInput1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, input1Dims);
|
||||
params.push_back(input1Dims);
|
||||
}
|
||||
|
||||
const auto eltwise = buildEltwise(eltwiseType, {dsr0, eltwiseInput1}, params, testShapeTypes);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{eltwise},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr, input1_dsr},
|
||||
params,
|
||||
"Actual");
|
||||
|
||||
eltwise->set_output_type(0, eltwise->get_input_element_type(0), ngraph::PartialShape::dynamic(eltwise->get_output_partial_shape(0).rank()));
|
||||
|
||||
const auto transformations = vpu::Transformations{{eltwiseType, vpu::dynamicToStaticShapeBinaryEltwise}};
|
||||
vpu::DynamicToStaticShape(transformations).run_on_function(function);
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
@ -72,26 +89,39 @@ public:
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
const ngraph::Shape& dataDims1,
|
||||
TestShapeTypes testShapeTypes) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()});
|
||||
const auto input0Dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0Dims);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, input1_dsr);
|
||||
ngraph::ParameterVector params{input0, input1, input0Dims};
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, dsr1});
|
||||
std::shared_ptr<ngraph::Node> dims;
|
||||
if (testShapeTypes == TestShapeTypes:: ALL_DYNAMIC) {
|
||||
params.push_back(std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()}));
|
||||
dims = params.back();
|
||||
} else {
|
||||
dims = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> eltwiseInput1 = input1;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
eltwiseInput1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, dims);
|
||||
}
|
||||
|
||||
const auto eltwise = buildEltwise(eltwiseType, {dsr0, eltwiseInput1}, params, testShapeTypes);
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0_dsr, input1_dsr);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0Dims, dims);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr, input1_dsr},
|
||||
params,
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
@ -102,28 +132,41 @@ public:
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
const ngraph::Shape& dataDims1,
|
||||
TestShapeTypes testShapeTypes) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()});
|
||||
const auto input0Dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0Dims);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, input1_dsr);
|
||||
ngraph::ParameterVector params{input0, input1, input0Dims};
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, dsr1});
|
||||
std::shared_ptr<ngraph::Node> dims;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
params.push_back(std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()}));
|
||||
dims = params.back();
|
||||
} else {
|
||||
dims = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> eltwiseInput1 = input1;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
eltwiseInput1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, dims);
|
||||
}
|
||||
|
||||
const auto eltwise = buildEltwise(eltwiseType, {dsr0, eltwiseInput1}, params, testShapeTypes);
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto broadcast_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size() - dataDims0.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcast_const, input0_dsr}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(concat, input1_dsr);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
const auto broadcastConst = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size() - dataDims0.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcastConst, input0Dims}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(concat, dims);
|
||||
const auto dsrFinal = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr, input1_dsr},
|
||||
ngraph::NodeVector{dsrFinal},
|
||||
params,
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
@ -134,166 +177,68 @@ public:
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
const ngraph::Shape& dataDims1,
|
||||
TestShapeTypes testShapeTypes) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()});
|
||||
const auto input0Dims = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0Dims);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
const auto dsr1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, input1_dsr);
|
||||
ngraph::ParameterVector params{input0, input1, input0Dims};
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, dsr1});
|
||||
std::shared_ptr<ngraph::Node> dims;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
params.push_back(std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims1.size()}));
|
||||
dims = params.back();
|
||||
} else {
|
||||
dims = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> eltwiseInput1 = input1;
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
eltwiseInput1 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input1, dims);
|
||||
}
|
||||
|
||||
const auto eltwise = buildEltwise(eltwiseType, {dsr0, eltwiseInput1}, params, testShapeTypes);
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto broadcast_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims0.size() - dataDims1.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcast_const, input1_dsr}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0_dsr, concat);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
const auto broadcastConst = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims0.size() - dataDims1.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcastConst, dims}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0Dims, concat);
|
||||
const auto dsrFinal = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr, input1_dsr},
|
||||
ngraph::NodeVector{dsrFinal},
|
||||
params,
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
}
|
||||
};
|
||||
|
||||
class DynamicToStaticShapeEltwiseSingleDSR: public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<ngraph::element::Type_t,
|
||||
ngraph::NodeTypeInfo, EltwiseParams>> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const auto& dataType = std::get<0>(GetParam());
|
||||
const auto& eltwiseType = std::get<1>(GetParam());
|
||||
const auto& eltwiseParams = std::get<2>(GetParam());
|
||||
|
||||
const auto& input0_shape = std::get<0>(eltwiseParams);
|
||||
const auto& input1_shape = std::get<1>(eltwiseParams);
|
||||
|
||||
ngraph::helpers::CompareFunctions(*transform(dataType, eltwiseType, input0_shape, input1_shape),
|
||||
*std::get<2>(eltwiseParams)(dataType, eltwiseType, input0_shape, input1_shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<const ngraph::Function> transform(
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) const {
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, input1});
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{eltwise},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr},
|
||||
"Actual");
|
||||
|
||||
eltwise->set_output_type(0, eltwise->get_input_element_type(0), ngraph::PartialShape::dynamic(eltwise->get_output_partial_shape(0).rank()));
|
||||
|
||||
const auto transformations = vpu::Transformations{{eltwiseType, vpu::dynamicToStaticShapeBinaryEltwise}};
|
||||
vpu::DynamicToStaticShape(transformations).run_on_function(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
public:
|
||||
private:
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> reference_simple(
|
||||
const ngraph::element::Type_t& dataType,
|
||||
std::shared_ptr<ngraph::Node> buildEltwise(
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, input1});
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0_dsr, input1_const);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr},
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> reference_broadcast_left(
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, input1});
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto broadcast_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size() - dataDims0.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcast_const, input0_dsr}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(concat, input1_const);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr},
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> reference_broadcast_right(
|
||||
const ngraph::element::Type_t& dataType,
|
||||
const ngraph::NodeTypeInfo& eltwiseType,
|
||||
const ngraph::Shape& dataDims0,
|
||||
const ngraph::Shape& dataDims1) {
|
||||
// Data flow subgraph
|
||||
const auto input0 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims0);
|
||||
const auto input1 = std::make_shared<ngraph::opset3::Parameter>(dataType, dataDims1);
|
||||
|
||||
const auto input0_dsr = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{dataDims0.size()});
|
||||
const auto input1_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims1.size()}, dataDims1);
|
||||
|
||||
const auto dsr0 = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(input0, input0_dsr);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {dsr0, input1});
|
||||
|
||||
// Shape infer subgraph
|
||||
const auto broadcast_const = ngraph::opset3::Constant::create(ngraph::element::i64, {dataDims0.size() - dataDims1.size()}, {1});
|
||||
const auto concat = std::make_shared<ngraph::opset3::Concat>(ngraph::OutputVector{broadcast_const, input1_const}, 0);
|
||||
const auto maximum = std::make_shared<ngraph::opset3::Maximum>(input0_dsr, concat);
|
||||
const auto dsr_final = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(eltwise, maximum);
|
||||
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::NodeVector{dsr_final},
|
||||
ngraph::ParameterVector{input0, input1, input0_dsr},
|
||||
"Actual");
|
||||
|
||||
return function;
|
||||
const ngraph::OutputVector& inputs,
|
||||
ngraph::ParameterVector& params,
|
||||
TestShapeTypes testShapeTypes) {
|
||||
if (eltwiseType == ngraph::opset5::Select::type_info) {
|
||||
params.push_back(std::make_shared<ngraph::opset3::Parameter>(
|
||||
ngraph::element::boolean,
|
||||
ngraph::Shape{inputs.front().get_shape()}));
|
||||
std::shared_ptr<ngraph::Node> condInput = params.back();
|
||||
if (testShapeTypes == TestShapeTypes::ALL_DYNAMIC) {
|
||||
params.push_back(std::make_shared<ngraph::opset3::Parameter>(
|
||||
ngraph::element::i64,
|
||||
ngraph::Shape{static_cast<size_t>(inputs.front().get_partial_shape().rank().get_length())}));
|
||||
condInput = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(condInput, params.back());
|
||||
}
|
||||
return ngraph::helpers::getNodeSharedPtr(eltwiseType, {condInput, inputs[0], inputs[1]});
|
||||
} else {
|
||||
return ngraph::helpers::getNodeSharedPtr(eltwiseType, inputs);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -317,37 +262,14 @@ INSTANTIATE_TEST_CASE_P(smoke_EltwiseBroadcast, DynamicToStaticShapeEltwise, tes
|
||||
ngraph::opset3::Subtract::type_info,
|
||||
ngraph::opset3::Maximum::type_info,
|
||||
ngraph::opset3::Minimum::type_info,
|
||||
ngraph::opset3::Less::type_info),
|
||||
ngraph::opset3::Less::type_info,
|
||||
ngraph::opset5::Select::type_info),
|
||||
testing::Values(
|
||||
EltwiseParams{DataDims{1000}, DataDims{1}, DynamicToStaticShapeEltwise::reference_simple},
|
||||
EltwiseParams{DataDims{1000, 1, 1}, DataDims{1000, 1, 1}, DynamicToStaticShapeEltwise::reference_simple},
|
||||
EltwiseParams{DataDims{2, 1000}, DataDims{3, 1, 1}, DynamicToStaticShapeEltwise::reference_broadcast_left},
|
||||
EltwiseParams{DataDims{1000, 64}, DataDims{1}, DynamicToStaticShapeEltwise::reference_broadcast_right})));
|
||||
EltwiseParams{DataDims{1000, 64}, DataDims{1}, DynamicToStaticShapeEltwise::reference_broadcast_right}),
|
||||
testing::Values(TestShapeTypes::ALL_DYNAMIC, TestShapeTypes::SINGLE_DSR)
|
||||
));
|
||||
|
||||
TEST_P(DynamicToStaticShapeEltwiseSingleDSR, CompareFunctions) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_EltwiseBroadcastSingleDSR, DynamicToStaticShapeEltwiseSingleDSR, testing::Combine(
|
||||
testing::Values(
|
||||
ngraph::element::f16,
|
||||
ngraph::element::f32,
|
||||
ngraph::element::i32,
|
||||
ngraph::element::i64,
|
||||
ngraph::element::u8),
|
||||
testing::Values(
|
||||
ngraph::opset3::Add::type_info,
|
||||
ngraph::opset3::Divide::type_info,
|
||||
ngraph::opset3::Equal::type_info,
|
||||
ngraph::opset3::Greater::type_info,
|
||||
ngraph::opset3::Power::type_info,
|
||||
ngraph::opset3::Multiply::type_info,
|
||||
ngraph::opset3::Subtract::type_info,
|
||||
ngraph::opset3::Maximum::type_info,
|
||||
ngraph::opset3::Minimum::type_info,
|
||||
ngraph::opset3::Less::type_info),
|
||||
testing::Values(
|
||||
EltwiseParams{DataDims{1000}, DataDims{1}, DynamicToStaticShapeEltwiseSingleDSR::reference_simple},
|
||||
EltwiseParams{DataDims{1000, 1, 1}, DataDims{1000, 1, 1}, DynamicToStaticShapeEltwiseSingleDSR::reference_simple},
|
||||
EltwiseParams{DataDims{2, 1000}, DataDims{3, 1, 1}, DynamicToStaticShapeEltwiseSingleDSR::reference_broadcast_left},
|
||||
EltwiseParams{DataDims{1000, 64}, DataDims{1}, DynamicToStaticShapeEltwiseSingleDSR::reference_broadcast_right})));
|
||||
} // namespace
|
@ -37,7 +37,10 @@ protected:
|
||||
const auto inputSubgraph0 = createInputSubgraphWithDSR(inDataType, inDataShapes.lhs);
|
||||
const auto inputSubgraph1 = createInputSubgraphWithDSR(inDataType, inDataShapes.rhs);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {inputSubgraph0, inputSubgraph1});
|
||||
const auto eltwise = eltwiseType == ngraph::opset5::Select::type_info ?
|
||||
ngraph::helpers::getNodeSharedPtr(eltwiseType, {createInputSubgraphWithDSR(
|
||||
ngraph::element::boolean, inDataShapes.lhs), inputSubgraph0, inputSubgraph1}) :
|
||||
ngraph::helpers::getNodeSharedPtr(eltwiseType, {inputSubgraph0, inputSubgraph1});
|
||||
|
||||
return eltwise;
|
||||
}
|
||||
@ -56,7 +59,10 @@ protected:
|
||||
const auto inputSubgraph0 = createInputSubgraphWithDSR(inDataType, inDataShapes.lhs);
|
||||
const auto input1 = createParameter(inDataType, inDataShapes.rhs.shape);
|
||||
|
||||
const auto eltwise = ngraph::helpers::getNodeSharedPtr(eltwiseType, {inputSubgraph0, input1});
|
||||
const auto eltwise = eltwiseType == ngraph::opset5::Select::type_info ?
|
||||
ngraph::helpers::getNodeSharedPtr(eltwiseType, {createParameter(
|
||||
ngraph::element::boolean, inDataShapes.rhs.shape), inputSubgraph0, input1}) :
|
||||
ngraph::helpers::getNodeSharedPtr(eltwiseType, {inputSubgraph0, input1});
|
||||
|
||||
return eltwise;
|
||||
}
|
||||
@ -70,6 +76,7 @@ static const std::vector<ngraph::NodeTypeInfo> binaryEltwiseTypeVector = {
|
||||
ngraph::opset3::Equal::type_info,
|
||||
ngraph::opset3::Greater::type_info,
|
||||
ngraph::opset3::Power::type_info,
|
||||
ngraph::opset5::Select::type_info,
|
||||
};
|
||||
|
||||
static const std::set<ngraph::NodeTypeInfo> doNotSupportI32 = {
|
||||
|
Loading…
Reference in New Issue
Block a user