From bd949c6bafbb44c66bd49a1201f711d94b70b1c3 Mon Sep 17 00:00:00 2001 From: Gabriele Galiero Casay Date: Tue, 9 Mar 2021 20:58:25 +0100 Subject: [PATCH] BinaryConvolution Reference Implementation (#4278) * Add BinaryConvolution unit tests. * Changed types to u1. * Add BIN precision handling in TestCase class. * Refactored validate and infer types to enhance dynamic shape inference * Add type_prop test to cover invalid op cases and dynamic shapes * Fix style * Disable check for float type of data batch input * Add type_prop test for incompatible input channels in inputs * Disable backend unit tests * Fix style * Add reference implementation * Add backend tests * Add single layer tests * Add check for float element type of batch data input * Refactor backend test cases to compare with regular convolution * Add serialization tests * Clean up * Add 1D and 3D tests into op_eval * Changes in reference implementation to improve readability * Add ticket information for todo tasks * Fix implementation misbehavior for filter channels * Add backend unit tests to cover strides, dilations, padding, channels and batches * Add end of line into files * Change name of type_prop unit tests * Simplified lambda to get spatial dimensions of filters * Add comment to support filters input as Parameter * Add namespace details for BinaryConvolution utility functions * Address review comments Co-authored-by: jdanieck --- .../single_layer/binary_convolution.cpp | 63 ++ .../single_layer_tests/binary_convolution.cpp | 72 ++ .../skip_tests_config.cpp | 2 + .../single_layer_tests/binary_convolution.hpp | 14 + .../single_layer/binary_convolution.hpp | 48 ++ .../src/single_layer/binary_convolution.cpp | 79 +++ .../runtime/reference/binary_convolution.hpp | 224 ++++++ ngraph/core/src/op/binary_convolution.cpp | 130 ++-- .../tests/test_ngraph/test_create_op.py | 2 +- ngraph/test/CMakeLists.txt | 2 + ngraph/test/backend/binary_convolution.in.cpp | 654 ++++++++++++++++++ ngraph/test/op_eval/binary_convolution.cpp | 245 +++++++ ngraph/test/runtime/ie/unit_test.manifest | 11 + .../runtime/interpreter/evaluates_map.cpp | 50 ++ .../runtime/interpreter/opset_int_tbl.hpp | 1 + .../runtime/interpreter/unit_test.manifest | 2 +- ngraph/test/type_prop/binary_convolution.cpp | 323 ++++++++- ngraph/test/util/engine/ie_engines.cpp | 2 +- 18 files changed, 1860 insertions(+), 64 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/serialization/single_layer/binary_convolution.cpp create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/binary_convolution.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/binary_convolution.hpp create mode 100644 inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/binary_convolution.hpp create mode 100644 inference-engine/tests/functional/shared_test_classes/src/single_layer/binary_convolution.cpp create mode 100644 ngraph/core/reference/include/ngraph/runtime/reference/binary_convolution.hpp create mode 100644 ngraph/test/backend/binary_convolution.in.cpp create mode 100644 ngraph/test/op_eval/binary_convolution.cpp diff --git a/inference-engine/tests/functional/inference_engine/serialization/single_layer/binary_convolution.cpp b/inference-engine/tests/functional/inference_engine/serialization/single_layer/binary_convolution.cpp new file mode 100644 index 00000000000..960c52728f1 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/serialization/single_layer/binary_convolution.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "shared_test_classes/single_layer/binary_convolution.hpp" + +using namespace LayerTestsDefinitions; + +namespace { +TEST_P(BinaryConvolutionLayerTest, Serialize) { + Serialize(); +} + +const std::vector netPrecisions = { + InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16}; +const std::vector> kernels = {{3, 5}}; +const std::vector> strides = {{1, 3}}; +const std::vector> padBegins = {{0, 3}}; +const std::vector> padEnds = {{0, 3}}; +const std::vector> dilations = {{3, 1}}; +const std::vector numOutChannels = {5}; +const std::vector pad_values = {0, 1}; + +const auto binConv2DParams_ExplicitPadding = ::testing::Combine( + ::testing::ValuesIn(kernels), ::testing::ValuesIn(strides), + ::testing::ValuesIn(padBegins), ::testing::ValuesIn(padEnds), + ::testing::ValuesIn(dilations), ::testing::ValuesIn(numOutChannels), + ::testing::Values(ngraph::op::PadType::EXPLICIT), + ::testing::ValuesIn(pad_values)); +const auto binConv2DParams_AutoPadValid = ::testing::Combine( + ::testing::ValuesIn(kernels), ::testing::ValuesIn(strides), + ::testing::Values(std::vector({0, 0})), + ::testing::Values(std::vector({0, 0})), + ::testing::ValuesIn(dilations), ::testing::ValuesIn(numOutChannels), + ::testing::Values(ngraph::op::PadType::VALID), + ::testing::ValuesIn(pad_values)); + +INSTANTIATE_TEST_CASE_P( + smoke_BinaryConvolution2D_Serialization_ExplicitPadding, BinaryConvolutionLayerTest, + ::testing::Combine( + binConv2DParams_ExplicitPadding, ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(std::vector({1, 3, 30, 30})), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + BinaryConvolutionLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P( + smoke_BinaryConvolution2D__Serialization_AutoPadValid, BinaryConvolutionLayerTest, + ::testing::Combine( + binConv2DParams_AutoPadValid, ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(std::vector({1, 3, 30, 30})), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + BinaryConvolutionLayerTest::getTestCaseName); +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/binary_convolution.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/binary_convolution.cpp new file mode 100644 index 00000000000..10d43813d47 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/binary_convolution.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/binary_convolution.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { + +const std::vector netPrecisions = { + InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP16}; + +/* ============= 2D Binary Convolution ============= */ +const std::vector> kernels = {{3, 3}, {3, 5}}; +const std::vector> strides = {{1, 1}, {1, 3}}; +const std::vector> padsBegin = {{0, 0}, {0, 3}}; +const std::vector> padsEnd = {{0, 0}, {0, 3}}; +const std::vector> dilations = {{1, 1}, {3, 1}}; +const std::vector numOutChannels = {1, 5}; +const std::vector padValues = {0, 1}; + +const auto binConv2DParams_ExplicitPadding = ::testing::Combine( + ::testing::ValuesIn(kernels), + ::testing::ValuesIn(strides), + ::testing::ValuesIn(padsBegin), + ::testing::ValuesIn(padsEnd), + ::testing::ValuesIn(dilations), + ::testing::ValuesIn(numOutChannels), + ::testing::Values(ngraph::op::PadType::EXPLICIT), + ::testing::ValuesIn(padValues)); + +const auto binConv2DParams_ValidPadding = ::testing::Combine( + ::testing::ValuesIn(kernels), + ::testing::ValuesIn(strides), + ::testing::Values(std::vector({0, 0})), + ::testing::Values(std::vector({0, 0})), + ::testing::ValuesIn(dilations), + ::testing::ValuesIn(numOutChannels), + ::testing::Values(ngraph::op::PadType::VALID), + ::testing::ValuesIn(padValues)); + +INSTANTIATE_TEST_CASE_P( + smoke_BinaryConvolution2D_ExplicitPadding, BinaryConvolutionLayerTest, + ::testing::Combine( + binConv2DParams_ExplicitPadding, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(std::vector({1, 3, 30, 30})), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + BinaryConvolutionLayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P( + smoke_BinaryConvolution2D_AutoPadValid, BinaryConvolutionLayerTest, + ::testing::Combine( + binConv2DParams_ValidPadding, + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(std::vector({1, 3, 30, 30})), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + BinaryConvolutionLayerTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp index b9555d14390..016cb795e46 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp @@ -57,6 +57,8 @@ std::vector disabledTestPatterns() { R"(.*decomposition1_batch=5_hidden_size=10_input_size=30_.*tanh.relu.*_clip=0_linear_before_reset=1.*_targetDevice=CPU_.*)", // Skip platforms that do not support BF16 (i.e. sse, avx, avx2) R"(.*BF16.*(jit_avx(?!5)|jit_sse).*)", + // TODO: Incorrect blob sizes for node BinaryConvolution_X + R"(.*BinaryConvolutionLayerTest.*)" }; if (!InferenceEngine::with_cpu_x86_avx512_core()) { diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/binary_convolution.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/binary_convolution.hpp new file mode 100644 index 00000000000..811736401fd --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/binary_convolution.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/single_layer/binary_convolution.hpp" + +namespace LayerTestsDefinitions { + +TEST_P(BinaryConvolutionLayerTest, CompareWithRefs) { + Run(); +} +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/binary_convolution.hpp b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/binary_convolution.hpp new file mode 100644 index 00000000000..157aa3faa0c --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/binary_convolution.hpp @@ -0,0 +1,48 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { + +using binConvSpecificParams = std::tuple< + InferenceEngine::SizeVector, // Kernel size + InferenceEngine::SizeVector, // Strides + std::vector, // Pads begin + std::vector, // Pads end + InferenceEngine::SizeVector, // Dilations + size_t, // Num Output channels + ngraph::op::PadType, // Padding type + float>; // Padding value + +using binaryConvolutionTestParamsSet = std::tuple< + binConvSpecificParams, // + InferenceEngine::Precision, // Network precision + InferenceEngine::Precision, // Input precision + InferenceEngine::Precision, // Output precision + InferenceEngine::Layout, // Input layout + InferenceEngine::Layout, // Output layout + InferenceEngine::SizeVector, // Input shape + LayerTestsUtils::TargetDevice>; // Device name + +class BinaryConvolutionLayerTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override; + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/src/single_layer/binary_convolution.cpp b/inference-engine/tests/functional/shared_test_classes/src/single_layer/binary_convolution.cpp new file mode 100644 index 00000000000..c2b3045f629 --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/src/single_layer/binary_convolution.cpp @@ -0,0 +1,79 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shared_test_classes/single_layer/binary_convolution.hpp" + +namespace LayerTestsDefinitions { + +std::string BinaryConvolutionLayerTest::getTestCaseName(testing::TestParamInfo obj) { + binConvSpecificParams binConvParams; + InferenceEngine::Precision netPrecision; + InferenceEngine::Precision inPrc, outPrc; + InferenceEngine::Layout inLayout, outLayout; + InferenceEngine::SizeVector inputShape; + std::string targetDevice; + + std::tie(binConvParams, netPrecision, inPrc, outPrc, inLayout, outLayout, inputShape, targetDevice) = obj.param; + + ngraph::op::PadType padType; + InferenceEngine::SizeVector kernel, stride, dilation; + std::vector padBegin, padEnd; + size_t convOutChannels; + float padValue; + std::tie(kernel, stride, padBegin, padEnd, dilation, convOutChannels, padType, padValue) = binConvParams; + + std::ostringstream result; + result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_"; + result << "KS=" << CommonTestUtils::vec2str(kernel) << "_"; + result << "S=" << CommonTestUtils::vec2str(stride) << "_"; + result << "PB=" << CommonTestUtils::vec2str(padBegin) << "_"; + result << "PE=" << CommonTestUtils::vec2str(padEnd) << "_"; + result << "D=" << CommonTestUtils::vec2str(dilation) << "_"; + result << "O=" << convOutChannels << "_"; + result << "AP=" << padType << "_"; + result << "PV=" << padValue << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "inPRC=" << inPrc.name() << "_"; + result << "outPRC=" << outPrc.name() << "_"; + result << "inL=" << inLayout << "_"; + result << "outL=" << outLayout << "_"; + result << "trgDev=" << targetDevice; + return result.str(); +} + +InferenceEngine::Blob::Ptr BinaryConvolutionLayerTest::GenerateInput(const InferenceEngine::InputInfo &info) const { + InferenceEngine::Blob::Ptr blobPtr; + const std::string name = info.name(); + // there is no input generation for filters since CPU implementation uses Constant + // TODO: enable filters input generation as Parameter when supported (Issue 50148) + if (name == "a_data_batch") { + blobPtr = FuncTestUtils::createAndFillBlob(info.getTensorDesc(), 1, 0, 1, 7235346); + } + return blobPtr; +} + +void BinaryConvolutionLayerTest::SetUp() { + binConvSpecificParams binConvParams; + InferenceEngine::Precision netPrecision; + InferenceEngine::SizeVector inputShape; + + std::tie(binConvParams, netPrecision, inPrc, outPrc, inLayout, outLayout, inputShape, targetDevice) = + this->GetParam(); + + ngraph::op::PadType padType; + InferenceEngine::SizeVector kernelSize, strides, dilations; + std::vector padsBegin, padsEnd; + size_t numOutChannels; + float padValue; + std::tie(kernelSize, strides, padsBegin, padsEnd, dilations, numOutChannels, padType, padValue) = binConvParams; + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {{"a_data_batch", inputShape}}); + // TODO: refactor build BinaryConvolution op to accept filters input as Parameter + auto binConv = ngraph::builder::makeBinaryConvolution(params[0], kernelSize, strides, padsBegin, padsEnd, dilations, padType, numOutChannels, + padValue); + ngraph::ResultVector results{std::make_shared(binConv)}; + function = std::make_shared(results, params, "BinaryConvolution"); +} + +} // namespace LayerTestsDefinitions diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/binary_convolution.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/binary_convolution.hpp new file mode 100644 index 00000000000..853b1032b53 --- /dev/null +++ b/ngraph/core/reference/include/ngraph/runtime/reference/binary_convolution.hpp @@ -0,0 +1,224 @@ +//***************************************************************************** +// Copyright 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include "ngraph/runtime/reference/convolution.hpp" +#include "ngraph/shape.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + namespace details + { + inline uint8_t extract_bit(uint8_t val, uint8_t bit) + { + return (uint8_t)((val >> bit) & 0x01); + } + + template + inline bool xnor(T a, T b) + { + return a == b; + } + + template + void binary_convolve_3D_channels(const ConvolutionParams& p, + const T_IN* batch, + const Shape& batch_shape, + const T_F* filter, + const Shape& filter_shape, + T_IN*& out, + const float pad_value) + { + const int n_bits = 8; + const int input_size_z = batch_shape[1]; + const int input_size_y = batch_shape[2]; + const int input_size_x = batch_shape[3]; + const int filter_size_z = filter_shape[1]; + const int filter_size_y = filter_shape[2]; + const int filter_size_x = filter_shape[3]; + const int dilated_filter_size_z = + filter_size_z + (filter_size_z - 1) * (p.dilation[0] - 1); + const int dilated_filter_size_y = + filter_size_y + (filter_size_y - 1) * (p.dilation[1] - 1); + const int dilated_filter_size_x = + filter_size_x + (filter_size_x - 1) * (p.dilation[2] - 1); + + const Shape input_channel_shape(++batch_shape.begin(), batch_shape.end()); + const size_t input_channel_size = shape_size(input_channel_shape); + const Shape filter_channel_shape(++filter_shape.begin(), filter_shape.end()); + const size_t filter_channel_size = shape_size(filter_channel_shape); + const T_IN bit_count = static_cast(filter_channel_size); + + for (int i_z = -p.pads_begin[0]; + i_z <= (p.pads_end[0] + input_size_z - dilated_filter_size_z); + i_z += p.strides[0]) + { + for (int i_y = -p.pads_begin[1]; + i_y <= (p.pads_end[1] + input_size_y - dilated_filter_size_y); + i_y += p.strides[1]) + { + for (int i_x = -p.pads_begin[2]; + i_x <= (p.pads_end[2] + input_size_x - dilated_filter_size_x); + i_x += p.strides[2]) + { + auto input_channel = batch; + size_t filter_channels_count = filter_shape[0]; + int filter_count = 0; + T_IN sum = 0; + while (filter_channels_count--) + { + T_IN popcount = 0; + for (int f_z = 0; f_z < filter_size_z; ++f_z) + { + for (int f_y = 0; f_y < filter_size_y; ++f_y) + { + for (int f_x = 0; f_x < filter_size_x; ++f_x) + { + int rel_i_z = i_z + (f_z * p.dilation[0]); + int rel_i_y = i_y + (f_y * p.dilation[1]); + int rel_i_x = i_x + (f_x * p.dilation[2]); + + bool padding = + !(in_range(rel_i_x, {0, input_size_x}) && + in_range(rel_i_y, {0, input_size_y}) && + in_range(rel_i_z, {0, input_size_z})); + int i_buf_idx = + (rel_i_z * input_size_y * input_size_x) + + (rel_i_y * input_size_x) + rel_i_x; + + T_IN in_val = padding + ? static_cast(pad_value) + : static_cast( + input_channel[i_buf_idx]); + + int f_buf_idx = + (f_z * filter_size_y * filter_size_x) + + (f_y * filter_size_x) + f_x; + + int f_byte_idx = + (f_buf_idx + filter_count) / n_bits; + int bit_idx = (n_bits - 1) - + ((f_buf_idx + filter_count) % n_bits); + uint8_t f_val = + extract_bit(filter[f_byte_idx], bit_idx); + + if (xnor(in_val, static_cast(f_val))) + { + popcount += static_cast(1); + } + } + } + } + input_channel += input_channel_size; + filter_count += filter_channel_size; + sum += (2 * popcount - bit_count); + } + *out = sum; + ++out; + } + } + } + } + } + + void validate_convolution_parameters(const Shape& in_shape, + const Shape& f_shape, + const Strides& strides, + const Strides& dilations, + const CoordinateDiff& pads_begin, + const CoordinateDiff& pads_end) + { + // this implementation supports 1D, 2D and 3D convolutions + NGRAPH_CHECK(in_shape.size() >= 3 && in_shape.size() <= 5, + "Unsupported input rank: ", + in_shape); + + NGRAPH_CHECK(in_shape.size() == f_shape.size(), + "Incompatible input ranks: ", + in_shape.size(), + " and ", + f_shape.size()); + + const auto spatial_dims = in_shape.size() - 2; + NGRAPH_CHECK(strides.size() == spatial_dims, + "Strides not definied for all and only spatial dimensions"); + + NGRAPH_CHECK(dilations.size() == spatial_dims, + "Dilations not defined for all and only spatial dimensions"); + + NGRAPH_CHECK((pads_begin.size() == pads_end.size()) && + (pads_begin.size() == spatial_dims), + "Pads not defined for all and only spatial dimensions"); + } + + template + void binary_convolution(const T_IN* in, + const T_F* f, + T_IN* out, + const Shape& in_shape, + const Shape& f_shape, + const Shape& out_shape, + const Strides& strides, + const Strides& dilations, + const CoordinateDiff& pads_begin, + const CoordinateDiff& pads_end, + const float pad_value) + { + validate_convolution_parameters( + in_shape, f_shape, strides, dilations, pads_begin, pads_end); + + // here we are converting all param types to int's to avoid arithmetic issues + // (e.g signed + unsigned) in indexes calculation later + ConvolutionParams params{strides, dilations, pads_begin, pads_end}; + + // here we are extending spatial dimensions to 3D, because we are going to use 3D + // convolution implementation to convolve also in 1D & 2D case + Shape input_shape{in_shape}; + Shape filters_shape{f_shape}; + if (in_shape.size() < 5) + { + extend_to_3D(params, input_shape, filters_shape); + } + + const size_t batches_count = input_shape[in_batch_axis]; + const Shape batch_shape(++input_shape.begin(), input_shape.end()); + const size_t batch_size = shape_size(batch_shape); + + const size_t filters_count = filters_shape[filter_out_ch_axis]; + const Shape filter_shape(++filters_shape.begin(), filters_shape.end()); + const size_t filter_size = shape_size(filter_shape); + + auto batch = in; + for (size_t batch_idx = 0; batch_idx < batches_count; ++batch_idx) + { + auto filter = f; + for (size_t f_idx = 0; f_idx < filters_count; ++f_idx) + { + details::binary_convolve_3D_channels( + params, batch, batch_shape, filter, filter_shape, out, pad_value); + filter += filter_size; + } + batch += batch_size; + } + } + } // namespace reference + } // namespace runtime +} // namespace ngraph diff --git a/ngraph/core/src/op/binary_convolution.cpp b/ngraph/core/src/op/binary_convolution.cpp index c695fb21541..a9928853dc3 100644 --- a/ngraph/core/src/op/binary_convolution.cpp +++ b/ngraph/core/src/op/binary_convolution.cpp @@ -73,80 +73,116 @@ op::v1::BinaryConvolution::BinaryConvolution(const Output& data, void op::v1::BinaryConvolution::validate_and_infer_types() { NGRAPH_OP_SCOPE(v1_BinaryConvolution_validate_and_infer_types); - const PartialShape& data_batch_shape = get_input_partial_shape(0); + const PartialShape& data_batch_pshape = get_input_partial_shape(0); element::Type data_batch_et = get_input_element_type(0); - const PartialShape& filters_shape = get_input_partial_shape(1); + const PartialShape& filters_pshape = get_input_partial_shape(1); - PartialShape result_shape = PartialShape::dynamic(); - if (data_batch_shape.rank().is_static()) - { - result_shape = - std::vector(data_batch_shape.rank().get_length(), Dimension::dynamic()); + NODE_VALIDATION_CHECK(this, + data_batch_et.is_real(), + "Data batch element type must be float point. Got: ", + data_batch_et); - if (data_batch_shape.rank().get_length() > 1) - { - result_shape[0] = data_batch_shape[0]; // batch size - } + // TODO: Add NodeValidationCheck to filters et once u1 is supported in nGraph Python API + // (#49517) - if (filters_shape.rank().is_static() && filters_shape.rank().get_length() > 1) - { - result_shape[1] = filters_shape[0]; // filter channel size - } - } + NODE_VALIDATION_CHECK(this, + data_batch_pshape.rank().compatible(filters_pshape.rank()), + "Shapes for data batch and filters must have same rank. Got: ", + data_batch_pshape, + "and ", + filters_pshape); if (m_strides.size() == 0) { - m_strides = conv_default_strides(this, data_batch_shape, filters_shape); + m_strides = conv_default_strides(this, data_batch_pshape, filters_pshape); } if (m_dilations.size() == 0) { - m_dilations = conv_default_strides(this, data_batch_shape, filters_shape); + m_dilations = conv_default_strides(this, data_batch_pshape, filters_pshape); } if (m_pads_begin.size() == 0) { - m_pads_begin = conv_default_padding(this, data_batch_shape, filters_shape); + m_pads_begin = conv_default_padding(this, data_batch_pshape, filters_pshape); } if (m_pads_end.size() == 0) { - m_pads_end = conv_default_padding(this, data_batch_shape, filters_shape); + m_pads_end = conv_default_padding(this, data_batch_pshape, filters_pshape); } - if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER) + PartialShape result_shape = PartialShape::dynamic(); + if (data_batch_pshape.rank().is_static() || filters_pshape.rank().is_static()) { - bool auto_padding_applied = false; - if (filters_shape.is_static()) + const bool is_data_batch_ps_static = data_batch_pshape.rank().is_static(); + const auto output_ps_rank = + is_data_batch_ps_static ? data_batch_pshape.rank() : filters_pshape.rank(); + const auto num_spatial_dims = output_ps_rank.get_length() - 2; + + NODE_VALIDATION_CHECK(this, + m_strides.size() == num_spatial_dims, + "Strides should be defined for all and only spatial features."); + + NODE_VALIDATION_CHECK(this, + m_dilations.size() == num_spatial_dims, + "Dilations should be defined for all and only spatial features."); + + NODE_VALIDATION_CHECK(this, + m_pads_begin.size() == num_spatial_dims && + m_pads_end.size() == num_spatial_dims, + "Pads should be defined for all and only spatial features."); + + result_shape = std::vector(output_ps_rank.get_length(), Dimension::dynamic()); + if (data_batch_pshape.rank().is_static()) { - m_pads_begin.clear(); - m_pads_end.clear(); - auto filter_shape = filters_shape.to_shape(); - filter_shape.erase(filter_shape.begin(), filter_shape.begin() + 2); // Remove {O,I} - auto_padding_applied = try_apply_auto_padding(data_batch_shape, - filter_shape, - m_strides, - m_dilations, - m_auto_pad, - m_pads_end, - m_pads_begin); + result_shape[0] = data_batch_pshape[0]; // batch size } - if (!auto_padding_applied) + if (filters_pshape.rank().is_static()) { - set_output_type(0, data_batch_et, result_shape); - return; + result_shape[1] = filters_pshape[0]; // filter channel size } + if (m_auto_pad == PadType::SAME_UPPER || m_auto_pad == PadType::SAME_LOWER) + { + bool auto_padding_applied = false; + if (filters_pshape.rank().is_static() && filters_pshape.rank().get_length() > 2) + { + m_pads_begin.clear(); + m_pads_end.clear(); + + const PartialShape filter_spatial_shape = [filters_pshape]() { + vector filter_dims{filters_pshape}; + filter_dims.erase(filter_dims.begin(), filter_dims.begin() + 2); // Remove {O,I} + return PartialShape{filter_dims}; + }(); + + if (filter_spatial_shape.is_static()) + { + auto_padding_applied = try_apply_auto_padding(data_batch_pshape, + filter_spatial_shape.to_shape(), + m_strides, + m_dilations, + m_auto_pad, + m_pads_end, + m_pads_begin); + } + } + if (!auto_padding_applied) + { + set_output_type(0, data_batch_et, result_shape); + return; + } + } + + result_shape = infer_convolution_forward(this, + data_batch_pshape, + Strides(num_spatial_dims, 1), + m_pads_begin, + m_pads_end, + filters_pshape, + m_strides, + m_dilations); } - - result_shape = infer_convolution_forward(this, - data_batch_shape, - Strides(data_batch_shape.rank().get_length() - 2, 1), - m_pads_begin, - m_pads_end, - filters_shape, - m_strides, - m_dilations); - set_output_type(0, data_batch_et, result_shape); } diff --git a/ngraph/python/tests/test_ngraph/test_create_op.py b/ngraph/python/tests/test_ngraph/test_create_op.py index eda402201df..c6096d84e36 100644 --- a/ngraph/python/tests/test_ngraph/test_create_op.py +++ b/ngraph/python/tests/test_ngraph/test_create_op.py @@ -36,7 +36,7 @@ integral_np_types = [ ] -@pytest.mark.parametrize("dtype", np_types) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_binary_convolution(dtype): strides = np.array([1, 1]) pads_begin = np.array([0, 0]) diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 8dd5469cef9..70f38ddc433 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -71,6 +71,7 @@ set(SRC ngraph_api.cpp node_input_output.cpp op.cpp + op_eval/binary_convolution.cpp op_eval/bucketize.cpp op_eval/floor_mod.cpp op_eval/gelu.cpp @@ -268,6 +269,7 @@ set(MULTI_TEST_SRC backend/convert.in.cpp backend/convert_like.in.cpp backend/convolution.in.cpp + backend/binary_convolution.in.cpp backend/cos.in.cpp backend/cosh.in.cpp backend/ctc_greedy_decoder.in.cpp diff --git a/ngraph/test/backend/binary_convolution.in.cpp b/ngraph/test/backend/binary_convolution.in.cpp new file mode 100644 index 00000000000..eb34975065e --- /dev/null +++ b/ngraph/test/backend/binary_convolution.in.cpp @@ -0,0 +1,654 @@ +//***************************************************************************** +// Copyright 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/runtime/tensor.hpp" +#include "runtime/backend.hpp" +#include "util/all_close.hpp" +#include "util/all_close_f.hpp" +#include "util/test_control.hpp" + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; + +template +static void BinaryConvolutionTest(const std::vector& inputs, + const Shape inputs_shape, + const std::vector& filters, + const Shape filters_shape, + const std::vector& outputs, + const Shape outputs_shape, + const Strides& strides, + const CoordinateDiff& padding, + const Strides& dilations, + const float pad_value = 0.0f) +{ + const CoordinateDiff pads_begin{padding}; + const CoordinateDiff pads_end{padding}; + const op::PadType auto_pad{op::PadType::EXPLICIT}; + + auto inputs_param = make_shared(element::from(), inputs_shape); + auto filters_const = make_shared(element::u1, filters_shape, &filters[0]); + auto bin_conv = make_shared( + inputs_param, + filters_const, + strides, + pads_begin, + pads_end, + dilations, + op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT, + pad_value, + auto_pad); + auto f = make_shared(bin_conv, ParameterVector{inputs_param}); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto input_tensor = backend->create_tensor(element::from(), inputs_shape); + copy_data(input_tensor, inputs); + auto result = backend->create_tensor(element::from(), outputs_shape); + + auto handle = backend->compile(f); + handle->call_with_validate({result}, {input_tensor}); + EXPECT_TRUE(test::all_close_f((outputs), read_vector(result), MIN_FLOAT_TOLERANCE_BITS)); +} + +template +static void ConvolutionTest(const std::vector& inputs, + const Shape inputs_shape, + const std::vector& filters, + const Shape filters_shape, + const std::vector& outputs, + const Shape outputs_shape, + const Strides& strides, + const CoordinateDiff& padding, + const Strides& dilations) +{ + const CoordinateDiff pads_begin{padding}; + const CoordinateDiff pads_end{padding}; + const op::PadType auto_pad{op::PadType::EXPLICIT}; + + auto inputs_param = make_shared(element::from(), inputs_shape); + auto filters_param = make_shared(element::from(), filters_shape); + auto conv = make_shared( + inputs_param, filters_param, strides, pads_begin, pads_end, dilations, auto_pad); + auto f = make_shared(conv, ParameterVector{inputs_param, filters_param}); + + auto backend = runtime::Backend::create("${BACKEND_NAME}"); + + auto input_tensor = backend->create_tensor(element::from(), inputs_shape); + copy_data(input_tensor, inputs); + auto filters_tensor = backend->create_tensor(element::from(), filters_shape); + copy_data(filters_tensor, filters); + auto result = backend->create_tensor(element::from(), outputs_shape); + + auto handle = backend->compile(f); + handle->call_with_validate({result}, {input_tensor, filters_tensor}); + EXPECT_TRUE(test::all_close_f((outputs), read_vector(result), MIN_FLOAT_TOLERANCE_BITS)); +} + +// clang-format off +// --------------------- 2D convolution ------------------------------------------ +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_1channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + + const Shape inputs_shape{1, 1, 4, 4}; + const std::vector inputs_conv{1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, 1.0f}; + + const std::vector inputs_bin_conv{1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f}; + + const std::vector filters_bin_conv{0xAA, 0x80}; // 10101010 10000000 + + const Shape outputs_shape{1, 1, 2, 2}; + const std::vector outputs{1.0f, 1.0f, + 3.0f, -1.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_1channel_padding_pad_val_0) +{ + const Strides strides{1, 1}; + const Strides dilations{1, 1}; + + const CoordinateDiff padding_conv{0, 0}; + const Shape inputs_conv_shape{1, 1, 6, 6}; + const std::vector inputs_conv{-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + const CoordinateDiff padding_bin_conv{1, 1}; + const Shape inputs_bin_conv_shape{1, 1, 4, 4}; + const std::vector inputs_bin_conv{1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f}; + + const std::vector filters_bin_conv{0xAA, 0x80}; // 10101010 10000000 + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{1.0f, -3.0f, -1.0f, 1.0f, + -3.0f, 1.0f, 1.0f, -5.0f, + -3.0f, 3.0f, -1.0f, 1.0f, + 1.0f, -5.0f, 1.0f, -3.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_bin_conv_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_bin_conv, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_conv_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_conv, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_1channel_padding_pad_val_1) +{ + const Strides strides{1, 1}; + const Strides dilations{1, 1}; + const float pad_value = 1.0f; + + const CoordinateDiff padding_conv{0, 0}; + const Shape inputs_conv_shape{1, 1, 6, 6}; + const std::vector inputs_conv{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + const CoordinateDiff padding_bin_conv{1, 1}; + const Shape inputs_bin_conv_shape{1, 1, 4, 4}; + const std::vector inputs_bin_conv{1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f}; + + const std::vector filters_bin_conv{0xAA, 0x80}; // 10101010 10000000 + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{3.0f, -1.0f, 1.0f, 3.0f, + -1.0f, 1.0f, 1.0f, -3.0f, + -1.0f, 3.0f, -1.0f, 3.0f, + 3.0f, -3.0f, 3.0f, -1.0f,}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_bin_conv_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_bin_conv, + dilations, + pad_value); + + ConvolutionTest( + inputs_conv, + inputs_conv_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_conv, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_1channel_stride) +{ + const Strides strides{2, 2}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + + const Shape inputs_shape{1, 1, 5, 5}; + const std::vector inputs_conv{-1.0f, 1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, 1.0f, 1.0f}; + + const std::vector inputs_bin_conv{0.0f, 1.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 1.0f, 1.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{-1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, -1.0f}; + const std::vector filters_bin_conv{0x2E, 0x00}; // 00101110 00000000 + + const Shape outputs_shape{1, 1, 2, 2}; + const std::vector outputs{-1.0f, 3.0f, + 1.0f, 1.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_1channel_dilation) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{2, 2}; + + const Shape inputs_shape{1, 1, 7, 7}; + const std::vector inputs_conv{1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f}; + + const std::vector inputs_bin_conv{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{-1.0f, 1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, -1.0f}; + const std::vector filters_bin_conv{0x6B, 0x00}; // 01101011 00000000 + + const Shape outputs_shape{1, 1, 3, 3}; + const std::vector outputs{-5.0f, -3.0f, -5.0f, + 5.0f, 1.0f, 3.0f, + -1.0f, -1.0f, 3.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, + bin_convolution_2D_1batch_1channel_strides_dilation_padding_pad_val_0) +{ + const Strides strides{2, 2}; + const Strides dilations{2, 2}; + + const CoordinateDiff padding_conv{0, 0}; + const Shape inputs_conv_shape{1, 1, 11, 11}; + const std::vector inputs_conv{ + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f,-1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + + const CoordinateDiff padding_bin_conv{2, 2}; + const Shape inputs_bin_conv_shape{1, 1, 7, 7}; + const std::vector inputs_bin_conv{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{-1.0f, 1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, -1.0f}; + const std::vector filters_bin_conv{0x6B, 0x00}; // 01101011 00000000 + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -5.0f, -5.0f, 5.0f, + 3.0f, -1.0f, 3.0f, 3.0f, + -1.0f, -1.0f, 3.0f, -3.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_bin_conv_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_bin_conv, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_conv_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_conv, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, + bin_convolution_2D_1batch_1channel_strides_dilation_padding_pad_val_1) +{ + const Strides strides{2, 2}; + const Strides dilations{2, 2}; + const float pad_value = 1.0f; + + const CoordinateDiff padding_conv{0, 0}; + const Shape inputs_conv_shape{1, 1, 11, 11}; + const std::vector inputs_conv{ + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + const CoordinateDiff padding_bin_conv{2, 2}; + const Shape inputs_bin_conv_shape{1, 1, 7, 7}; + const std::vector inputs_bin_conv{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{-1.0f, 1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, -1.0f}; + const std::vector filters_bin_conv{0x6B, 0x00}; // 01101011 00000000 + + const Shape outputs_shape{1, 1, 4, 4}; + const std::vector outputs{3.0f, 3.0f, 1.0f, -1.0f, + -1.0f, -5.0f, -5.0f, 3.0f, + 1.0f, -1.0f, 3.0f, 1.0f, + -3.0f, 1.0f, 5.0f, -1.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_bin_conv_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_bin_conv, + dilations, + pad_value); + + ConvolutionTest( + inputs_conv, + inputs_conv_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding_conv, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_1batch_2channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + + const Shape inputs_shape{1, 2, 4, 4}; + const std::vector inputs_conv{ + // channel 1 + 1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, + // channel 2 + -1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, -1.0f}; + const std::vector inputs_bin_conv{ + // channel 1 + 1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, + // channel 2 + 0.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f}; + + const Shape filters_shape{1, 2, 3, 3}; + const std::vector filters_conv{ + // channel 1 + 1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, + // channel 2 + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f}; + // 10101010 10101010 10000000 + const std::vector filters_bin_conv{0xAA, 0xAA, 0x80}; + + const Shape outputs_shape{1, 1, 2, 2}; + const std::vector outputs{2.0f, 2.0f, + 6.0f, -2.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} + +NGRAPH_TEST(${BACKEND_NAME}, bin_convolution_2D_2batch_1channel) +{ + const Strides strides{1, 1}; + const CoordinateDiff padding{0, 0}; + const Strides dilations{1, 1}; + + const Shape inputs_shape{2, 1, 4, 4}; + const std::vector inputs_conv{ + // batch 1 + 1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, + // batch 2 + -1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f}; + const std::vector inputs_bin_conv{ + // batch 1 + 1.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, + // batch 2 + 0.0f, 0.0f, 0.0f, 0.0f, + 1.0f, 1.0f, 1.0f, 0.0f, + 1.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f}; + + const Shape filters_shape{1, 1, 3, 3}; + const std::vector filters_conv{1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f}; + const std::vector filters_bin_conv{0xAA, 0x80}; // 10101010 10000000 + + const Shape outputs_shape{2, 1, 2, 2}; + const std::vector outputs{ + // batch 1 + 1.0f, 1.0f, + 3.0f, -1.0f, + // batch 2 + -3.0f, 3.0f, + 5.0f, -7.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} +// clang-format on diff --git a/ngraph/test/op_eval/binary_convolution.cpp b/ngraph/test/op_eval/binary_convolution.cpp new file mode 100644 index 00000000000..5ced1f7d0f6 --- /dev/null +++ b/ngraph/test/op_eval/binary_convolution.cpp @@ -0,0 +1,245 @@ +//***************************************************************************** +// Copyright 2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/runtime/tensor.hpp" +#include "runtime/backend.hpp" +#include "util/all_close.hpp" +#include "util/all_close_f.hpp" +#include "util/test_control.hpp" + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; + +template +static void BinaryConvolutionTest(const std::vector& inputs, + const Shape inputs_shape, + const std::vector& filters, + const Shape filters_shape, + const std::vector& outputs, + const Shape outputs_shape, + const Strides& strides, + const CoordinateDiff& padding, + const Strides& dilations) +{ + const CoordinateDiff pads_begin{padding}; + const CoordinateDiff pads_end{padding}; + const op::PadType auto_pad{op::PadType::EXPLICIT}; + float pad_value = 0; + + auto inputs_param = make_shared(element::from(), inputs_shape); + auto filters_const = make_shared(element::u1, filters_shape, &filters[0]); + auto bin_conv = make_shared( + inputs_param, + filters_const, + strides, + pads_begin, + pads_end, + dilations, + op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT, + pad_value, + auto_pad); + auto f = make_shared(bin_conv, ParameterVector{inputs_param}); + + auto backend = runtime::Backend::create("INTERPRETER"); + + auto input_tensor = backend->create_tensor(element::from(), inputs_shape); + copy_data(input_tensor, inputs); + auto result = backend->create_tensor(element::from(), outputs_shape); + + auto handle = backend->compile(f); + handle->call_with_validate({result}, {input_tensor}); + EXPECT_TRUE(test::all_close_f((outputs), read_vector(result), MIN_FLOAT_TOLERANCE_BITS)); +} + +template +static void ConvolutionTest(const std::vector& inputs, + const Shape inputs_shape, + const std::vector& filters, + const Shape filters_shape, + const std::vector& outputs, + const Shape outputs_shape, + const Strides& strides, + const CoordinateDiff& padding, + const Strides& dilations) +{ + const CoordinateDiff pads_begin{padding}; + const CoordinateDiff pads_end{padding}; + const op::PadType auto_pad{op::PadType::EXPLICIT}; + + auto inputs_param = make_shared(element::from(), inputs_shape); + auto filters_param = make_shared(element::from(), filters_shape); + auto conv = make_shared( + inputs_param, filters_param, strides, pads_begin, pads_end, dilations, auto_pad); + auto f = make_shared(conv, ParameterVector{inputs_param, filters_param}); + + auto backend = runtime::Backend::create("INTERPRETER"); + + auto input_tensor = backend->create_tensor(element::from(), inputs_shape); + copy_data(input_tensor, inputs); + auto filters_tensor = backend->create_tensor(element::from(), filters_shape); + copy_data(filters_tensor, filters); + auto result = backend->create_tensor(element::from(), outputs_shape); + + auto handle = backend->compile(f); + handle->call_with_validate({result}, {input_tensor, filters_tensor}); + EXPECT_TRUE(test::all_close_f((outputs), read_vector(result), MIN_FLOAT_TOLERANCE_BITS)); +} + +// --------------------- 1D convolution ------------------------------------------ +TEST(op_eval, bin_convolution_1D_1batch_1channel_no_padding) +{ + const Strides strides{1}; + const CoordinateDiff padding{0}; + const Strides dilations{1}; + + const Shape inputs_shape{1, 1, 5}; + const std::vector inputs_conv{1.0f, -1.0f, -1.0f, 1.0f, -1.0f}; + const std::vector inputs_bin_conv{1.0f, 0.0f, 0.0f, 1.0f, 0.0f}; + + const Shape filters_shape{1, 1, 3}; + const std::vector filters_conv{1.0f, -1.0f, 1.0f}; + const std::vector filters_bin_conv{0xA0}; // 1010 0000 + + const Shape outputs_shape{1, 1, 3}; + const std::vector outputs{1.0f, 1.0f, -3.0f}; + + BinaryConvolutionTest(inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest(inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} + +// --------------------- 3D convolution ------------------------------------------ +// clang-format off +NGRAPH_TEST(op_eval, bin_convolution_3D_1batch_1channel_no_padding) +{ + const Strides strides{1, 1, 1}; + const CoordinateDiff padding{0, 0, 0}; + const Strides dilations{1, 1, 1}; + + const Shape inputs_shape{1, 1, 4, 4, 4}; + const std::vector inputs_conv{ + // depth: 1 + 1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, + // depth: 2 + -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, + // depth: 3 + 1.0f, 1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, + // depth: 4 + 1.0f, -1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, 1.0f + }; + const std::vector inputs_bin_conv{ + // depth: 1 + 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, + // depth: 2 + 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, + // depth: 3 + 1.0f, 1.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.0f, + // depth: 4 + 1.0f, 0.0f, 1.0f, 0.0f, + 1.0f, 1.0f, 0.0f, 1.0f, + 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f + }; + + const Shape filters_shape{1, 1, 3, 3, 3}; + const std::vector filters_conv{ + // depth: 1 + 1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, + // depth: 2 + -1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, + // depth: 3 + 1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, -1.0f, + 1.0f, 1.0f, 1.0f}; + const std::vector filters_bin_conv{0xAA, 0xBB, 0xB2, 0xE0}; + + const Shape outputs_shape{1, 1, 2, 2, 2}; + const std::vector outputs{ + // depth: 1 + 13.0f, 3.0f, + -3.0f, -3.0f, + // depth: 2 + -3.0f, 5.0f, + 11.0f, -3.0f}; + + BinaryConvolutionTest( + inputs_bin_conv, + inputs_shape, + filters_bin_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); + + ConvolutionTest( + inputs_conv, + inputs_shape, + filters_conv, + filters_shape, + outputs, + outputs_shape, + strides, + padding, + dilations); +} +// clang-format off diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index c512ca11e45..487e8d30627 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1592,3 +1592,14 @@ IE_CPU.onnx_mvn_v6 # The test randomly fails in CI (MSVC2019 Debug) onnx_model_experimental_detectron_generate_proposals_single_image + +# Issue 49621: Incorrect blob sizes for node BinaryConvolution_X +bin_convolution_2D_1batch_1channel +bin_convolution_2D_1batch_1channel_padding_pad_val_0 +bin_convolution_2D_1batch_1channel_padding_pad_val_1 +bin_convolution_2D_1batch_1channel_stride +bin_convolution_2D_1batch_1channel_dilation +bin_convolution_2D_1batch_1channel_strides_dilation_padding_pad_val_0 +bin_convolution_2D_1batch_1channel_strides_dilation_padding_pad_val_1 +bin_convolution_2D_1batch_2channel +bin_convolution_2D_2batch_1channel diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 1cd0f01ee04..5168dce1bcc 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -214,6 +215,55 @@ namespace return true; } + namespace bin_conv_v1 + { + template + inline void evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + using T_IN = typename element_type_traits::value_type; + using T_F = typename element_type_traits::value_type; + + const auto in_data_ptr = inputs[0]->get_data_ptr(); + const auto filter_data_ptr = inputs[1]->get_data_ptr(); + auto out_data_ptr = outputs[0]->get_data_ptr(); + const auto in_shape = inputs[0]->get_shape(); + const auto filter_shape = inputs[1]->get_shape(); + const auto out_shape = outputs[0]->get_shape(); + + runtime::reference::binary_convolution(in_data_ptr, + filter_data_ptr, + out_data_ptr, + in_shape, + filter_shape, + out_shape, + op->get_strides(), + op->get_dilations(), + op->get_pads_begin(), + op->get_pads_end(), + op->get_pad_value()); + } + } // bin_conv_v1 + + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + switch (inputs[1]->get_element_type()) + { + case element::Type_t::u1: + bin_conv_v1::evaluate(op, outputs, inputs); + break; + default: + throw std::runtime_error( + "BinaryConvolution supports only u1 element type for filters input"); + break; + } + return true; + } + template bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 16453b5e413..5d0cc48d5b9 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -49,6 +49,7 @@ NGRAPH_OP(TensorIterator, op::v0) NGRAPH_OP(ROIPooling, op::v0) NGRAPH_OP(AvgPool, op::v1) +NGRAPH_OP(BinaryConvolution, ngraph::op::v1) NGRAPH_OP(ConvertLike, op::v1) NGRAPH_OP(Convolution, ngraph::op::v1) NGRAPH_OP(ConvolutionBackpropData, ngraph::op::v1) diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index eaa8b24263c..b33d89103a9 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -170,4 +170,4 @@ INTERPRETER.onnx_model_experimental_detectron_generate_proposals_single_image INTERPRETER.onnx_model_experimental_detectron_prior_grid_generator # Interpreter backend doesn't implement evaluate method for OP ExperimentalDetectronROIFeatureExtractor -INTERPRETER.onnx_model_experimental_detectron_roi_feature_extractor \ No newline at end of file +INTERPRETER.onnx_model_experimental_detectron_roi_feature_extractor diff --git a/ngraph/test/type_prop/binary_convolution.cpp b/ngraph/test/type_prop/binary_convolution.cpp index d6518a655c8..98ec779d24d 100644 --- a/ngraph/test/type_prop/binary_convolution.cpp +++ b/ngraph/test/type_prop/binary_convolution.cpp @@ -21,7 +21,7 @@ using namespace std; using namespace ngraph; -TEST(type_prop, binary_conv_v1_partial_auto_padding_same) +TEST(type_prop, bin_convolution_auto_padding_same) { const PartialShape data_batch_shape{1, 1, 5, 5}; const PartialShape filters_shape{1, 1, 3, 3}; @@ -34,7 +34,7 @@ TEST(type_prop, binary_conv_v1_partial_auto_padding_same) const auto auto_pad = op::PadType::SAME_LOWER; auto data_batch = make_shared(element::f32, data_batch_shape); - auto filters = make_shared(element::f32, filters_shape); + auto filters = make_shared(element::u1, filters_shape); auto conv = make_shared( data_batch, filters, strides, pads_begin, pads_end, dilations, mode, pad_value, auto_pad); @@ -44,10 +44,10 @@ TEST(type_prop, binary_conv_v1_partial_auto_padding_same) ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1})); } -TEST(type_prop, binary_conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_lower) +TEST(type_prop, bin_convolution_auto_padding_same_lower_spatial_dims_static) { const PartialShape data_batch_shape{Dimension::dynamic(), Dimension::dynamic(), 5, 5}; - const PartialShape filters_shape{1, 1, 3, 3}; + const PartialShape filters_shape{Dimension::dynamic(), Dimension::dynamic(), 3, 3}; Strides strides{1, 1}; CoordinateDiff pads_begin{0, 0}; CoordinateDiff pads_end{0, 0}; @@ -57,20 +57,21 @@ TEST(type_prop, binary_conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_lo const auto auto_pad = op::PadType::SAME_LOWER; auto data_batch = make_shared(element::f32, data_batch_shape); - auto filters = make_shared(element::f32, filters_shape); + auto filters = make_shared(element::u1, filters_shape); auto conv = make_shared( data_batch, filters, strides, pads_begin, pads_end, dilations, mode, pad_value, auto_pad); - ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), 1, 5, 5})); + ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme( + {Dimension::dynamic(), Dimension::dynamic(), 5, 5})); ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{1, 1})); ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1})); } -TEST(type_prop, binary_conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_upper) +TEST(type_prop, bin_convolution_auto_padding_same_upper_spatial_dims_static) { const PartialShape data_batch_shape{Dimension::dynamic(), Dimension::dynamic(), 5, 5}; - const PartialShape filters_shape{1, 1, 2, 2}; + const PartialShape filters_shape{Dimension::dynamic(), Dimension::dynamic(), 2, 2}; Strides strides{1, 1}; CoordinateDiff pads_begin{0, 0}; CoordinateDiff pads_end{0, 0}; @@ -80,20 +81,21 @@ TEST(type_prop, binary_conv_v1_partial_auto_padding_same_nc_dims_dynamic_same_up const auto auto_pad = op::PadType::SAME_UPPER; auto data_batch = make_shared(element::f32, data_batch_shape); - auto filters = make_shared(element::f32, filters_shape); + auto filters = make_shared(element::u1, filters_shape); auto conv = make_shared( data_batch, filters, strides, pads_begin, pads_end, dilations, mode, pad_value, auto_pad); - ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({Dimension::dynamic(), 1, 5, 5})); + ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme( + {Dimension::dynamic(), Dimension::dynamic(), 5, 5})); ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 0})); ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{1, 1})); } -TEST(type_prop, binary_conv_v1_partial_auto_padding_same_spatial_dims_dynamic) +TEST(type_prop, bin_convolution_auto_padding_same_data_batch_spatial_dims_dynamic) { const PartialShape data_batch_shape{1, 1, Dimension::dynamic(), 5}; - const PartialShape filters_shape{1, 1, 3, 3}; + const PartialShape filters_shape{Dimension::dynamic(), 1, 3, 3}; Strides strides{1, 1}; CoordinateDiff pads_begin{0, 0}; CoordinateDiff pads_end{0, 0}; @@ -103,12 +105,305 @@ TEST(type_prop, binary_conv_v1_partial_auto_padding_same_spatial_dims_dynamic) const auto auto_pad = op::PadType::SAME_LOWER; auto data_batch = make_shared(element::f32, data_batch_shape); - auto filters = make_shared(element::f32, filters_shape); + auto filters = make_shared(element::u1, filters_shape); auto conv = make_shared( data_batch, filters, strides, pads_begin, pads_end, dilations, mode, pad_value, auto_pad); - ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme({1, 1, Dimension::dynamic(), 5})); + ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme( + {1, Dimension::dynamic(), Dimension::dynamic(), 5})); ASSERT_EQ(conv->get_pads_begin(), (CoordinateDiff{0, 1})); ASSERT_EQ(conv->get_pads_end(), (CoordinateDiff{0, 1})); } + +TEST(type_prop, bin_convolution_dyn_data_batch) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + const auto data_batch = make_shared(element::f32, PartialShape::dynamic()); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).rank().same_scheme(Rank{3})); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).same_scheme( + PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()})); +} + +TEST(type_prop, bin_convolution_dyn_filters) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + const auto data_batch = make_shared(element::f32, PartialShape{1, 1, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape::dynamic()); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).rank().same_scheme(Rank{4})); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).same_scheme( + PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()})); +} + +TEST(type_prop, bin_convolution_dyn_data_batch_and_filters) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + const auto data_batch = make_shared(element::f32, PartialShape::dynamic()); + const auto filters = make_shared(element::u1, PartialShape::dynamic()); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).rank().is_dynamic()); + ASSERT_TRUE(bin_conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); +} + +TEST(type_prop, bin_convolution_invalid_inputs_et) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + try + { + const auto data_batch = make_shared(element::i32, PartialShape{1, 1, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + // data batch element type must be float point + FAIL() << "Incompatible element type of data batch input not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Data batch element type must be float point"); + } + catch (...) + { + FAIL() << "Data batch element type validation check failed for unexpected reason"; + } + // TODO: Add test with check filters element type once u1 is supported in nGraph Python API + // (#49517) +} + +TEST(type_prop, bin_convolution_incompatible_input_channels) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + auto data_batch = make_shared(element::f32, PartialShape{1, 1, 5, 5}); + auto filters = make_shared(element::u1, PartialShape{1, 2, 3, 3}); + + try + { + auto conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + FAIL() << "Incompatible input channel dimension in data batch and filters not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch channel count")); + } + catch (...) + { + FAIL() << "Data batch and filters input channel count validation check failed for " + "unexpected reason"; + } +} + +TEST(type_prop, bin_convolution_invalid_input_ranks) +{ + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + // data partial shape provided is rank 4 (Conv2D) + // filter partial shape provided is rank 5 (Conv3D) + try + { + const auto data_batch = make_shared(element::f32, PartialShape{1, 1, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3, 3, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + // data batch and filters have incompatible ranks + FAIL() << "Incompatible input ranks not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Shapes for data batch and filters must have same rank."); + } + catch (...) + { + FAIL() << "Rank validation check of inputs failed for unexpected reason"; + } + + // data partial shape provided is rank 5 (Conv3D) + // filter partial shape provided is rank 4 (Conv2D) + try + { + const auto data_batch = + make_shared(element::f32, PartialShape{1, 1, 5, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + Strides{}, + CoordinateDiff{}, + CoordinateDiff{}, + Strides{}, + mode, + pad_value, + auto_pad); + // data batch and filters have incompatible ranks + FAIL() << "Incompatible input ranks not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Shapes for data batch and filters must have same rank."); + } + catch (...) + { + FAIL() << "Rank validation check of inputs failed for unexpected reason"; + } +} + +TEST(type_prop, bin_convolution_invalid_spatial_dims_parameters) +{ + Strides strides_1d{1}; + Strides strides_3d{1, 1, 1}; + + Strides dilations_2d{1, 1}; + Strides dilations_3d{1, 1, 1}; + + CoordinateDiff pads_end_2d{0, 0}; + CoordinateDiff pads_begin_3d{0, 0, 0}; + + const auto mode = op::v1::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT; + const float pad_value = 1.0f; + const auto auto_pad = op::PadType::EXPLICIT; + + try + { + const auto data_batch = make_shared(element::f32, PartialShape{1, 1, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + strides_3d, + CoordinateDiff{}, + CoordinateDiff{}, + dilations_2d, + mode, + pad_value, + auto_pad); + // Strides have incompatible number of spatial dimensions + FAIL() << "Incompatible stride number of spatial dimensions not detected."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Strides should be defined for all and only spatial features.")); + } + catch (...) + { + FAIL() << "Strides validation check failed for unexpected reason."; + } + + try + { + const auto data_batch = make_shared(element::f32, PartialShape{1, 1, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + strides_1d, + CoordinateDiff{}, + CoordinateDiff{}, + dilations_2d, + mode, + pad_value, + auto_pad); + // Dilations have incompatible number of spatial dimensions + FAIL() << "Incompatible dilations number of spatial dimensions not detected."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Dilations should be defined for all and only spatial features.")); + } + catch (...) + { + FAIL() << "Dilations validation check failed for unexpected reason."; + } + + try + { + const auto data_batch = + make_shared(element::f32, PartialShape{1, 1, 5, 5, 5}); + const auto filters = make_shared(element::u1, PartialShape{1, 1, 3, 3, 3}); + const auto bin_conv = make_shared(data_batch, + filters, + strides_3d, + pads_begin_3d, + pads_end_2d, + dilations_3d, + mode, + pad_value, + auto_pad); + // Pads have incompatible number of spatial dimensions + FAIL() << "Incompatible pads number of spatial dimensions not detected."; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), std::string("Pads should be defined for all and only spatial features.")); + } + catch (...) + { + FAIL() << "Pads validation check failed for unexpected reason."; + } +} diff --git a/ngraph/test/util/engine/ie_engines.cpp b/ngraph/test/util/engine/ie_engines.cpp index 03ff866c3b8..68acc3fbb2c 100644 --- a/ngraph/test/util/engine/ie_engines.cpp +++ b/ngraph/test/util/engine/ie_engines.cpp @@ -148,7 +148,7 @@ namespace case element::Type_t::u16: return InferenceEngine::Precision::U16; break; case element::Type_t::u32: return InferenceEngine::Precision::U32; break; case element::Type_t::u64: return InferenceEngine::Precision::U64; break; - case element::Type_t::u1: throw std::runtime_error("unsupported type"); + case element::Type_t::u1: return InferenceEngine::Precision::BIN; break; case element::Type_t::undefined: throw std::runtime_error("unsupported type"); case element::Type_t::dynamic: throw std::runtime_error("unsupported type"); }