Revise squared difference (#5091)
* clean SquaredDifference op from legacy FusedOp operation * add backend and type_prop tests for squared_difference operation * apply requested changes * fix op name * revert date change * update squared_difference op info with ngraph RTTI * change test since squred_diff should be binop
This commit is contained in:
parent
c852634055
commit
6e7a8af2be
@ -4,11 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||||
#include "ngraph/op/op.hpp"
|
|
||||||
#include "ngraph/op/util/fused_op.hpp"
|
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
{
|
{
|
||||||
@ -19,40 +15,30 @@ namespace ngraph
|
|||||||
/// \brief Calculates an element-wise squared difference between two tensors
|
/// \brief Calculates an element-wise squared difference between two tensors
|
||||||
///
|
///
|
||||||
/// y[i] = (x1[i] - x2[i])^2
|
/// y[i] = (x1[i] - x2[i])^2
|
||||||
class NGRAPH_API SquaredDifference : public ngraph::op::util::FusedOp
|
class NGRAPH_API SquaredDifference : public util::BinaryElementwiseArithmetic
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
static constexpr NodeTypeInfo type_info{"SquaredDifference", 0};
|
NGRAPH_RTTI_DECLARATION;
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
|
||||||
SquaredDifference();
|
/// \brief Constrcuts an uninitialized squared difference operation
|
||||||
|
SquaredDifference()
|
||||||
|
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
||||||
|
{
|
||||||
|
}
|
||||||
/// \brief Constructs the squared difference operation.
|
/// \brief Constructs the squared difference operation.
|
||||||
///
|
///
|
||||||
/// \param x1 First input tensor
|
/// \param x1 First input tensor
|
||||||
/// \param x2 Second input tensor
|
/// \param x2 Second input tensor
|
||||||
/// \param auto_broadcast Auto broadcast specification
|
/// \param auto_broadcast Auto broadcast specification
|
||||||
SquaredDifference(
|
SquaredDifference(const Output<Node>& x1,
|
||||||
const Output<Node>& x1,
|
const Output<Node>& x2,
|
||||||
const Output<Node>& x2,
|
const AutoBroadcastSpec& auto_broadcast =
|
||||||
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
|
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
|
||||||
|
|
||||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
|
||||||
virtual OutputVector decompose_op() const override;
|
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node>
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
const AutoBroadcastSpec& get_autob() const override { return m_autobroadcast; }
|
|
||||||
void set_autob(const AutoBroadcastSpec& auto_broadcast)
|
|
||||||
{
|
|
||||||
m_autobroadcast = auto_broadcast;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
AutoBroadcastSpec m_autobroadcast;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
using v0::SquaredDifference;
|
using v0::SquaredDifference;
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
|
@ -4,55 +4,28 @@
|
|||||||
|
|
||||||
#include "ngraph/op/squared_difference.hpp"
|
#include "ngraph/op/squared_difference.hpp"
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/attribute_visitor.hpp"
|
|
||||||
#include "ngraph/node.hpp"
|
|
||||||
#include "ngraph/op/multiply.hpp"
|
|
||||||
#include "ngraph/op/subtract.hpp"
|
|
||||||
#include "ngraph/op/util/fused_op.hpp"
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
// ------------------------------ v0 -------------------------------------------
|
||||||
|
|
||||||
constexpr NodeTypeInfo op::SquaredDifference::type_info;
|
NGRAPH_RTTI_DEFINITION(op::SquaredDifference,
|
||||||
|
"SquaredDifference",
|
||||||
|
0,
|
||||||
|
util::BinaryElementwiseArithmetic);
|
||||||
|
|
||||||
op::SquaredDifference::SquaredDifference()
|
op::SquaredDifference::SquaredDifference(const Output<Node>& arg0,
|
||||||
: FusedOp()
|
const Output<Node>& arg1,
|
||||||
, m_autobroadcast(AutoBroadcastType::NUMPY)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
op::SquaredDifference::SquaredDifference(const Output<Node>& x1,
|
|
||||||
const Output<Node>& x2,
|
|
||||||
const AutoBroadcastSpec& auto_broadcast)
|
const AutoBroadcastSpec& auto_broadcast)
|
||||||
: FusedOp({x1, x2})
|
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
|
||||||
, m_autobroadcast(auto_broadcast)
|
|
||||||
{
|
{
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ngraph::op::v0::SquaredDifference::visit_attributes(AttributeVisitor& visitor)
|
|
||||||
{
|
|
||||||
NGRAPH_OP_SCOPE(v0_SquaredDifference_visit_attributes);
|
|
||||||
visitor.on_attribute("auto_broadcast", m_autobroadcast);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
OutputVector op::SquaredDifference::decompose_op() const
|
|
||||||
{
|
|
||||||
const auto x1 = input_value(0);
|
|
||||||
const auto x2 = input_value(1);
|
|
||||||
|
|
||||||
const auto difference = make_shared<op::v1::Subtract>(x1, x2, m_autobroadcast);
|
|
||||||
|
|
||||||
return {make_shared<op::v1::Multiply>(difference, difference)};
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_ptr<Node> op::SquaredDifference::clone_with_new_inputs(const OutputVector& new_args) const
|
shared_ptr<Node> op::SquaredDifference::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v0_SquaredDifference_clone_with_new_inputs);
|
NGRAPH_OP_SCOPE(v0_SquaredDifference_clone_with_new_inputs);
|
||||||
check_new_args_count(this, new_args);
|
check_new_args_count(this, new_args);
|
||||||
|
return make_shared<op::SquaredDifference>(new_args.at(0), new_args.at(1), this->get_autob());
|
||||||
return make_shared<SquaredDifference>(new_args.at(0), new_args.at(1), get_autob());
|
|
||||||
}
|
}
|
||||||
|
@ -396,6 +396,7 @@ set(MULTI_TEST_SRC
|
|||||||
backend/softmax.in.cpp
|
backend/softmax.in.cpp
|
||||||
backend/split.in.cpp
|
backend/split.in.cpp
|
||||||
backend/sqrt.in.cpp
|
backend/sqrt.in.cpp
|
||||||
|
backend/squared_difference.in.cpp
|
||||||
backend/subtract.in.cpp
|
backend/subtract.in.cpp
|
||||||
backend/tan.in.cpp
|
backend/tan.in.cpp
|
||||||
backend/tanh.in.cpp
|
backend/tanh.in.cpp
|
||||||
|
133
ngraph/test/backend/squared_difference.in.cpp
Normal file
133
ngraph/test/backend/squared_difference.in.cpp
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||||
|
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||||
|
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||||
|
#endif
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "util/engine/test_engines.hpp"
|
||||||
|
#include "util/test_case.hpp"
|
||||||
|
#include "util/test_control.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
static string s_manifest = "${MANIFEST}";
|
||||||
|
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_no_broadcast)
|
||||||
|
{
|
||||||
|
Shape shape{1, 2};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto f = make_shared<Function>(make_shared<op::SquaredDifference>(A, B), ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{256, 56};
|
||||||
|
vector<float> b{256, 56};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape, {0, 0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_negative_numbers)
|
||||||
|
{
|
||||||
|
Shape shape{2, 2};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto f = make_shared<Function>(make_shared<op::SquaredDifference>(A, B), ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{256, 56, -21, -14};
|
||||||
|
vector<float> b{-112, 56, 6, -8};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape, {135424, 0, 729, 36});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_broadcast)
|
||||||
|
{
|
||||||
|
Shape shape_a{1, 2};
|
||||||
|
Shape shape_b{3, 2, 2};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape_b);
|
||||||
|
auto f = make_shared<Function>(make_shared<op::SquaredDifference>(A, B), ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{1, 2};
|
||||||
|
vector<float> b{5, 6, 7, 8, 2, 3, 1, 5, 6, 7, 1, 3};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape_b, {16, 16, 36, 36, 1, 1, 0, 9, 25, 25, 0, 1});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_scalars)
|
||||||
|
{
|
||||||
|
Shape shape{};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto f = make_shared<Function>(make_shared<op::SquaredDifference>(A, B), ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{57};
|
||||||
|
vector<float> b{13};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape, {1936});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_vector_and_scalar)
|
||||||
|
{
|
||||||
|
Shape shape_a{2, 2};
|
||||||
|
Shape shape_b{};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape_b);
|
||||||
|
auto f = make_shared<Function>(make_shared<op::SquaredDifference>(A, B), ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{2, 4, 7, 8};
|
||||||
|
vector<float> b{8};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape_a, {36, 16, 1, 0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, squared_difference_in_place)
|
||||||
|
{
|
||||||
|
Shape shape{2, 2};
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto B = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto T = make_shared<op::SquaredDifference>(A, B);
|
||||||
|
auto T2 = make_shared<op::SquaredDifference>(T, T);
|
||||||
|
|
||||||
|
auto f = make_shared<Function>(T2, ParameterVector{A, B});
|
||||||
|
|
||||||
|
vector<float> a{1, 2, 3, 4};
|
||||||
|
vector<float> b{5, 6, 7, 8};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(f);
|
||||||
|
test_case.add_multiple_inputs<float>({a, b});
|
||||||
|
test_case.add_expected_output<float>(shape, {0, 0 ,0 ,0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
@ -651,7 +651,7 @@ namespace
|
|||||||
{
|
{
|
||||||
op::SquaredDifference node;
|
op::SquaredDifference node;
|
||||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
EXPECT_TRUE(op::is_binary_elementwise_arithmetic(&node));
|
||||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||||
EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
|
EXPECT_FALSE(op::is_binary_elementwise_logical(&node));
|
||||||
}
|
}
|
||||||
|
@ -2114,7 +2114,7 @@ namespace
|
|||||||
outputs[0]->get_data_ptr<T>(),
|
outputs[0]->get_data_ptr<T>(),
|
||||||
inputs[0]->get_shape(),
|
inputs[0]->get_shape(),
|
||||||
inputs[1]->get_shape(),
|
inputs[1]->get_shape(),
|
||||||
ngraph::op::AutoBroadcastSpec::NUMPY);
|
op->get_autob());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,34 +2,8 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "arithmetic_ops.hpp"
|
||||||
#include "ngraph/ngraph.hpp"
|
|
||||||
#include "util/type_prop.hpp"
|
|
||||||
|
|
||||||
using namespace std;
|
using Type = ::testing::Types<ngraph::op::SquaredDifference>;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
TEST(type_prop, squared_difference)
|
INSTANTIATE_TYPED_TEST_CASE_P(type_prop_squared_difference, ArithmeticOperator, Type);
|
||||||
{
|
|
||||||
const auto x1 = make_shared<op::Parameter>(element::f64, Shape{2, 2});
|
|
||||||
const auto x2 = make_shared<op::Parameter>(element::f64, Shape{3, 2});
|
|
||||||
const auto x3 = make_shared<op::Parameter>(element::f64, Shape{1, 2});
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
const auto squared_diff = make_shared<op::SquaredDifference>(x1, x2);
|
|
||||||
FAIL() << "SquaredDifference node was created with incorrect data.";
|
|
||||||
}
|
|
||||||
catch (const NodeValidationFailure& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto squared_diff = make_shared<op::SquaredDifference>(x1, x3);
|
|
||||||
EXPECT_EQ(squared_diff->get_element_type(), element::f64);
|
|
||||||
EXPECT_EQ(squared_diff->get_shape(), (Shape{2, 2}));
|
|
||||||
EXPECT_EQ(squared_diff->get_autob(), op::AutoBroadcastType::NUMPY);
|
|
||||||
|
|
||||||
const auto squared_diff_no_args = make_shared<op::SquaredDifference>();
|
|
||||||
EXPECT_EQ(squared_diff_no_args->get_autob(), op::AutoBroadcastType::NUMPY);
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user