Test calculation output shape for Broadcast op, relax restrictions for partially dynamic input data (#1247)

This commit is contained in:
Mateusz Bencer 2020-08-10 13:39:14 +02:00 committed by GitHub
parent ffe8599c30
commit ae48d9deb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 638 additions and 45 deletions

View File

@ -90,15 +90,21 @@ std::pair<bool, AxisSet> op::v3::Broadcast::get_broadcast_axes() const
namespace
{
PartialShape
get_result_shape_bidirectional(const Node* this_ptr, Shape& arg_shape, Shape& target_shape)
PartialShape get_result_shape_bidirectional(const Node* this_ptr,
const PartialShape& arg_shape,
Shape& target_shape)
{
if (arg_shape.rank().is_dynamic())
{
return PartialShape::dynamic();
}
auto arg_shape_vec = static_cast<std::vector<Dimension>>(arg_shape);
PartialShape result_shape;
// Add left padding to shorter target or argument shape
const auto target_padded_rank = std::max(arg_shape.size(), target_shape.size());
while (arg_shape.size() < target_padded_rank)
const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
while (arg_shape_vec.size() < target_padded_rank)
{
arg_shape.insert(arg_shape.begin(), 1);
arg_shape_vec.insert(arg_shape_vec.begin(), 1);
}
while (target_shape.size() < target_padded_rank)
{
@ -108,15 +114,28 @@ namespace
result_shape = target_shape;
for (auto i = 0; i < target_shape.size(); ++i)
{
if (arg_shape_vec[i].is_dynamic())
{
if (target_shape[i] == 1)
{
result_shape[i] = Dimension::dynamic();
}
else
{
result_shape[i] = target_shape[i];
}
continue;
}
const size_t arg_shape_dim = arg_shape_vec[i].get_length();
NODE_VALIDATION_CHECK(this_ptr,
arg_shape[i] == 1 || target_shape[i] == 1 ||
arg_shape[i] == target_shape[i],
arg_shape_dim == 1 || target_shape[i] == 1 ||
arg_shape_dim == target_shape[i],
"Broadcast incorrect target shape. Expecting either 1 or ",
arg_shape[i],
arg_shape_dim,
". Got ",
target_shape[i]);
result_shape[i] = std::max(arg_shape[i], target_shape[i]);
result_shape[i] = std::max(arg_shape_dim, target_shape[i]);
}
return result_shape;
}
@ -143,9 +162,9 @@ void op::v3::Broadcast::validate_and_infer_types()
auto result_shape = get_output_partial_shape(0);
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
{
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static())
if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static())
{
auto arg_shape = get_input_shape(0);
auto arg_shape = get_input_partial_shape(0);
const auto shape_constant =
as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
@ -196,7 +215,8 @@ bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs,
{
auto arg_shape = inputs[0]->get_shape();
Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]);
PartialShape result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
PartialShape result_shape =
get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape);
auto pair_broadcast_axes =
get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape());
return op::util::BroadcastBase::evaluate_broadcast(

View File

@ -47,35 +47,79 @@ op::util::BroadcastBase::BroadcastBase(const Output<Node>& arg,
{
}
PartialShape op::util::BroadcastBase::get_result_shape_numpy_pdpd(
const Shape& arg0_shape,
PartialShape op::util::BroadcastBase::get_result_shape_pdpd(
const PartialShape& arg0_shape,
const Shape& target_shape,
const op::BroadcastModeSpec& broadcast_spec) const
{
if (arg0_shape.rank().is_dynamic())
{
return PartialShape::dynamic(target_shape.size());
}
const auto arg_rank_length = arg0_shape.rank().get_length();
PartialShape result_shape = target_shape;
auto start_axis = (broadcast_spec.m_type == op::BroadcastType::PDPD)
? broadcast_spec.m_axis
: target_shape.size() - arg0_shape.size();
auto start_axis = broadcast_spec.m_axis;
NODE_VALIDATION_CHECK(this,
start_axis >= 0,
"Broadcast target_shape has smaller rank ",
target_shape.size(),
" than arg shape ",
arg0_shape.size());
arg_rank_length);
for (auto i = start_axis; i < target_shape.size(); i++)
{
if (arg0_shape[i - start_axis].is_dynamic())
{
result_shape[i] = Dimension::dynamic();
continue;
}
const size_t arg_dim = arg0_shape[i - start_axis].get_length();
NODE_VALIDATION_CHECK(this,
arg0_shape[i - start_axis] == 1 || target_shape[i] == 1 ||
arg0_shape[i - start_axis] == target_shape[i],
arg_dim == 1 || target_shape[i] == 1 || arg_dim == target_shape[i],
"Broadcast incorrect target shape. Expecting either 1 or ",
arg0_shape[i - start_axis],
arg_dim,
" . Got ",
target_shape[i]);
result_shape[i] = std::max(arg0_shape[i - start_axis], target_shape[i]);
result_shape[i] = std::max(arg_dim, target_shape[i]);
}
return result_shape;
}
void op::util::BroadcastBase::validate_target_shape_numpy(const PartialShape& arg_shape,
const Shape& target_shape) const
{
if (arg_shape.rank().is_dynamic())
{
return;
}
const auto arg_rank_length = arg_shape.rank().get_length();
auto start_axis = target_shape.size() - arg_rank_length;
NODE_VALIDATION_CHECK(this,
start_axis >= 0,
"Broadcast target_shape has smaller rank ",
target_shape.size(),
" than arg shape ",
arg_rank_length);
for (auto i = start_axis; i < target_shape.size(); i++)
{
if (arg_shape[i - start_axis].is_dynamic())
{
continue;
}
const size_t arg_dim = arg_shape[i - start_axis].get_length();
NODE_VALIDATION_CHECK(this,
arg_dim == 1 || arg_dim == target_shape[i],
"Input shape dimension equal ",
arg_dim,
" cannot be broadcasted (numpy mode) to ",
target_shape[i],
". Allowed input dimension value would be 1",
target_shape[i] != 1
? (std::string(" or ") + std::to_string(target_shape[i])).c_str()
: "");
}
}
void op::util::BroadcastBase::validate_target_shape_none(const Shape& arg_shape,
const AxisVector& axes_mapping_val,
const Shape& target_shape) const
@ -142,13 +186,28 @@ void op::util::BroadcastBase::validate_and_infer_types()
}
PartialShape result_shape{PartialShape::dynamic()};
auto input_rank = input_value(0).get_partial_shape().rank();
auto output_rank = input_value(1).get_partial_shape();
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
const auto& input_shape = get_input_partial_shape(0);
const auto input_rank = input_shape.rank();
const auto& target_shape = input_value(1).get_partial_shape();
const bool is_target_shape_known =
target_shape.rank().is_static() && target_shape[0].is_static();
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
{
result_shape =
PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
if (input_rank.is_static() && is_target_shape_known)
{
result_shape = PartialShape::dynamic(
std::max(input_rank.get_length(), target_shape[0].get_length()));
}
}
else
{
if (is_target_shape_known)
{
result_shape = PartialShape::dynamic(target_shape[0].get_length());
}
}
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
if (auto concat = as_type_ptr<op::v0::Concat>(input_value(1).get_node_shared_ptr()))
@ -206,17 +265,21 @@ void op::util::BroadcastBase::validate_and_infer_types()
}
}
}
else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD)
else if (m_mode.m_type == BroadcastType::NUMPY)
{
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static())
if (shape_constant)
{
auto arg_shape = get_input_shape(0);
if (shape_constant)
{
const auto target_shape = shape_constant->get_shape_val();
result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode);
}
const auto target_shape = shape_constant->get_shape_val();
result_shape = target_shape;
validate_target_shape_numpy(input_shape, target_shape);
}
}
else if (m_mode.m_type == BroadcastType::PDPD)
{
if (shape_constant)
{
const auto target_shape = shape_constant->get_shape_val();
result_shape = get_result_shape_pdpd(input_shape, target_shape, m_mode);
}
}
set_output_type(0, get_input_element_type(0), result_shape);
@ -490,9 +553,16 @@ bool op::util::BroadcastBase::evaluate(const HostTensorVector& outputs,
validate_target_shape_none(inputs[0]->get_shape(), axes_mapping_val, target_shape);
result_shape = target_shape;
}
else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD)
else if (m_mode.m_type == BroadcastType::PDPD)
{
result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode);
result_shape = get_result_shape_pdpd(arg_shape, target_shape, m_mode);
pair_broadcast_axes =
get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode);
}
else if (m_mode.m_type == BroadcastType::NUMPY)
{
result_shape = target_shape;
validate_target_shape_numpy(arg_shape, target_shape);
pair_broadcast_axes =
get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode);
}

View File

@ -77,9 +77,13 @@ namespace ngraph
const AxisSet& broadcast_axes) const;
PartialShape
get_result_shape_numpy_pdpd(const Shape& arg0_shape,
const Shape& target_shape,
const op::BroadcastModeSpec& broadcast_spec) const;
get_result_shape_pdpd(const PartialShape& arg0_shape,
const Shape& target_shape,
const op::BroadcastModeSpec& broadcast_spec) const;
void validate_target_shape_numpy(const PartialShape& arg_shape,
const Shape& target_shape) const;
static std::pair<bool, AxisSet>
get_broadcast_axes_numpy_pdpd(const Shape& arg_shape,
const Shape& result_shape,

View File

@ -315,7 +315,7 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi)
Shape in_shape{1, 4, 1};
auto A = make_shared<op::Parameter>(element::f32, in_shape);
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 1, 4});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 4, 4});
auto bcast_v3_num = make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::NUMPY);
auto fun_num = make_shared<Function>(OutputVector{bcast_v3_num}, ParameterVector{A});
@ -343,6 +343,26 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi)
ASSERT_EQ(expec2, result_val2);
}
TEST(eval, evaluate_broadcast_v3_bidi_3d)
{
Shape in_shape{1, 4, 1};
auto A = make_shared<op::Parameter>(element::f32, in_shape);
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 1, 3});
auto bcast_v3_num =
make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::BIDIRECTIONAL);
auto fun_num = make_shared<Function>(OutputVector{bcast_v3_num}, ParameterVector{A});
auto result = make_shared<HostTensor>();
ASSERT_TRUE(fun_num->evaluate(
{result}, {make_host_tensor<element::Type_t::f32>(in_shape, {1.0f, 2.0f, 3.0f, 4.0f})}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_partial_shape(), (PartialShape{1, 4, 3}));
auto result_val = read_vector<float>(result);
vector<float> expec{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f};
ASSERT_EQ(expec, result_val);
}
TEST(eval, evaluate_broadcast_v3_bidi_4d)
{
Shape in_shape{4, 1, 1};

View File

@ -439,6 +439,359 @@ TYPED_TEST_P(BroadcastTests, broadcast_axes_et_wrong)
}
}
// EXPLICIT MODE
TYPED_TEST_P(BroadcastTests, broadcast_explicit_all_inputs_dynamic)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_static_rank)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto target_shape =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_input_rank_static)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_and_input_data_rank_static)
{
// static rank data
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_input)
{
const auto target_shape =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 5, 10});
// static rank data
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
// dynamic target shape and axes mapping
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape and const axes mapping
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape_const_target_shape)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{4});
auto target_shape = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 4, 2, 3});
// dynamic axes mapping
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
// const axes mapping
const auto axes_mapping_const =
op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_target_shape)
{
// dynamic input
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape{4});
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
// static rank input
data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
}
// NUMPY MODE
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_shape_dynamic)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
// dynamic output shape
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_constant)
{
// dynamic data
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto target_shape =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
// static rank data
data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_dynamic)
{
// static rank data
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// static shape data
data = make_shared<op::Parameter>(element::f32, PartialShape{3, 4, 5, 6});
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_target_shape_static_rank)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
const auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_static_shape)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
// static rank target_shape
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
// constant target_shape
const auto target_shape_const =
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{3, 2, 3});
bc = make_shared<TypeParam>(data, target_shape_const, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{3, 2, 3}));
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_partially_dynamic)
{
const Shape expected_target_shape{1, 2, 3, 4};
const auto target_shape = op::Constant::create(
element::i64,
{expected_target_shape.size()},
std::vector<int64_t>(expected_target_shape.begin(), expected_target_shape.end()));
auto data = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
data = make_shared<op::Parameter>(element::f32,
PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()});
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
data = make_shared<op::Parameter>(element::f32,
PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
data = make_shared<op::Parameter>(
element::f32,
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
}
TYPED_TEST_P(BroadcastTests, broadcast_numpy_static_dims_incorrect)
{
const auto target_shape = op::Constant::create(element::i64, Shape{4}, {1, 2, 3, 4});
auto data =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 999, 3, 4});
try
{
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input shape dimension equal 999 cannot be broadcasted (numpy mode) "
"to 2. Allowed input dimension value would be 1 or 2");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
data = make_shared<op::Parameter>(
element::f32,
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 888});
try
{
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input shape dimension equal 888 cannot be broadcasted (numpy mode) "
"to 4. Allowed input dimension value would be 1 or 4");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
data = make_shared<op::Parameter>(
element::f32,
PartialShape{5, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
try
{
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Input shape dimension equal 5 cannot be broadcasted (numpy mode) to "
"1. Allowed input dimension value would be 1");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
broadcast_numpy,
broadcast_axes_mapping,
@ -451,7 +804,23 @@ REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
broadcast_axes_wrong_rank,
broadcast_fully_dynamic_target_shape,
broadcast_broadcast_shape_et_wrong,
broadcast_axes_et_wrong);
broadcast_axes_et_wrong,
broadcast_explicit_all_inputs_dynamic,
broadcast_explicit_target_shape_static_rank,
broadcast_explicit_const_target_shape,
broadcast_explicit_input_rank_static,
broadcast_explicit_target_shape_and_input_data_rank_static,
broadcast_explicit_const_target_shape_static_rank_input,
broadcast_explicit_static_input_shape,
broadcast_explicit_static_input_shape_const_target_shape,
broadcast_explicit_static_target_shape,
broadcast_numpy_input_shape_dynamic,
broadcast_numpy_target_shape_constant,
broadcast_numpy_target_shape_dynamic,
broadcast_numpy_input_target_shape_static_rank,
broadcast_numpy_input_static_shape,
broadcast_numpy_input_partially_dynamic,
broadcast_numpy_static_dims_incorrect);
typedef ::testing::Types<op::v1::Broadcast, op::v3::Broadcast> BroadcastTypes;
// the last empty argument resolves compiler warning on MAC:
@ -696,7 +1065,8 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_arg)
const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 8, 6, 4}));
}
TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
@ -706,5 +1076,114 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(5)));
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 5);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
PartialShape{8, 6, Dimension::dynamic(), 5, Dimension::dynamic()}));
}
TEST(type_prop, broadcast_v3_bidirectional_dynamic_input)
{
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
// dynamic target shape
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// constant target shape
const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, broadcast_v3_bidirectional_static_rank_input)
{
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
// dynamic target shape
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// constant target shape
const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_dynamic());
}
TEST(type_prop, broadcast_v3_bidirectional_static_shape_input)
{
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 1});
// dynamic target shape
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// static rank target shape
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
// constant target shape
auto target_shape_const = op::Constant::create(element::i64, {4}, {2, 2, 3, 2});
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{2, 2, 3, 2}));
target_shape_const = op::Constant::create(element::i64, {4}, {5, 2, 3, 7});
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{5, 2, 3, 7}));
}
TEST(type_prop, broadcast_v3_bidirectional_partially_dynamic_input)
{
const auto target_shape =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 50, 50});
auto data = make_shared<op::Parameter>(element::f32, PartialShape{16, 1, Dimension::dynamic()});
auto bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
data = make_shared<op::Parameter>(element::f32,
PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()});
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
data = make_shared<op::Parameter>(element::f32,
PartialShape{16, Dimension::dynamic(), Dimension::dynamic()});
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
data = make_shared<op::Parameter>(
element::f32,
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
}