[GNA] Fixed concat axis checks (#18281)

* Fixed concat axis checks

* clang fix

* Enabled the rigth concat axis check

* is_concat_supported based on ngraph

* fixed concat patterns checks

* code style fix

* removed unused helpers

---------

Co-authored-by: Marcin Kacprzak <marcin.kacprzak@intel.com>
This commit is contained in:
Mikhail Ryzhov
2023-07-07 15:33:47 +02:00
committed by GitHub
parent 33b457b097
commit df014637c9
6 changed files with 102 additions and 187 deletions

View File

@@ -864,13 +864,59 @@ bool Limitations::is_split_supported(const std::shared_ptr<ov::Node>& node, bool
return is_aligned;
}
bool Limitations::is_concat_supported(const std::shared_ptr<const ov::Node>& node) {
bool Limitations::is_concat_supported(const std::shared_ptr<const ov::Node>& node, bool is_exception_allowed) {
OPENVINO_ASSERT(node, "Concat node is empty!");
auto concat_node = std::dynamic_pointer_cast<const Concat>(node);
const ov::Shape& output_shape = concat_node->get_output_shape(0);
const ov::Shape& concat_shape_out = concat_node->get_output_shape(0);
auto axis = concat_node->get_axis();
return graph_utils::get_first_valuable_dim_id(output_shape) == axis;
std::function<bool(std::shared_ptr<ov::Node>)> is_skipped_layer = [](std::shared_ptr<ov::Node> node) {
return graph_utils::is_non_functional(node) || graph_utils::is_split(node) || graph_utils::is_copy(node) ||
graph_utils::is_activation(node);
};
size_t skipped_ops_count = 0;
bool is_interleaved = false;
for (size_t i = 0; i < concat_node->inputs().size(); ++i) {
auto concat_input =
graph_utils::get_prev_node_skipping_certain(concat_node->get_input_node_shared_ptr(i), is_skipped_layer);
if (ov::op::util::is_parameter(concat_input) || ov::op::util::is_constant(concat_input)) {
skipped_ops_count++;
}
const ov::Shape concat_input_shape = concat_input->get_output_shape(0);
// graph compiler changes the concat axis if one of the inputs is interleaved layer output
if (graph_utils::squeeze_shape(concat_input_shape).size() >= 2 && graph_utils::is_interleaved(concat_input)) {
is_interleaved = true;
}
}
bool is_supported = false;
if (skipped_ops_count == concat_node->inputs().size()) {
is_supported = true;
} else if (is_interleaved) {
// TODO: need to extend interleaved layers detection patterns when migration to ngraph is finished.
// make interleaved shape
ov::Shape tr_shape(concat_shape_out);
std::rotate(tr_shape.begin(), tr_shape.begin() + 1, tr_shape.end());
// make interleaved order
std::vector<size_t> tr_order(concat_shape_out.size());
std::iota(tr_order.begin(), tr_order.end(), 0);
std::rotate(tr_order.begin(), tr_order.begin() + 1, tr_order.end());
const int64_t tr_axis = std::distance(tr_order.begin(), std::find(tr_order.begin(), tr_order.end(), axis));
is_supported = graph_utils::get_first_valuable_dim_id(tr_shape) == tr_axis;
} else {
is_supported = graph_utils::get_first_valuable_dim_id(concat_shape_out) == axis;
}
if (!is_supported && is_exception_allowed) {
THROW_GNA_EXCEPTION << concat_node->get_friendly_name()
<< " Unsupported concatenation axis=" << concat_node->get_axis()
<< " for input dimensions: " << concat_node->get_input_shape(0);
}
return is_supported;
}
bool Limitations::is_forward_transposed_concat_supported(const std::shared_ptr<const ov::Node>& node,
@@ -967,6 +1013,8 @@ bool Limitations::is_op_supported(const std::shared_ptr<ov::Node>& node,
return SupportedElementTypes::IsConstantTypeSupported(node->get_element_type(), is_exception_allowed);
} else if (auto conv = std::dynamic_pointer_cast<ov::intel_gna::op::GNAConvolution>(node)) {
return is_conv_supported(conv, gna_precision, is_exception_allowed);
} else if (auto concat = std::dynamic_pointer_cast<Concat>(node)) {
return is_concat_supported(concat, is_exception_allowed);
} else if (auto fully_connected = std::dynamic_pointer_cast<ngraph::op::FullyConnected>(node)) {
return is_fc_supported(fully_connected, is_exception_allowed);
} else if (ov::intel_gna::graph_utils::is_pooling(node)) {
@@ -999,9 +1047,13 @@ void Limitations::check_all_ops_supported(const std::shared_ptr<ov::Model>& mode
std::stringstream error;
// Walk through the transformed model
for (auto& op : model->get_ops()) {
if (!is_op_supported(op, gna_precision, true)) {
error << "The plugin does not support layer " << op->get_friendly_name() << " (type " << op->get_type_name()
<< ")!" << std::endl;
try {
if (!is_op_supported(op, gna_precision, true)) {
error << "The plugin does not support layer " << op->get_friendly_name() << " (type "
<< op->get_type_name() << ")!" << std::endl;
}
} catch (const InferenceEngine::GeneralError& e) {
error << e.what() << std::endl;
}
}
if (!error.str().empty()) {
@@ -1013,144 +1065,6 @@ bool Limitations::use_only_16bit_convolution_weights() const {
return m_use_only_16bit_conv_weights;
}
IE_SUPPRESS_DEPRECATED_START
bool Limitations::validate_concat_axis(const InferenceEngine::CNNLayerPtr layer, std::string& errMessage) {
LayerInfo info(layer);
auto concat_layer = info.as<InferenceEngine::ConcatLayer*>();
IE_ASSERT(concat_layer);
auto dims_size = concat_layer->insData[0].lock()->getDims().size();
auto in_dims = concat_layer->insData[0].lock()->getDims();
auto concat_axis = concat_layer->_axis;
if (dims_size >= 2) {
InferenceEngine::CNNLayerPtr prev_layer, pre_prev_layer;
// Skip all convolutions in this check, they will be handled during concat primitive creation
auto isFusableWithConv = [](InferenceEngine::CNNLayerPtr ptr) {
return (LayerInfo(ptr).isFusableWithConv() || LayerInfo(ptr).isNonFunctional() ||
(LayerInfo(ptr).isPermute() &&
((ptr->input()->getLayout() == InferenceEngine::Layout::NCHW &&
ptr->GetParamAsInts("order") ==
permute::GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)) ||
(ptr->input()->getLayout() == InferenceEngine::Layout::CHW &&
ptr->GetParamAsInts("order") == std::vector<int32_t>{0, 2, 1} /* NCW to NWC */))));
};
for (size_t input_idx = 0; input_idx != concat_layer->insData.size(); input_idx++) {
prev_layer =
InferenceEngine::CNNNetPrevLayerSkipCertain(layer, static_cast<int>(input_idx), isFusableWithConv);
if (prev_layer && LayerInfo(prev_layer).isConvolution())
return true;
}
// Look for trivial cases which will be flattened later
// for explanation of what is meant by trivial case,
// look to FlattenTrivialConcatPass comments
// TODO: detection of trivial cases could be moved to one common place
// when all transformations are migrated to ngraph
bool is_not_trivial_concat = false;
// Concatentaion of consts and input parameters only is supported, even if first dimentsion of input
// parameter >
// 1
bool concat_all_const_or_inputs = false;
// If concat axis > 0, detect any dimension > 1 before the concat axis
if (concat_axis > 0) {
for (unsigned int axis = 0; axis < concat_axis; axis++) {
if (in_dims[axis] > 1) {
is_not_trivial_concat = true;
break;
}
}
// If concat axis == 0, detect any preceding functional layer's input
// with 0'th dimension > 1, but take into account that some layers need to be skipped
} else {
concat_all_const_or_inputs = true;
for (size_t input_idx = 0; input_idx != concat_layer->insData.size(); input_idx++) {
if (concat_layer->insData[input_idx].lock()->getDims()[0] != 1) {
// First we're checking concat input layers
prev_layer = InferenceEngine::CNNNetPrevLayerSkipCertain(
concat_layer,
static_cast<int>(input_idx),
[](InferenceEngine::CNNLayerPtr ptr) {
return LayerInfo(ptr).isNonFunctional() || LayerInfo(ptr).isFakeQuantize();
});
IE_ASSERT(prev_layer);
if ((LayerInfo(prev_layer).isInput() && prev_layer->outData[0]->getDims()[0] == 1) ||
LayerInfo(prev_layer).isConst()) {
continue;
} else if ((LayerInfo(prev_layer).isInput() && prev_layer->outData[0]->getDims()[0] != 1)) {
is_not_trivial_concat = true;
break;
}
// If it's not clear still if concat is supported,
// we're moving one more layer back to see the dimensions
pre_prev_layer = InferenceEngine::CNNNetPrevLayerSkipCertain(
prev_layer,
0,
[](InferenceEngine::CNNLayerPtr ptr) {
return LayerInfo(ptr).isNonFunctional() || LayerInfo(ptr).isFakeQuantize() ||
LayerInfo(ptr).isSplit();
});
IE_ASSERT(pre_prev_layer);
if (LayerInfo(pre_prev_layer).isConst()) {
continue;
} else if (LayerInfo(pre_prev_layer).isPermute()) {
continue;
}
concat_all_const_or_inputs = false;
if (LayerInfo(pre_prev_layer).isInput() && pre_prev_layer->outData[0]->getDims()[0] == 1)
continue;
if (pre_prev_layer->outData[0]->getDims()[0] != 1) {
is_not_trivial_concat = true;
break;
}
}
}
}
// This is a trivial concat or it isn't a 'not trivial one' :-)
// it can be flattened and we're allowing it
if (!is_not_trivial_concat || concat_all_const_or_inputs)
return true;
// For interleaved inputs start checking from axis 1
// and allow concatenation on axis 0 only when all other dimesions = 1
std::rotate(in_dims.begin(), in_dims.begin() + 1, in_dims.end());
concat_axis == 0 ? concat_axis = static_cast<unsigned int>(dims_size - 1) : concat_axis--;
// Looking for any axis with dimension > 1 before concatentaion axis;
// in general such concatenation is unsupported
auto end_dim = in_dims.begin() + concat_axis;
auto unsupported_concat_axis = std::find_if(in_dims.begin(), end_dim, [](const size_t& in_dim) {
return (in_dim > 1);
});
if (unsupported_concat_axis != end_dim) {
auto dims = concat_layer->insData[0].lock()->getDims();
std::ostringstream in_dims_oss;
std::copy(dims.begin(), std::prev(dims.end()), std::ostream_iterator<size_t>(in_dims_oss, ","));
if (!dims.empty()) {
in_dims_oss << dims.back();
}
errMessage = "[ WARNING ] Topology with layer: " + layer->name + ", type: " + layer->type +
", and concatenation axis(" + std::to_string(concat_layer->_axis) + ") for input dimensions(" +
in_dims_oss.str() + ") not supported\n";
return false;
}
}
return true;
}
bool Limitations::validate_conv_concat_axis(const InferenceEngine::ConcatLayer* concat_layer) {
IE_ASSERT(concat_layer);
auto dims_size = concat_layer->insData[0].lock()->getDims().size();
@@ -1250,10 +1164,6 @@ bool Limitations::are_layers_supported(InferenceEngine::CNNNetwork& network, std
", and batch size(" + std::to_string(output_batch_size) + ") not supported";
check_result = false;
}
} else if (info.isConcat()) {
if (!validate_concat_axis(layer, errMessage)) {
THROW_GNA_EXCEPTION << errMessage;
}
}
},
false);

View File

@@ -234,7 +234,7 @@ public:
bool is_pooling_supported(const std::shared_ptr<ov::intel_gna::op::GNAMaxPool> max_pool,
bool is_exception_allowed = false);
static bool is_concat_supported(const std::shared_ptr<const ov::Node>& node);
static bool is_concat_supported(const std::shared_ptr<const ov::Node>& node, bool is_exception_allowed);
static bool is_forward_transposed_concat_supported(const std::shared_ptr<const ov::Node>& node,
const AxisVector& order);
static bool is_backward_transposed_concat_supported(const std::shared_ptr<const ov::Node>& node,
@@ -306,10 +306,6 @@ private:
size_t get_memory_alignment_bytes(const target::DeviceVersion& target) const;
IE_SUPPRESS_DEPRECATED_START
static bool validate_concat_axis(const InferenceEngine::CNNLayerPtr layer, std::string& errMessage);
IE_SUPPRESS_DEPRECATED_END
bool m_use_only_16bit_conv_weights = false;
size_t m_mem_alignment = 0;
std::shared_ptr<cnn2d::AbstractValidator> m_cnn_validator;

View File

@@ -244,6 +244,34 @@ inline bool is_activation(const std::shared_ptr<ngraph::Node>& node) noexcept {
return is_activation(node.get());
}
inline bool is_non_functional(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<ov::opset12::Reshape>(node) != nullptr ||
std::dynamic_pointer_cast<ov::opset12::Squeeze>(node) != nullptr ||
std::dynamic_pointer_cast<ov::opset12::Unsqueeze>(node) != nullptr ||
std::dynamic_pointer_cast<ov::opset12::FakeQuantize>(node) != nullptr;
}
inline bool is_copy(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<ov::intel_gna::op::Copy>(node) != nullptr;
}
inline bool is_matmul(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<ov::opset12::MatMul>(node) != nullptr;
}
inline bool is_fully_connected(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<ngraph::op::FullyConnected>(node) != nullptr;
}
inline bool is_split(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<ov::opset12::Split>(node) != nullptr ||
std::dynamic_pointer_cast<ov::opset12::VariadicSplit>(node) != nullptr;
}
inline bool is_interleaved(const std::shared_ptr<ov::Node>& node) {
return is_matmul(node) || is_fully_connected(node);
}
inline bool is_gna_precision_agnostic(std::shared_ptr<ngraph::Node> node) {
return ((std::dynamic_pointer_cast<ngraph::opset9::VariadicSplit>(node) != nullptr) ||
(std::dynamic_pointer_cast<ngraph::opset9::Split>(node) != nullptr) ||

View File

@@ -69,7 +69,7 @@ struct ReLUConcatAxis {
return std::make_shared<ngraph::Function>(results, params, getName());
}
static const char* getMatch() {
return "type: Concat, and concatenation axis(";
return "Unsupported concatenation axis";
}
};
@@ -115,7 +115,7 @@ struct MatmulConcatAxis {
return std::make_shared<ngraph::Function>(results, params, getName());
}
static const char* getMatch() {
return "type: Concat, and concatenation axis(";
return "Unsupported concatenation axis";
}
};
@@ -203,7 +203,7 @@ struct ConvNHWCConcatAxis {
return std::make_shared<ngraph::Function>(results, params, getName());
}
static const char* getMatch() {
return "type: Concat, and concatenation axis(";
return "Unsupported concatenation axis";
}
};
@@ -261,7 +261,7 @@ struct ConvConcatNHWCAxis {
return std::make_shared<ngraph::Function>(results, params, getName());
}
static const char* getMatch() {
return "type: Concat, and concatenation axis(";
return "Unsupported concatenation axis";
}
};
@@ -337,7 +337,7 @@ struct ConvConcatConcatNHWCAxis {
return std::make_shared<ngraph::Function>(results, params, getName());
}
static const char* getMatch() {
return "type: Concat, and concatenation axis(";
return "Unsupported concatenation axis";
}
};
@@ -500,12 +500,12 @@ using ConvConcatConcatNHWCRestrictionsNeg = ConcatRestrictions<ConvConcatConcatN
using ConvConcatConcatNHWCRestrictionsPos = ConcatRestrictions<ConvConcatConcatNHWCAxis>;
using TransposeTransposeConcatPos = ConcatRestrictions<TransposeTransposeConcat>;
TEST_P(ReLUConcatRestrictionsNeg, CompareWithRefImpl) {
ExpectLoadNetworkToThrow(getMatch());
};
// TODO: this test is left for future when GNA plugin handles const tranposition required for concats with interleaved
// layers
// TODO: those tests are left for future when GNA plugin handles const tranposition required for concats with
// interleaved layers
// TEST_P(ReLUConcatRestrictionsNeg, CompareWithRefImpl) {
// ExpectLoadNetworkToThrow(getMatch());
// };
//
// TEST_P(ReLUConcatRestrictionsPos, CompareWithRefImpl) {
// Run();
//};

View File

@@ -9,8 +9,7 @@
using namespace SubgraphTestsDefinitions;
namespace {
std::vector<std::vector<size_t>> inputs1{{{1, 8}}, {{8, 1}}};
std::vector<std::vector<size_t>> inputs2{{{16, 2}}, {{8, 2}}};
std::vector<std::vector<size_t>> inputs1{{{1, 8}}, {{8, 1}}, {{16, 2}}, {{8, 2}}};
std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
@@ -24,10 +23,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_concat_permute,
::testing::Values(CommonTestUtils::DEVICE_GNA)),
PermuteConcatConcatPermute::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_concat_permute,
PermuteConcatConcatPermuteNeg,
::testing::Combine(::testing::ValuesIn(inputs2),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA)),
PermuteConcatConcatPermute::getTestCaseName);
} // namespace

View File

@@ -9,14 +9,9 @@
using namespace SubgraphTestsDefinitions;
namespace {
std::vector<std::vector<std::vector<size_t>>> inputs{
{{1, 8}, {1, 0}, {1, 0}},
};
std::vector<std::vector<std::vector<size_t>>> inputsNeg{
{{32, 2}, {1, 0}, {1, 0}},
{{8, 2}, {1, 0}, {1, 0}},
};
std::vector<std::vector<std::vector<size_t>>> inputs{{{1, 8}, {1, 0}, {1, 0}},
{{32, 2}, {1, 0}, {1, 0}},
{{8, 2}, {1, 0}, {1, 0}}};
std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
@@ -30,11 +25,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_permute,
::testing::Values(CommonTestUtils::DEVICE_GNA)),
PermuteConcatPermute::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_permute_concat_permute,
PermuteConcatPermuteNeg,
::testing::Combine(::testing::ValuesIn(inputsNeg),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA)),
PermuteConcatPermute::getTestCaseName);
} // namespace