Implement ConvtoBinaryConv transformation (#4038)

This commit is contained in:
Mateusz Tabaka 2021-03-24 11:06:26 +01:00 committed by GitHub
parent b99345c320
commit e7b9b021ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 974 additions and 0 deletions

View File

@ -0,0 +1,79 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API BinarizeWeights;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief This transformation converts weights to -1/+1 form
* and applies normalization factors to output low/high and after Convolution.
* For example, following graph
*
* .... .... out_low out_high weights .. .. out_low out_high
* | | | | | | | | |
* +--------------------------+ +--------------------------+
* | FakeQuantize (levels==2) | | FakeQuantize (levels==2) |
* | (on activations) | | (on weights) |
* +--------------------------+ +--------------------------+
* | |
* | |
* ----------------- -------------------
* | |
* v v
* +-------------+
* | Convolution |
* +-------------+
* |
* v
*
* is transformed to:
*
* normalized normalized
* .... .... out_low out_high
* | | | |
* +--------------------------+ +--------------------------+
* | FakeQuantize (levels==2) | | Constant |
* | (on activations) | | (with converted weights) |
* +--------------------------+ +--------------------------+
* | |
* | |
* ----------------- -------------------
* | |
* v v
* +-------------+
* | Convolution |
* +-------------+
* |
* v
* +------------+ +---------------------------------------------------------------+
* | Multiply | <---| Constant (normalization factor coming from FQ on activations) |
* +------------+ +---------------------------------------------------------------+
* |
* v
* +------------+ +-----------------------------------------------------------+
* | Multiply | <---| Constant (normalization factor coming from FQ on weights) |
* +------------+ +------------------------------------------------------------
* |
* v
*
* Normalization factors are chosen based output_high value.
* If it's zero - norm factor is equal to output_low and output_high otherwise
*/
class ngraph::pass::BinarizeWeights : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
BinarizeWeights();
};

View File

@ -0,0 +1,77 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvToBinaryConv;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief This transformation converts Convolution to BinaryConvolution under following conditions:
* - first input to Convolution is a FakeQuantize with levels==2 with output low,high being either (0, 1) or (-1, 1)
* - second input (weights) is a constant with values -1 or 1
* The transformation also converts weights to binary Constant (with 'u1' type)
* For example, when output_low is equal to 0 and output_high is equal to 1, following graph
*
* .... .... out_low out_high
* | | | |
* +--------------------------+ +-------------------------------------+
* | FakeQuantize (levels==2) | | Constant |
* | (on activations) | | (weights containing -1 or 1 values) |
* +--------------------------+ +-------------------------------------+
* | |
* | |
* ----------------- -------------------
* | |
* v v
* +-------------+
* | Convolution |
* +-------------+
* |
* v
* is transformed to:
*
* .... .... out_low out_high
* | | | |
* +--------------------------+ +---------------------------------+
* | FakeQuantize (levels==2) | | Constant (with u1 type) |
* | (on activations) | | (with u1 type - binary weights) |
* +--------------------------+ +---------------------------------+
* | |
* | |
* ----------------- -------------------
* | |
* v v
* +-------------------+
* | BinaryConvolution |
* +-------------------+
* |
* v
* +------------+ +----------------------------------------------------+
* | | | Constant |
* | Add | <---| (weights from original graph, |
* | | | sum-reduced over [1,..., len(weights.shape)] axes |
* +------------+ +----------------------------------------------------+
* |
* v
* +------------+ +-----+
* | Multiply | <---| 0.5 |
* +------------+ +-----+
* |
* v
*/
class ngraph::pass::ConvToBinaryConv : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvToBinaryConv();
};

View File

@ -0,0 +1,188 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/binarize_weights.hpp"
#include "itt.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(pass::BinarizeWeights, "BinarizeWeights", 0);
static float quantize(float f, float input_low, float input_high, float output_low, float output_high) {
if (f <= input_low)
return output_low;
if (f > input_high)
return output_high;
return std::round((f - input_low) / (input_high - input_low)) * (output_high - output_low) + output_low;
}
static std::vector<float> quantize_weights(const Shape& weights_shape, std::vector<float>& weights,
Shape input_low_high_shape, const std::vector<float>& input_low, const std::vector<float>& input_high,
Shape output_low_high_shape, const std::vector<float>& output_low, const std::vector<float>& output_high) {
NGRAPH_CHECK(shape_size(input_low_high_shape) == 1 || shape_size(input_low_high_shape) == weights_shape[0]);
NGRAPH_CHECK(shape_size(output_low_high_shape) == 1 || shape_size(output_low_high_shape) == weights_shape[0]);
size_t out_feat_off = 1;
for (size_t i = 1; i < weights_shape.size(); i++)
out_feat_off *= weights_shape[i];
std::vector<float> out;
out.reserve(shape_size(weights_shape));
auto get_idx = [out_feat_off] (size_t i, const Shape& shape) -> size_t {
return (i / out_feat_off) % shape[0];
};
for (size_t i = 0; i < shape_size(weights_shape); i++) {
size_t in_idx = get_idx(i, input_low_high_shape);
size_t out_idx = get_idx(i, output_low_high_shape);
out.push_back(quantize(weights[i], input_low[in_idx], input_high[in_idx], output_low[out_idx], output_high[out_idx]));
}
return out;
}
pass::BinarizeWeights::BinarizeWeights() {
MATCHER_SCOPE(BinarizeWeights);
auto activations_fq_pattern = pattern::wrap_type<opset5::FakeQuantize>(
{pattern::any_input(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>()},
pattern::consumers_count(1));
auto weights_fq_pattern = pattern::wrap_type<opset5::FakeQuantize>(
{pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>(),
pattern::wrap_type<opset5::Constant>()},
pattern::consumers_count(1));
auto conv_pattern = pattern::wrap_type<opset5::Convolution>({activations_fq_pattern, weights_fq_pattern});
matcher_pass_callback callback = [=](pattern::Matcher &m) {
auto conv = std::dynamic_pointer_cast<opset5::Convolution>(m.get_match_root());
if (!conv)
return false;
auto activations_fq = std::dynamic_pointer_cast<opset5::FakeQuantize>(conv->input_value(0).get_node_shared_ptr());
if (!activations_fq || activations_fq->get_levels() != 2)
return false;
auto weights_fq = std::dynamic_pointer_cast<opset5::FakeQuantize>(conv->input_value(1).get_node_shared_ptr());
if (!weights_fq || weights_fq->get_levels() != 2)
return false;
auto weights_const = std::dynamic_pointer_cast<opset5::Constant>(weights_fq->input_value(0).get_node_shared_ptr());
if (!weights_const)
return false;
auto check_output_low_high = [] (const std::vector<float>& output_low,
const std::vector<float>& output_high) -> std::tuple<bool, bool> {
bool output_low_is_zero = true;
bool output_low_high_are_opposite = true;
for (size_t i = 0; i < output_low.size(); i++) {
output_low_is_zero = output_low_is_zero && output_low[i] == 0.0f;
output_low_high_are_opposite = output_low_high_are_opposite && output_low[i] == -output_high[i];
}
return std::tuple<bool, bool>{output_low_is_zero, output_low_high_are_opposite};
};
auto activations_output_low_const = std::dynamic_pointer_cast<opset5::Constant>(activations_fq->input_value(3).get_node_shared_ptr());
auto activations_output_high_const = std::dynamic_pointer_cast<opset5::Constant>(activations_fq->input_value(4).get_node_shared_ptr());
if (!activations_output_low_const || !activations_output_high_const)
return false;
// Check output low and high on activations FQ first
bool act_out_low_is_zero = false;
bool act_out_low_high_are_opposite = false;
auto activations_output_low = activations_output_low_const->cast_vector<float>();
auto activations_output_high = activations_output_high_const->cast_vector<float>();
std::tie(act_out_low_is_zero, act_out_low_high_are_opposite) = check_output_low_high(activations_output_low,
activations_output_high);
if (!(act_out_low_high_are_opposite || act_out_low_is_zero))
return false;
auto weights_input_low_const = std::dynamic_pointer_cast<opset5::Constant>(weights_fq->input_value(1).get_node_shared_ptr());
auto weights_input_high_const = std::dynamic_pointer_cast<opset5::Constant>(weights_fq->input_value(2).get_node_shared_ptr());
if (!weights_input_low_const || !weights_input_high_const)
return false;
auto weights_output_low_const = std::dynamic_pointer_cast<opset5::Constant>(weights_fq->input_value(3).get_node_shared_ptr());
auto weights_output_high_const = std::dynamic_pointer_cast<opset5::Constant>(weights_fq->input_value(4).get_node_shared_ptr());
if (!weights_output_low_const || !weights_output_high_const)
return false;
// Check output low and high on weights FQ
bool weights_out_low_high_are_opposite = false;
auto weights_output_low = weights_output_low_const->cast_vector<float>();
auto weights_output_high = weights_output_high_const->cast_vector<float>();
std::tie(std::ignore, weights_out_low_high_are_opposite) = check_output_low_high(weights_output_low,
weights_output_high);
if (!weights_out_low_high_are_opposite)
return false;
// Normalize output low and high to either (0, 1) or (-1, 1)
auto normalize_output_low_high = [] (std::vector<float>& output_low, std::vector<float>& output_high) {
for (size_t i = 0; i < output_low.size(); i++) {
output_low[i] /= output_high[i];
output_high[i] = 1.0f;
}
};
normalize_output_low_high(activations_output_low, activations_output_high);
normalize_output_low_high(weights_output_low, weights_output_high);
// Choose additional normalization factor that has to be put after Convolution
const std::shared_ptr<Node>& activations_norm_factor = activations_output_high_const;
const std::shared_ptr<Node>& weights_norm_factor = weights_output_high_const;
// Create new FQ on activations with new output low/high
auto output_low_normalized = op::Constant::create(element::f32, activations_output_low_const->get_shape(), activations_output_low);
output_low_normalized->set_friendly_name(activations_output_low_const->get_friendly_name());
auto output_high_normalized = op::Constant::create(element::f32, activations_output_high_const->get_shape(), activations_output_high);
output_high_normalized->set_friendly_name(activations_output_high_const->get_friendly_name());
auto new_activations_fq = activations_fq->clone_with_new_inputs({activations_fq->input_value(0),
activations_fq->input_value(1),
activations_fq->input_value(2),
output_low_normalized,
output_high_normalized});
new_activations_fq->set_friendly_name(activations_fq->get_friendly_name());
// Quantize weights - here we get rid of FQ on weights and create a constant with quantized weights
auto weights = weights_const->cast_vector<float>();
auto weights_input_low = weights_input_low_const->cast_vector<float>();
auto weights_input_high = weights_input_high_const->cast_vector<float>();
auto quantized_weights = quantize_weights(weights_const->get_shape(), weights,
weights_input_low_const->get_shape(), weights_input_low, weights_input_high,
weights_output_low_const->get_shape(), weights_output_low, weights_output_high);
auto quantized_weights_const = op::Constant::create(element::f32, weights_const->get_shape(), quantized_weights);
quantized_weights_const->set_friendly_name(weights_const->get_friendly_name());
auto new_conv = conv->clone_with_new_inputs({new_activations_fq, quantized_weights_const});
new_conv->set_friendly_name(conv->get_friendly_name());
std::vector<int64_t> norm_factor_shape = {-1};
for (size_t i = 2; i < weights_const->get_shape().size(); i++)
norm_factor_shape.push_back(1);
auto norm_factor_shape_const = opset5::Constant::create(element::i64, Shape{norm_factor_shape.size()}, norm_factor_shape);
auto activations_norm_factor_reshaped = std::make_shared<opset5::Reshape>(activations_norm_factor, norm_factor_shape_const, false);
auto mul = std::make_shared<opset5::Multiply>(new_conv, activations_norm_factor_reshaped);
auto weights_norm_factor_reshaped = std::make_shared<opset5::Reshape>(weights_norm_factor, norm_factor_shape_const, false);
auto mul2 = std::make_shared<opset5::Multiply>(mul, weights_norm_factor_reshaped);
copy_runtime_info({activations_fq, weights_fq, conv},
{new_activations_fq, new_conv, activations_norm_factor_reshaped, mul, weights_norm_factor_reshaped, mul2});
replace_node(conv, mul2);
return true;
};
auto m = std::make_shared<pattern::Matcher>(conv_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -33,6 +33,8 @@
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
#include "transformations/common_optimizations/softmax_fusion.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/common_optimizations/binarize_weights.hpp"
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@ -106,6 +108,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::ConvertInterpolate1ToInterpolate4, false>();
manager.register_pass<ngraph::pass::BinarizeWeights>();
manager.register_pass<ngraph::pass::ConvToBinaryConv>();
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
decomp->add_matcher<ngraph::pass::Gelu7Downgrade>();

View File

@ -0,0 +1,130 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
#include "itt.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvToBinaryConv, "ConvToBinaryConv", 0);
static std::vector<uint8_t> binarize_weights(const std::vector<float>& weights) {
std::vector<uint8_t> out;
size_t bits_per_byte = 8;
for (size_t i = 0; i < weights.size(); i += 8) {
uint8_t val = 0;
for (size_t j = 0; j < std::min(bits_per_byte, weights.size() - i); j++) {
if (weights[i + j] == 1.0f)
val |= 1 << j;
}
out.push_back(val);
}
return out;
}
ngraph::pass::ConvToBinaryConv::ConvToBinaryConv() {
MATCHER_SCOPE(ConvToBinaryConv);
auto fq_pattern = ngraph::pattern::wrap_type<opset5::FakeQuantize>(
{ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::wrap_type<opset5::Constant>(),
ngraph::pattern::wrap_type<opset5::Constant>()},
pattern::consumers_count(1));
auto conv_pattern = ngraph::pattern::wrap_type<opset5::Convolution>({fq_pattern, ngraph::pattern::wrap_type<opset5::Constant>()});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
auto conv = std::dynamic_pointer_cast<opset5::Convolution>(m.get_match_root());
if (!conv)
return false;
auto fq = std::dynamic_pointer_cast<opset5::FakeQuantize>(conv->input_value(0).get_node_shared_ptr());
if (!fq || fq->get_levels() != 2)
return false;
auto output_low_constant = std::dynamic_pointer_cast<opset5::Constant>(fq->input_value(3).get_node_shared_ptr());
if (!output_low_constant)
return false;
auto output_low = output_low_constant->cast_vector<float>();
bool output_low_is_zero = std::all_of(output_low.begin(), output_low.end(), [] (float f) -> bool { return f == 0.0f; });
bool output_low_is_minus_one = std::all_of(output_low.begin(), output_low.end(), [] (float f) -> bool { return f == -1.0f; });
auto output_high_constant = std::dynamic_pointer_cast<opset5::Constant>(fq->input_value(4).get_node_shared_ptr());
if (!output_high_constant)
return false;
auto output_high = output_high_constant->cast_vector<float>();
bool output_high_is_one = std::all_of(output_high.begin(), output_high.end(), [] (float f) -> bool { return f == 1.0f; });
if (!(output_high_is_one && (output_low_is_zero || output_low_is_minus_one)))
return false;
auto weights_constant = std::dynamic_pointer_cast<opset5::Constant>(conv->input_value(1).get_node_shared_ptr());
if (!weights_constant)
return false;
auto weights = weights_constant->cast_vector<float>();
if (!std::all_of(weights.begin(), weights.end(), [] (float f) -> bool { return f == -1.0f || f == 1.0f; }))
return false;
auto bin_weights = binarize_weights(weights);
auto bin_weights_constant = std::make_shared<opset5::Constant>(element::u1, weights_constant->get_shape(), bin_weights.data());
if (output_low_is_zero && output_high_is_one) {
auto new_conv = std::make_shared<opset5::BinaryConvolution>(conv->input_value(0), bin_weights_constant,
conv->get_strides(),
conv->get_pads_begin(),
conv->get_pads_end(),
conv->get_dilations(),
opset5::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT,
-1,
conv->get_auto_pad());
new_conv->set_friendly_name(conv->get_friendly_name());
std::vector<int64_t> axes;
std::vector<int64_t> weights_reduced_shape = {-1};
for (size_t i = 1; i < weights_constant->get_shape().size(); i++) {
axes.push_back(i);
}
for (size_t i = 2; i < weights_constant->get_shape().size(); i++) {
weights_reduced_shape.push_back(1);
}
auto weights_reduced = std::make_shared<opset5::ReduceSum>(
op::Constant::create(element::f32, weights_constant->get_shape(), weights),
op::Constant::create(element::i64, Shape{axes.size()}, axes), false);
std::shared_ptr<Node> weights_reduced_reshaped = std::make_shared<opset5::Reshape>(weights_reduced,
op::Constant::create(element::i64,
Shape{weights_reduced_shape.size()},
weights_reduced_shape),
false);
weights_reduced_reshaped = ngraph::get_constant_from_source(weights_reduced_reshaped);
auto add = std::make_shared<opset5::Add>(new_conv, weights_reduced_reshaped);
auto mul = std::make_shared<opset5::Multiply>(add, op::Constant::create(element::f32, Shape{}, {0.5}));
copy_runtime_info(conv, {new_conv, add, mul});
replace_node(conv, mul);
return true;
}
auto new_conv = std::make_shared<opset5::BinaryConvolution>(conv->input_value(0), bin_weights_constant,
conv->get_strides(),
conv->get_pads_begin(),
conv->get_pads_end(),
conv->get_dilations(),
opset5::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT,
0,
conv->get_auto_pad());
new_conv->set_friendly_name(conv->get_friendly_name());
copy_runtime_info(conv, new_conv);
replace_node(conv, new_conv);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(conv_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,263 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <transformations/common_optimizations/binarize_weights.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, BinarizeWeightsActivationsOutputLowZero) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::BinarizeWeights>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
auto mul = std::make_shared<opset5::Multiply>(conv, opset5::Constant::create(element::f32, Shape{1, 1, 1}, {0.7f}));
auto mul2 = std::make_shared<opset5::Multiply>(mul, opset5::Constant::create(element::f32, Shape{1, 1, 1}, {0.2f}));
f_ref = std::make_shared<Function>(NodeVector{mul2}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, BinarizeWeightsActivationsOutputLowNegative) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.7f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::BinarizeWeights>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
auto mul = std::make_shared<opset5::Multiply>(conv, opset5::Constant::create(element::f32, Shape{1, 1, 1}, {0.7f}));
auto mul2 = std::make_shared<opset5::Multiply>(mul, opset5::Constant::create(element::f32, Shape{1, 1, 1}, {0.2f}));
f_ref = std::make_shared<Function>(NodeVector{mul2}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeBinarizeWeightsInvalidLevels) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.7f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 3);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::BinarizeWeights>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.7f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 3);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeBinarizeWeightsInvalidActivationsOutputLowHigh) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::BinarizeWeights>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {-0.2f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeBinarizeWeightsInvalidOutputLowHigh) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::BinarizeWeights>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.7f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 0, 2});
auto weights_in_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto weights_in_high = opset5::Constant::create(element::f32, Shape{1}, {2.0f});
auto weights_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto weights_out_high = opset5::Constant::create(element::f32, Shape{1}, {0.2f});
auto weights_fq = std::make_shared<opset5::FakeQuantize>(weights, weights_in_low, weights_in_high, weights_out_low, weights_out_high, 2);
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights_fq, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0}, Strides{1, 1});
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -0,0 +1,233 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <transformations/common_optimizations/conv_to_binary_conv.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, ConvToBinaryConvOutputLowZeroOutputHighOne) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvToBinaryConv>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
uint8_t weights_val = 6;
auto weights = std::make_shared<opset5::Constant>(element::u1, Shape{1, 3, 1, 1}, &weights_val);
auto conv = std::make_shared<opset5::BinaryConvolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}, opset5::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT, -1, op::PadType::EXPLICIT);
auto add = std::make_shared<opset5::Add>(conv, opset5::Constant::create(element::f32, Shape{1, 1, 1}, {0.7f}));
auto mul = std::make_shared<opset5::Multiply>(add, opset5::Constant::create(element::f32, Shape{}, {0.2f}));
f_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvToBinaryConvOutputLowMinusOneOutputHighOne) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvToBinaryConv>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {0.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
uint8_t weights_val = 6;
auto weights = std::make_shared<opset5::Constant>(element::u1, Shape{1, 3, 1, 1}, &weights_val);
auto conv = std::make_shared<opset5::BinaryConvolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}, opset5::BinaryConvolution::BinaryConvolutionMode::XNOR_POPCOUNT, 0, op::PadType::EXPLICIT);
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeConvToBinaryConvInvalidWeights) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 2, 3});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvToBinaryConv>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 2, 3});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeConvToBinaryConvInvalidLevels) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 3);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvToBinaryConv>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-1.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 3);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, NegativeConvToBinaryConvOutputLowHigh) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-2.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvToBinaryConv>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 3, 2, 2});
auto act_in_low = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_in_high = opset5::Constant::create(element::f32, Shape{1}, {3.0f});
auto act_out_low = opset5::Constant::create(element::f32, Shape{1}, {-2.0f});
auto act_out_high = opset5::Constant::create(element::f32, Shape{1}, {1.0f});
auto act_fq = std::make_shared<opset5::FakeQuantize>(data, act_in_low, act_in_high, act_out_low, act_out_high, 2);
auto weights = opset5::Constant::create(element::f32, Shape{1, 3, 1, 1}, {-1, 1, 1});
auto conv = std::make_shared<opset5::Convolution>(act_fq, weights, Strides{1, 1}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, Strides{1, 1}, op::PadType::EXPLICIT);
f_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}