Refactor visitor tests for BinaryElementwiseArithmetic ops (#6667)
* add binary_elementwise file * change binary_elementwise.hpp to binary_ops.hpp * migrate mod operation test o typed template test * add tests for remaining binary ops * remove comment * fix formatting to match clang-format * add RVO-exploit string concatenating andbeautify the code * add validation for attributes number * add missing visit_attributes() calls * add missing 4th param to NGRAPH_RTTI_DEFINITION calls * fix formatting to match clang-format
This commit is contained in:
parent
1a92a69515
commit
b4ad7a1755
@ -20,8 +20,8 @@ namespace ngraph
|
||||
class NGRAPH_API Broadcast : public util::BroadcastBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Broadcast", 3};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a broadcast operation.
|
||||
Broadcast() = default;
|
||||
/// \brief Constructs a broadcast operation.
|
||||
|
@ -16,8 +16,8 @@ namespace ngraph
|
||||
class NGRAPH_API Bucketize : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Bucketize", 3};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Bucketize() = default;
|
||||
/// \brief Constructs a Bucketize node
|
||||
|
||||
|
@ -60,8 +60,8 @@ namespace ngraph
|
||||
class NGRAPH_API CumSum : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"CumSum", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a cumulative summation operation.
|
||||
CumSum() = default;
|
||||
|
||||
|
@ -19,8 +19,8 @@ namespace ngraph
|
||||
class NGRAPH_API FloorMod : public util::BinaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"FloorMod", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs an uninitialized addition operation
|
||||
FloorMod()
|
||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY){};
|
||||
|
@ -16,8 +16,8 @@ namespace ngraph
|
||||
class NGRAPH_API Maximum : public util::BinaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Maximum", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a maximum operation.
|
||||
Maximum()
|
||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
||||
|
@ -16,8 +16,8 @@ namespace ngraph
|
||||
class NGRAPH_API Minimum : public util::BinaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Minimum", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a minimum operation.
|
||||
Minimum()
|
||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
||||
|
@ -17,8 +17,8 @@ namespace ngraph
|
||||
class NGRAPH_API Mod : public util::BinaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Mod", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a Mod node.
|
||||
Mod()
|
||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
||||
|
@ -31,8 +31,8 @@ namespace ngraph
|
||||
class NGRAPH_API Power : public util::BinaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Power", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Power()
|
||||
: util::BinaryElementwiseArithmetic(AutoBroadcastSpec::NUMPY)
|
||||
{
|
||||
|
@ -15,8 +15,8 @@ namespace ngraph
|
||||
class NGRAPH_API ROIPooling : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ROIPooling", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
ROIPooling() = default;
|
||||
/// \brief Constructs a ROIPooling operation
|
||||
///
|
||||
|
@ -16,8 +16,8 @@ namespace ngraph
|
||||
class NGRAPH_API Tanh : public util::UnaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Tanh", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a hyperbolic tangent operation.
|
||||
///
|
||||
/// \param arg Node that produces the input tensor.
|
||||
|
@ -55,6 +55,8 @@ namespace ngraph
|
||||
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
|
||||
|
@ -17,7 +17,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v3::Broadcast::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v3::Broadcast, "Broadcast", 3, op::util::BroadcastBase);
|
||||
|
||||
op::v3::Broadcast::Broadcast(const Output<Node>& arg,
|
||||
const Output<Node>& target_shape,
|
||||
|
@ -8,7 +8,7 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
constexpr NodeTypeInfo op::v3::Bucketize::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v3::Bucketize, "Bucketize", 3);
|
||||
|
||||
op::v3::Bucketize::Bucketize(const Output<Node>& data,
|
||||
const Output<Node>& buckets,
|
||||
|
@ -12,7 +12,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v0::CumSum::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::CumSum, "CumSum", 0);
|
||||
|
||||
op::v0::CumSum::CumSum(const Output<Node>& arg,
|
||||
const Output<Node>& axis,
|
||||
|
@ -51,7 +51,7 @@ namespace equal
|
||||
|
||||
//------------------------------- v1 -------------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Equal, "Equal", 1, op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::Equal::Equal(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
@ -94,5 +94,6 @@ bool op::v1::Equal::has_evaluate() const
|
||||
bool op::v1::Equal::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_Equal_visit_attributes);
|
||||
BinaryElementwiseComparison::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
@ -10,7 +10,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::FloorMod::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::FloorMod, "FloorMod", 1, op::util::BinaryElementwiseArithmetic);
|
||||
|
||||
op::v1::FloorMod::FloorMod(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
@ -97,5 +97,6 @@ bool op::v1::FloorMod::has_evaluate() const
|
||||
bool op::v1::FloorMod::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_FloorMod_visit_attributes);
|
||||
BinaryElementwiseArithmetic::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ namespace greaterop
|
||||
|
||||
//-------------------------------------- v1 ------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Greater, "Greater", 1, op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::Greater::Greater(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -51,7 +51,10 @@ namespace greater_equalop
|
||||
|
||||
//---------------------------------- v1 ----------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual, "GreaterEqual", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::GreaterEqual,
|
||||
"GreaterEqual",
|
||||
1,
|
||||
op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::GreaterEqual::GreaterEqual(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
@ -95,5 +98,6 @@ bool op::v1::GreaterEqual::has_evaluate() const
|
||||
bool op::v1::GreaterEqual::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_GreaterEqual_visit_attributes);
|
||||
BinaryElementwiseComparison::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ namespace lessop
|
||||
|
||||
// ----------------------------- v1 --------------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Less, "Less", 1, op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::Less::Less(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -12,7 +12,7 @@ using namespace ngraph;
|
||||
|
||||
// ---------------------------------- v1 ---------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::LessEqual, "LessEqual", 1, op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -58,7 +58,7 @@ namespace maximumop
|
||||
|
||||
// ------------------------------------ v1 -------------------------------------
|
||||
|
||||
constexpr NodeTypeInfo op::v1::Maximum::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Maximum, "Maximum", 1, op::util::BinaryElementwiseArithmetic);
|
||||
|
||||
op::v1::Maximum::Maximum(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -56,7 +56,7 @@ namespace minimumop
|
||||
|
||||
// ------------------------------ v1 -------------------------------------------
|
||||
|
||||
constexpr NodeTypeInfo op::v1::Minimum::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Minimum, "Minimum", 1, op::util::BinaryElementwiseArithmetic);
|
||||
|
||||
op::v1::Minimum::Minimum(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -10,7 +10,7 @@ using namespace ngraph;
|
||||
|
||||
// ------------------------------ v1 -------------------------------------------
|
||||
|
||||
constexpr NodeTypeInfo op::v1::Mod::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Mod, "Mod", 1, op::util::BinaryElementwiseArithmetic);
|
||||
|
||||
op::v1::Mod::Mod(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -51,7 +51,7 @@ namespace not_equalop
|
||||
|
||||
// ----------------------------------- v1 --------------------------------------
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::NotEqual, "NotEqual", 1, op::util::BinaryElementwiseComparison);
|
||||
|
||||
op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
@ -95,5 +95,6 @@ bool op::v1::NotEqual::has_evaluate() const
|
||||
bool op::v1::NotEqual::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_NotEqual_visit_attributes);
|
||||
BinaryElementwiseComparison::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ namespace power
|
||||
|
||||
// ------------------------------ v1 -------------------------------------------
|
||||
|
||||
constexpr NodeTypeInfo op::v1::Power::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Power, "Power", 1, op::util::BinaryElementwiseArithmetic);
|
||||
|
||||
op::v1::Power::Power(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
|
@ -8,7 +8,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::ROIPooling::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::ROIPooling, "ROIPooling", 0);
|
||||
|
||||
op::ROIPooling::ROIPooling(const Output<Node>& input,
|
||||
const Output<Node>& coords,
|
||||
|
@ -14,7 +14,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::Tanh::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::Tanh, "Tanh", 0, op::util::UnaryElementwiseArithmetic);
|
||||
|
||||
op::Tanh::Tanh(const Output<Node>& arg)
|
||||
: UnaryElementwiseArithmetic(arg)
|
||||
|
@ -10,6 +10,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::BinaryElementwiseComparison, "BinaryElementwiseComparison", 0);
|
||||
|
||||
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const AutoBroadcastSpec& autob)
|
||||
: m_autob(autob)
|
||||
{
|
||||
|
@ -234,6 +234,7 @@ set(SRC
|
||||
visitors/op/acosh.cpp
|
||||
visitors/op/adaptive_avg_pool.cpp
|
||||
visitors/op/adaptive_max_pool.cpp
|
||||
visitors/op/add.cpp
|
||||
visitors/op/asinh.cpp
|
||||
visitors/op/atan.cpp
|
||||
visitors/op/batch_norm.cpp
|
||||
@ -253,16 +254,23 @@ set(SRC
|
||||
visitors/op/detection_output.cpp
|
||||
visitors/op/einsum.cpp
|
||||
visitors/op/elu.cpp
|
||||
visitors/op/equal.cpp
|
||||
visitors/op/erf.cpp
|
||||
visitors/op/extractimagepatches.cpp
|
||||
visitors/op/fake_quantize.cpp
|
||||
visitors/op/floor_mod.cpp
|
||||
visitors/op/floor.cpp
|
||||
visitors/op/gather.cpp
|
||||
visitors/op/gelu.cpp
|
||||
visitors/op/greater_equal.cpp
|
||||
visitors/op/greater.cpp
|
||||
visitors/op/grn.cpp
|
||||
visitors/op/group_conv.cpp
|
||||
visitors/op/interpolate.cpp
|
||||
visitors/op/less_equal.cpp
|
||||
visitors/op/less.cpp
|
||||
visitors/op/log.cpp
|
||||
visitors/op/logical_or.cpp
|
||||
visitors/op/logical_xor.cpp
|
||||
visitors/op/lrn.cpp
|
||||
visitors/op/lstm_cell.cpp
|
||||
@ -270,17 +278,22 @@ set(SRC
|
||||
visitors/op/matmul.cpp
|
||||
visitors/op/matrix_nms.cpp
|
||||
visitors/op/max_pool.cpp
|
||||
visitors/op/maximum.cpp
|
||||
visitors/op/minimum.cpp
|
||||
visitors/op/mish.cpp
|
||||
visitors/op/mod.cpp
|
||||
visitors/op/multiclass_nms.cpp
|
||||
visitors/op/multiply.cpp
|
||||
visitors/op/mvn.cpp
|
||||
visitors/op/negative.cpp
|
||||
visitors/op/non_max_suppression.cpp
|
||||
visitors/op/non_zero.cpp
|
||||
visitors/op/normalize_l2.cpp
|
||||
visitors/op/not_equal.cpp
|
||||
visitors/op/one_hot.cpp
|
||||
visitors/op/pad.cpp
|
||||
visitors/op/parameter.cpp
|
||||
visitors/op/power.cpp
|
||||
visitors/op/prior_box.cpp
|
||||
visitors/op/prior_box_clustered.cpp
|
||||
visitors/op/proposal.cpp
|
||||
@ -314,10 +327,11 @@ set(SRC
|
||||
visitors/op/space_to_batch.cpp
|
||||
visitors/op/space_to_depth.cpp
|
||||
visitors/op/split.cpp
|
||||
visitors/op/sqrt.cpp
|
||||
visitors/op/squared_difference.cpp
|
||||
visitors/op/squeeze.cpp
|
||||
visitors/op/sqrt.cpp
|
||||
visitors/op/strided_slice.cpp
|
||||
visitors/op/subtract.cpp
|
||||
visitors/op/swish.cpp
|
||||
visitors/op/tanh.cpp
|
||||
visitors/op/topk.cpp
|
||||
|
13
ngraph/test/visitors/op/add.cpp
Normal file
13
ngraph/test/visitors/op/add.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Add, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
60
ngraph/test/visitors/op/binary_ops.hpp
Normal file
60
ngraph/test/visitors/op/binary_ops.hpp
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
template <typename T, ngraph::element::Type_t ELEMENT_TYPE>
|
||||
class BinaryOperatorType
|
||||
{
|
||||
public:
|
||||
using op_type = T;
|
||||
static constexpr ngraph::element::Type_t element_type = ELEMENT_TYPE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BinaryOperatorVisitor : public testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
class BinaryOperatorTypeName
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
static std::string GetName(int)
|
||||
{
|
||||
using OP_Type = typename T::op_type;
|
||||
constexpr ngraph::element::Type precision(T::element_type);
|
||||
const ngraph::Node::type_info_t typeinfo = OP_Type::get_type_info_static();
|
||||
return std::string{typeinfo.name} + "_" + precision.get_type_name();
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE_P(BinaryOperatorVisitor);
|
||||
|
||||
TYPED_TEST_P(BinaryOperatorVisitor, Auto_Broadcast)
|
||||
{
|
||||
using OP_Type = typename TypeParam::op_type;
|
||||
const ngraph::element::Type_t element_type = TypeParam::element_type;
|
||||
|
||||
ngraph::test::NodeBuilder::get_ops().register_factory<OP_Type>();
|
||||
const auto A =
|
||||
std::make_shared<ngraph::op::Parameter>(element_type, ngraph::PartialShape{1, 2, 3});
|
||||
const auto B =
|
||||
std::make_shared<ngraph::op::Parameter>(element_type, ngraph::PartialShape{3, 2, 1});
|
||||
|
||||
auto auto_broadcast = ngraph::op::AutoBroadcastType::NUMPY;
|
||||
|
||||
const auto op_func = std::make_shared<OP_Type>(A, B, auto_broadcast);
|
||||
ngraph::test::NodeBuilder builder(op_func);
|
||||
const auto g_op_func = ngraph::as_type_ptr<OP_Type>(builder.create());
|
||||
|
||||
const auto expected_attr_count = 1;
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
EXPECT_EQ(op_func->get_autob(), g_op_func->get_autob());
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(BinaryOperatorVisitor, Auto_Broadcast);
|
13
ngraph/test/visitors/op/equal.cpp
Normal file
13
ngraph/test/visitors/op/equal.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Equal, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/floor_mod.cpp
Normal file
13
ngraph/test/visitors/op/floor_mod.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::FloorMod, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/greater.cpp
Normal file
13
ngraph/test/visitors/op/greater.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Greater, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
14
ngraph/test/visitors/op/greater_equal.cpp
Normal file
14
ngraph/test/visitors/op/greater_equal.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type =
|
||||
::testing::Types<BinaryOperatorType<ngraph::opset1::GreaterEqual, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/less.cpp
Normal file
13
ngraph/test/visitors/op/less.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Less, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/less_equal.cpp
Normal file
13
ngraph/test/visitors/op/less_equal.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::LessEqual, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
14
ngraph/test/visitors/op/logical_or.cpp
Normal file
14
ngraph/test/visitors/op/logical_or.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type =
|
||||
::testing::Types<BinaryOperatorType<ngraph::opset1::LogicalOr, ngraph::element::boolean>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
@ -2,33 +2,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
using Type =
|
||||
::testing::Types<BinaryOperatorType<ngraph::opset1::LogicalXor, ngraph::element::boolean>>;
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, logical_xor_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset1::LogicalXor>();
|
||||
auto x1 = make_shared<op::Parameter>(element::boolean, Shape{200});
|
||||
auto x2 = make_shared<op::Parameter>(element::boolean, Shape{200});
|
||||
|
||||
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
|
||||
|
||||
auto logical_xor = make_shared<opset1::LogicalXor>(x1, x2, auto_broadcast);
|
||||
NodeBuilder builder(logical_xor);
|
||||
auto g_logical_xor = as_type_ptr<opset1::LogicalXor>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_logical_xor->get_autob(), logical_xor->get_autob());
|
||||
}
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
||||
|
13
ngraph/test/visitors/op/maximum.cpp
Normal file
13
ngraph/test/visitors/op/maximum.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Maximum, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/minimum.cpp
Normal file
13
ngraph/test/visitors/op/minimum.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Minimum, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
@ -2,33 +2,12 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Mod, ngraph::element::f32>>;
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, mod_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset1::Mod>();
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2});
|
||||
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||
|
||||
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
|
||||
|
||||
auto mod = make_shared<opset1::Mod>(A, B, auto_broadcast);
|
||||
NodeBuilder builder(mod);
|
||||
auto g_mod = as_type_ptr<opset1::Mod>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_mod->get_autob(), mod->get_autob());
|
||||
}
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
||||
|
13
ngraph/test/visitors/op/multiply.cpp
Normal file
13
ngraph/test/visitors/op/multiply.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Multiply, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/not_equal.cpp
Normal file
13
ngraph/test/visitors/op/not_equal.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::NotEqual, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
13
ngraph/test/visitors/op/power.cpp
Normal file
13
ngraph/test/visitors/op/power.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Power, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
@ -2,31 +2,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
using Type =
|
||||
::testing::Types<BinaryOperatorType<ngraph::opset1::SquaredDifference, ngraph::element::f32>>;
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, squared_difference_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset1::SquaredDifference>();
|
||||
auto x1 = make_shared<op::Parameter>(element::i32, Shape{200});
|
||||
auto x2 = make_shared<op::Parameter>(element::i32, Shape{200});
|
||||
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
|
||||
auto squared_difference = make_shared<opset1::SquaredDifference>(x1, x2, auto_broadcast);
|
||||
NodeBuilder builder(squared_difference);
|
||||
auto g_squared_difference = as_type_ptr<opset1::SquaredDifference>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_squared_difference->get_autob(), squared_difference->get_autob());
|
||||
}
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
||||
|
13
ngraph/test/visitors/op/subtract.cpp
Normal file
13
ngraph/test/visitors/op/subtract.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "binary_ops.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using Type = ::testing::Types<BinaryOperatorType<ngraph::opset1::Subtract, ngraph::element::f32>>;
|
||||
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_with_auto_broadcast,
|
||||
BinaryOperatorVisitor,
|
||||
Type,
|
||||
BinaryOperatorTypeName);
|
Loading…
Reference in New Issue
Block a user