Shell implementation for RandomUniform. (#6782)
* Added shell implementation for RandomUniform. * Small correction. * Small correction. * Corrected wrong type. * Corrected error message, corrected setters.
This commit is contained in:
parent
a3d9f00d98
commit
e70e7e1e9d
69
ngraph/core/include/ngraph/op/random_uniform.hpp
Normal file
69
ngraph/core/include/ngraph/op/random_uniform.hpp
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v8
|
||||
{
|
||||
/// \brief Tensor RandomUniform operation.
|
||||
class NGRAPH_API RandomUniform : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
RandomUniform() = default;
|
||||
|
||||
///
|
||||
/// \brief Constructs a RandomUniform operation.
|
||||
///
|
||||
/// \param out_shape Node producing the tensor with output shape.
|
||||
/// \param min_val Node producing the tensor with minimum value.
|
||||
/// \param max_val Node producing the tensor with maximum value.
|
||||
/// \param out_type Output type of the tensor.
|
||||
/// \param global_seed Global seed value.
|
||||
/// \param op_seed Operational seed value.
|
||||
RandomUniform(const Output<Node>& out_shape,
|
||||
const Output<Node>& min_val,
|
||||
const Output<Node>& max_val,
|
||||
const ngraph::element::Type& out_type,
|
||||
uint64_t global_seed,
|
||||
uint64_t op_seed);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
/// \return The output tensor type.
|
||||
const ngraph::element::Type& get_out_type() const { return m_output_type; }
|
||||
void set_out_type(const ngraph::element::Type& output_type)
|
||||
{
|
||||
m_output_type = output_type;
|
||||
}
|
||||
|
||||
/// \return The global seed value.
|
||||
uint64_t get_global_seed() const { return m_global_seed; }
|
||||
void set_global_seed(uint64_t seed) { m_global_seed = seed; }
|
||||
|
||||
/// \return The operational seed value.
|
||||
uint64_t get_op_seed() const { return m_op_seed; }
|
||||
void set_op_seed(uint64_t seed2) { m_op_seed = seed2; }
|
||||
|
||||
protected:
|
||||
ngraph::element::Type m_output_type;
|
||||
uint64_t m_global_seed;
|
||||
uint64_t m_op_seed;
|
||||
};
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -112,6 +112,7 @@
|
||||
#include "ngraph/op/prior_box_clustered.hpp"
|
||||
#include "ngraph/op/proposal.hpp"
|
||||
#include "ngraph/op/psroi_pooling.hpp"
|
||||
#include "ngraph/op/random_uniform.hpp"
|
||||
#include "ngraph/op/range.hpp"
|
||||
#include "ngraph/op/read_value.hpp"
|
||||
#include "ngraph/op/reduce_l1.hpp"
|
||||
|
@ -181,3 +181,4 @@ NGRAPH_OP(AdaptiveMaxPool, ngraph::op::v8)
|
||||
NGRAPH_OP(DeformableConvolution, ngraph::op::v8)
|
||||
NGRAPH_OP(MatrixNms, ngraph::op::v8)
|
||||
NGRAPH_OP(MulticlassNms, ngraph::op::v8)
|
||||
NGRAPH_OP(RandomUniform, ngraph::op::v8)
|
||||
|
144
ngraph/core/src/op/random_uniform.cpp
Normal file
144
ngraph/core/src/op/random_uniform.cpp
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/random_uniform.hpp"
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v8::RandomUniform, "RandomUniform", 8);
|
||||
|
||||
op::v8::RandomUniform::RandomUniform(const Output<Node>& out_shape,
|
||||
const Output<Node>& min_val,
|
||||
const Output<Node>& max_val,
|
||||
const ngraph::element::Type& out_type,
|
||||
uint64_t global_seed,
|
||||
uint64_t op_seed)
|
||||
: Op({out_shape, min_val, max_val})
|
||||
, m_output_type(out_type)
|
||||
, m_global_seed(global_seed)
|
||||
, m_op_seed(op_seed)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v8::RandomUniform::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v8_RandomUniform_validate_and_infer_types);
|
||||
|
||||
const auto& shape_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_et.is_dynamic() || shape_et == element::i32 ||
|
||||
shape_et == element::i64,
|
||||
"Type of the input should be int32 or int64.");
|
||||
|
||||
PartialShape output_shape = PartialShape::dynamic();
|
||||
const auto& input_shape = get_input_partial_shape(0);
|
||||
if (input_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_shape.rank() == 1,
|
||||
"The rank of the tensor defining output shape must be equal to 1.");
|
||||
if (const auto& const_shape = get_constant_from_source(input_value(0)))
|
||||
{
|
||||
output_shape = PartialShape(const_shape->cast_vector<int64_t>());
|
||||
}
|
||||
}
|
||||
|
||||
const auto& min_pshape = get_input_partial_shape(1);
|
||||
const auto& max_pshape = get_input_partial_shape(2);
|
||||
if (min_pshape.is_static())
|
||||
{
|
||||
const auto& min_rank = min_pshape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this, min_rank <= 1, "Min value must be a scalar or 1D tensor.");
|
||||
|
||||
if (min_rank == 1)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, min_pshape.compatible(Shape{1}), "'min_val' should have 1 element.");
|
||||
}
|
||||
}
|
||||
|
||||
if (max_pshape.is_static())
|
||||
{
|
||||
const auto& max_rank = max_pshape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this, max_rank <= 1, "Max value must be a scalar or 1D tensor.");
|
||||
|
||||
if (max_rank == 1)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, max_pshape.compatible(Shape{1}), "'max_val' should have 1 element.");
|
||||
}
|
||||
}
|
||||
|
||||
const element::Type& min_element_type = get_input_element_type(1);
|
||||
element::Type max_element_type = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
min_element_type == max_element_type,
|
||||
"'min_val' should have the same type as 'max_val'.");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
min_element_type == get_out_type(),
|
||||
"'min_val' and 'max_val' should have the same type as 'out_type' attribute.");
|
||||
|
||||
if (const auto& const_min = get_constant_from_source(input_value(1)))
|
||||
{
|
||||
if (const auto& const_max = get_constant_from_source(input_value(2)))
|
||||
{
|
||||
if (get_out_type() == ngraph::element::Type_t::i64 ||
|
||||
get_out_type() == ngraph::element::Type_t::i32)
|
||||
{
|
||||
int64_t min_val = const_min->cast_vector<int64_t>()[0];
|
||||
int64_t max_val = const_max->cast_vector<int64_t>()[0];
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
min_val < max_val,
|
||||
"Min value must be less than max value. Got "
|
||||
"min value: ",
|
||||
min_val,
|
||||
", max value: ",
|
||||
max_val);
|
||||
}
|
||||
else if (get_out_type().is_real())
|
||||
{
|
||||
double min_val = const_min->cast_vector<double>()[0];
|
||||
double max_val = const_max->cast_vector<double>()[0];
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
min_val < max_val,
|
||||
"Min value must be less than max value. Got "
|
||||
"min value: ",
|
||||
min_val,
|
||||
", max value: ",
|
||||
max_val);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unsupported output type of RandomUniform: " +
|
||||
get_out_type().get_type_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, get_out_type(), output_shape);
|
||||
}
|
||||
|
||||
bool op::v8::RandomUniform::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v8_RandomUniform_visit_attributes);
|
||||
visitor.on_attribute("output_type", m_output_type);
|
||||
visitor.on_attribute("op_seed", m_op_seed);
|
||||
visitor.on_attribute("global_seed", m_global_seed);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v8::RandomUniform::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v8_Roll_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v8::RandomUniform>(
|
||||
new_args[0], new_args[1], new_args[2], m_output_type, m_global_seed, m_op_seed);
|
||||
}
|
@ -178,6 +178,7 @@ set(SRC
|
||||
type_prop/proposal.cpp
|
||||
type_prop/psroi_pooling.cpp
|
||||
type_prop/prior_box_clustered.cpp
|
||||
type_prop/random_uniform.cpp
|
||||
type_prop/range.cpp
|
||||
type_prop/read_value.cpp
|
||||
type_prop/reduce_l1.cpp
|
||||
@ -298,6 +299,7 @@ set(SRC
|
||||
visitors/op/prior_box_clustered.cpp
|
||||
visitors/op/proposal.cpp
|
||||
visitors/op/psroi_pooling.cpp
|
||||
visitors/op/random_uniform.cpp
|
||||
visitors/op/reduce_l1.cpp
|
||||
visitors/op/reduce_l2.cpp
|
||||
visitors/op/reduce_logical_and.cpp
|
||||
|
254
ngraph/test/type_prop/random_uniform.cpp
Normal file
254
ngraph/test/type_prop/random_uniform.cpp
Normal file
@ -0,0 +1,254 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, random_uniform_type_shape)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_dynamic_shape)
|
||||
{
|
||||
auto out_shape =
|
||||
make_shared<opset8::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto min_val = make_shared<opset8::Constant>(element::i64, Shape{}, 5);
|
||||
auto max_val = make_shared<opset8::Constant>(element::i64, Shape{}, 10);
|
||||
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i64, 100, 200);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::i64);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_dynamic_rank)
|
||||
{
|
||||
auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto min_val = make_shared<opset8::Constant>(element::f64, Shape{}, 5);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f64, Shape{}, 10);
|
||||
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f64, 100, 200);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f64);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_out_shape_type)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::f64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid output shape.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Type of the input should be int32 or int64."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_out_shape_rank)
|
||||
{
|
||||
auto out_shape = make_shared<opset8::Parameter>(element::i32, Shape{3, 2});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid output shape.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("The rank of the tensor defining output shape must be equal to 1."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_min_val)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i32, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = opset8::Constant::create(element::f32, Shape{2}, {2, 3});
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid min value.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("'min_val' should have 1 element."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_max_val)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i32, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = opset8::Constant::create(element::f32, Shape{3}, {2, 3, 5});
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid max value.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("'max_val' should have 1 element."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason.";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_min_max_val_type_case1)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::i32, Shape{}, 100);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid min value type.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("'min_val' should have the same type as 'max_val'."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_min_max_val_type_case2)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid min and max value type.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string(
|
||||
"'min_val' and 'max_val' should have the same type as 'out_type' attribute."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_min_max_values_case1)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1.f);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0.f);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid min and max values.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Min value must be less than max value."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_invalid_min_max_values_case2)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = make_shared<opset8::Constant>(element::i32, Shape{}, 100);
|
||||
auto max_val = make_shared<opset8::Constant>(element::i32, Shape{}, 100);
|
||||
|
||||
try
|
||||
{
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i32, 120, 100);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Unexpected pass with invalid min and max values.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Min value must be less than max value."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_min_max_1d_tensors)
|
||||
{
|
||||
auto out_shape = opset8::Constant::create(element::i64, Shape{4}, {2, 3, 4, 5});
|
||||
auto min_val = opset8::Constant::create(element::f32, Shape{1}, {-1.0});
|
||||
auto max_val = opset8::Constant::create(element::f32, Shape{1}, {2.0});
|
||||
|
||||
auto r =
|
||||
make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::f32, 120, 100);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 3, 4, 5}));
|
||||
}
|
36
ngraph/test/visitors/op/random_uniform.cpp
Normal file
36
ngraph/test/visitors/op/random_uniform.cpp
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, random_uniform_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset8::RandomUniform>();
|
||||
auto out_shape =
|
||||
make_shared<opset8::Constant>(element::i64, Shape{3}, vector<int64_t>{3, 2, 4});
|
||||
auto min_val = make_shared<opset8::Constant>(element::f32, Shape{}, 0);
|
||||
auto max_val = make_shared<opset8::Constant>(element::f32, Shape{}, 1);
|
||||
|
||||
const auto random_uniform = make_shared<opset8::RandomUniform>(
|
||||
out_shape, min_val, max_val, element::Type_t::f32, 150, 10);
|
||||
NodeBuilder builder(random_uniform);
|
||||
auto g_random_uniform = as_type_ptr<opset8::RandomUniform>(builder.create());
|
||||
|
||||
const auto expected_attr_count = 3;
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
EXPECT_EQ(g_random_uniform->get_global_seed(), random_uniform->get_global_seed());
|
||||
EXPECT_EQ(g_random_uniform->get_op_seed(), random_uniform->get_op_seed());
|
||||
EXPECT_EQ(g_random_uniform->get_out_type(), random_uniform->get_out_type());
|
||||
}
|
Loading…
Reference in New Issue
Block a user