[IE] Fix UNITY build (#2799)
This commit is contained in:
parent
33371ca1ac
commit
d846969a1c
@ -119,6 +119,8 @@ bool FakeQuantizeTransformation::transform(TransformationContext& context, ngrap
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace fq {
|
||||
|
||||
static std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
|
||||
const Shape shape = op->get_output_shape(0);
|
||||
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
|
||||
@ -154,8 +156,10 @@ static std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>
|
||||
return as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
|
||||
}
|
||||
|
||||
} // namespace fq
|
||||
|
||||
bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& eltwise) {
|
||||
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
|
||||
std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
|
||||
if (constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
@ -178,7 +182,7 @@ bool FakeQuantizeTransformation::checkElementwise(const std::shared_ptr<Node>& e
|
||||
}
|
||||
}
|
||||
|
||||
return getData(eltwise) != nullptr;
|
||||
return fq::getData(eltwise) != nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwise(
|
||||
@ -189,7 +193,7 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
|
||||
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
|
||||
|
||||
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
|
||||
std::shared_ptr<opset1::Constant> constant = fq::getConstant(eltwise);
|
||||
if (is_type<opset1::Multiply>(eltwise) && checkElementwise(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
@ -203,8 +207,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
}
|
||||
}
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputLowConst = fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Divide>(eltwise) && checkElementwise(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
@ -218,18 +222,18 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
}
|
||||
}
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputLowConst = fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Subtract>(eltwise) && checkElementwise(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputLowConst = fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Add>(eltwise) && checkElementwise(eltwise)) {
|
||||
if (is_type<opset1::Convolution>(getData(eltwise)) ||
|
||||
is_type<opset1::GroupConvolution>(getData(eltwise))) {
|
||||
if (is_type<opset1::Convolution>(fq::getData(eltwise)) ||
|
||||
is_type<opset1::GroupConvolution>(fq::getData(eltwise))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -237,8 +241,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputLowConst = fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Convert>(eltwise)) {
|
||||
// issue #40611
|
||||
if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
|
||||
@ -249,7 +253,7 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
||||
getData(eltwise),
|
||||
fq::getData(eltwise),
|
||||
inputLowConst,
|
||||
inputHightConst,
|
||||
fakeQuantize->input_value(3),
|
||||
|
@ -24,6 +24,8 @@ bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, n
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace fuse_fq {
|
||||
|
||||
std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const Shape& targetShape) {
|
||||
const Shape shape = op->get_output_shape(0);
|
||||
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
|
||||
@ -86,6 +88,8 @@ bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
|
||||
return getData(eltwise) != nullptr;
|
||||
}
|
||||
|
||||
} // namespace fuse_fq
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
||||
TransformationContext& context,
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
|
||||
@ -94,31 +98,31 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
||||
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
|
||||
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
|
||||
|
||||
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
|
||||
if (is_type<opset1::Multiply>(eltwise) && eltwiseWithConstant(eltwise)) {
|
||||
std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
|
||||
if (is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Divide>(eltwise) && eltwiseWithConstant(eltwise)) {
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Subtract>(eltwise) && eltwiseWithConstant(eltwise)) {
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Add>(eltwise) && eltwiseWithConstant(eltwise)) {
|
||||
if (is_type<opset1::Convolution>(getData(eltwise)) ||
|
||||
is_type<opset1::GroupConvolution>(getData(eltwise))) {
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
if (is_type<opset1::Convolution>(fuse_fq::getData(eltwise)) ||
|
||||
is_type<opset1::GroupConvolution>(fuse_fq::getData(eltwise))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -126,8 +130,8 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
||||
constant :
|
||||
fold<opset1::Convert>(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_shape(0));
|
||||
} else if (is_type<opset1::Convert>(eltwise)) {
|
||||
// issue #40611
|
||||
if ((eltwise->input(0).get_element_type() == element::i32) && (eltwise->output(0).get_element_type() == element::f32)) {
|
||||
@ -138,7 +142,7 @@ std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
||||
getData(eltwise),
|
||||
fuse_fq::getData(eltwise),
|
||||
inputLowConst,
|
||||
inputHightConst,
|
||||
fakeQuantize->input_value(3),
|
||||
|
@ -18,7 +18,7 @@ using namespace ngraph;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
namespace {
|
||||
namespace mvn {
|
||||
|
||||
template<typename T>
|
||||
std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
|
||||
@ -33,7 +33,7 @@ std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Con
|
||||
return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mvn
|
||||
|
||||
bool MVNTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
|
||||
if (!LayerTransformation::canBeTransformed(context, operation)) {
|
||||
@ -93,11 +93,11 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
|
||||
if (normalizeVariance) {
|
||||
switch (type) {
|
||||
case ngraph::element::Type_t::f16: {
|
||||
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
|
||||
newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f32: {
|
||||
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
|
||||
newScalesConst = mvn::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -17,7 +17,7 @@ using namespace ngraph;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
namespace {
|
||||
namespace normalize_l2 {
|
||||
|
||||
template<typename T>
|
||||
std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Constant& originalConst) {
|
||||
@ -32,7 +32,7 @@ std::shared_ptr<ngraph::op::Constant> createNewScalesConst(const ngraph::op::Con
|
||||
return ngraph::op::Constant::create(type, originalConst.get_shape(), newData);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace normalize_l2
|
||||
|
||||
bool NormalizeL2Transformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
|
||||
if (!LayerTransformation::canBeTransformed(context, operation)) {
|
||||
@ -106,11 +106,11 @@ bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph
|
||||
const auto type = scalesConst->get_output_element_type(0);
|
||||
switch (type) {
|
||||
case ngraph::element::Type_t::f16: {
|
||||
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
|
||||
newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f16>::value_type>(*scalesConst);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f32: {
|
||||
newScalesConst = createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
|
||||
newScalesConst = normalize_l2::createNewScalesConst<ngraph::element_type_traits<ngraph::element::Type_t::f32>::value_type>(*scalesConst);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -44,7 +44,6 @@ addIeTargetTest(
|
||||
)
|
||||
|
||||
ie_faster_build(${TARGET_NAME}
|
||||
UNITY
|
||||
PCH PRIVATE "precomp.hpp"
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user