Corrected Concat checks

This commit is contained in:
Mikhail Ryzhov 2023-03-23 16:09:59 +01:00
parent 3c5aa8c53d
commit 1163b926ee

View File

@ -688,22 +688,6 @@ static bool ValidateConcatAxis(const InferenceEngine::CNNLayerPtr layer, std::st
if (dims_size >= 2) {
InferenceEngine::CNNLayerPtr prev_layer, pre_prev_layer;
// Skip all convolutions in this check, they will be handled during concat primitive creation
auto isFusableWithConv = [](InferenceEngine::CNNLayerPtr ptr) {
return (LayerInfo(ptr).isFusableWithConv() || LayerInfo(ptr).isNonFunctional() ||
(LayerInfo(ptr).isPermute() &&
((ptr->input()->getLayout() == InferenceEngine::Layout::NCHW &&
ptr->GetParamAsInts("order") ==
permute::GetPermuteOrder(InferenceEngine::Layout::NCHW, InferenceEngine::Layout::NHWC)) ||
(ptr->input()->getLayout() == InferenceEngine::Layout::CHW &&
ptr->GetParamAsInts("order") == std::vector<int32_t>{0, 2, 1} /* NCW to NWC */))));
};
for (auto input_idx = 0; input_idx != concat_layer->insData.size(); input_idx++) {
prev_layer = InferenceEngine::CNNNetPrevLayerSkipCertain(layer, input_idx, isFusableWithConv);
if (prev_layer && LayerInfo(prev_layer).isConvolution())
return true;
}
// Look for trivial cases which will be flattened later
// for explanation of what is meant by trivial case,
@ -783,10 +767,6 @@ static bool ValidateConcatAxis(const InferenceEngine::CNNLayerPtr layer, std::st
if (!is_not_trivial_concat || concat_all_const_or_inputs)
return true;
// For interleaved inputs start checking from axis 1
// and allow concatenation on axis 0 only when all other dimesions = 1
std::rotate(in_dims.begin(), in_dims.begin() + 1, in_dims.end());
concat_axis == 0 ? concat_axis = static_cast<unsigned int>(dims_size - 1) : concat_axis--;
// Looking for any axis with dimension > 1 before concatentaion axis;
// in general such concatenation is unsupported