[IE] Fix UNITY build (#2799)

This commit is contained in:
Vladislav Vinogradov 2020-10-23 19:21:02 +03:00 committed by GitHub
parent 33371ca1ac
commit d846969a1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 39 deletions

View File

@ -119,6 +119,8 @@ bool FakeQuantizeTransformation::transform(TransformationContext& context, ngrap
return true;
}
namespace fq {
static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
const Shape shape = op->get_output_shape(0);
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
@ -154,8 +156,10 @@ static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>
return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
}
} // namespace fq
bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
if (constant == nullptr) {
return false;
}
@ -178,7 +182,7 @@ bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& e
}
}
return getData(eltwise) != nullptr;
return fq::getData(eltwise) != nullptr;
}
std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
@ -189,7 +193,7 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
if (is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
@ -203,8 +207,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
}
}
inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
inputLowConst = fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Divide>(eltwise) && checkElementwise(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
@ -218,18 +222,18 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
}
}
inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
inputLowConst = fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Subtract>(eltwise) && checkElementwise(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
inputLowConst = fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
if (is_type<opset1::Convolution>(getData(eltwise)) ||
is_type<opset1::GroupConvolution>(getData(eltwise))) {
if (is_type<opset1::Convolution>(fq::getData(eltwise)) ||
is_type<opset1::GroupConvolution>(fq::getData(eltwise))) {
return nullptr;
}
@ -237,8 +241,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
inputLowConst = fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Convert>(eltwise)) {
// issue #40611
if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
@ -249,7 +253,7 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
}
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
getData(eltwise),
fq::getData(eltwise),
inputLowConst,
inputHightConst,
fakeQuantize->input_value(3),

View File

@ -24,6 +24,8 @@ bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, n
return true;
}
namespace fuse_fq {
std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
const Shape shape = op->get_output_shape(0);
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
@ -86,6 +88,8 @@ bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
return getData(eltwise) != nullptr;
}
} // namespace fuse_fq
std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
TransformationContext& context,
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
@ -94,31 +98,31 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
if (is_type<opset1::Multiply>(eltwise) && eltwiseWithConstant(eltwise)) {
std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
if (is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Divide>(eltwise) && eltwiseWithConstant(eltwise)) {
inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Subtract>(eltwise) && eltwiseWithConstant(eltwise)) {
inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Add>(eltwise) && eltwiseWithConstant(eltwise)) {
if (is_type<opset1::Convolution>(getData(eltwise)) ||
is_type<opset1::GroupConvolution>(getData(eltwise))) {
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
if (is_type<opset1::Convolution>(fuse_fq::getData(eltwise)) ||
is_type<opset1::GroupConvolution>(fuse_fq::getData(eltwise))) {
return nullptr;
}
@ -126,8 +130,8 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
constant :
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
} else if (is_type<opset1::Convert>(eltwise)) {
// issue #40611
if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
@ -138,7 +142,7 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
}
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
getData(eltwise),
fuse_fq::getData(eltwise),
inputLowConst,
inputHightConst,
fakeQuantize->input_value(3),

View File

@ -18,7 +18,7 @@ using namespace ngraph;
using namespace ngraph::pass;
using namespace ngraph::pass::low_precision;
namespace {
namespace mvn {
template<typename T>
std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
@ -33,7 +33,7 @@ std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Con
return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
}
} // namespace
} // namespace mvn
bool MVNTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!LayerTransformation::canBeTransformed(context, operation)) {
@ -93,11 +93,11 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
if (normalizeVariance) {
switch (type) {
case ngraph::element::Type_t::f16: {
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
break;
}
case ngraph::element::Type_t::f32: {
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
break;
}
default: {

View File

@ -17,7 +17,7 @@ using namespace ngraph;
using namespace ngraph::pass;
using namespace ngraph::pass::low_precision;
namespace {
namespace normalize_l2 {
template<typename T>
std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
@ -32,7 +32,7 @@ std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Con
return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
}
} // namespace
} // namespace normalize_l2
bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!LayerTransformation::canBeTransformed(context, operation)) {
@ -106,11 +106,11 @@ bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph
const auto type = scalesConst->get_output_element_type(0);
switch (type) {
case ngraph::element::Type_t::f16: {
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
break;
}
case ngraph::element::Type_t::f32: {
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
break;
}
default: {

View File

@ -44,7 +44,6 @@ addIeTargetTest(
)
ie_faster_build(${TARGET_NAME}
UNITY
PCH PRIVATE "precomp.hpp"
)