Test calculation output shape for Broadcast op, relax restrictions for partially dynamic input data (#1247)
This commit is contained in:
parent
ffe8599c30
commit
ae48d9deb8
@ -90,15 +90,21 @@ std::pair<bool, AxisSet> op::v3::Broadcast::get_broadcast_axes() const
|
||||
|
||||
namespace
|
||||
{
|
||||
PartialShape
|
||||
get_result_shape_bidirectional(const Node* this_ptr, Shape& arg_shape, Shape& target_shape)
|
||||
PartialShape get_result_shape_bidirectional(const Node* this_ptr,
|
||||
const PartialShape& arg_shape,
|
||||
Shape& target_shape)
|
||||
{
|
||||
if (arg_shape.rank().is_dynamic())
|
||||
{
|
||||
return PartialShape::dynamic();
|
||||
}
|
||||
auto arg_shape_vec = static_cast<std::vector<Dimension>>(arg_shape);
|
||||
PartialShape result_shape;
|
||||
// Add left padding to shorter target or argument shape
|
||||
const auto target_padded_rank = std::max(arg_shape.size(), target_shape.size());
|
||||
while (arg_shape.size() < target_padded_rank)
|
||||
const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
|
||||
while (arg_shape_vec.size() < target_padded_rank)
|
||||
{
|
||||
arg_shape.insert(arg_shape.begin(), 1);
|
||||
arg_shape_vec.insert(arg_shape_vec.begin(), 1);
|
||||
}
|
||||
while (target_shape.size() < target_padded_rank)
|
||||
{
|
||||
@ -108,15 +114,28 @@ namespace
|
||||
result_shape = target_shape;
|
||||
for (auto i = 0; i < target_shape.size(); ++i)
|
||||
{
|
||||
if (arg_shape_vec[i].is_dynamic())
|
||||
{
|
||||
if (target_shape[i] == 1)
|
||||
{
|
||||
result_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
result_shape[i] = target_shape[i];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const size_t arg_shape_dim = arg_shape_vec[i].get_length();
|
||||
NODE_VALIDATION_CHECK(this_ptr,
|
||||
arg_shape[i] == 1 || target_shape[i] == 1 ||
|
||||
arg_shape[i] == target_shape[i],
|
||||
arg_shape_dim == 1 || target_shape[i] == 1 ||
|
||||
arg_shape_dim == target_shape[i],
|
||||
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||
arg_shape[i],
|
||||
arg_shape_dim,
|
||||
". Got ",
|
||||
target_shape[i]);
|
||||
|
||||
result_shape[i] = std::max(arg_shape[i], target_shape[i]);
|
||||
result_shape[i] = std::max(arg_shape_dim, target_shape[i]);
|
||||
}
|
||||
return result_shape;
|
||||
}
|
||||
@ -143,9 +162,9 @@ void op::v3::Broadcast::validate_and_infer_types()
|
||||
auto result_shape = get_output_partial_shape(0);
|
||||
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
|
||||
{
|
||||
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static())
|
||||
if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static())
|
||||
{
|
||||
auto arg_shape = get_input_shape(0);
|
||||
auto arg_shape = get_input_partial_shape(0);
|
||||
|
||||
const auto shape_constant =
|
||||
as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
|
||||
@ -196,7 +215,8 @@ bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs,
|
||||
{
|
||||
auto arg_shape = inputs[0]->get_shape();
|
||||
Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]);
|
||||
PartialShape result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
|
||||
PartialShape result_shape =
|
||||
get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape);
|
||||
auto pair_broadcast_axes =
|
||||
get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape());
|
||||
return op::util::BroadcastBase::evaluate_broadcast(
|
||||
|
@ -47,35 +47,79 @@ op::util::BroadcastBase::BroadcastBase(const Output<Node>& arg,
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape op::util::BroadcastBase::get_result_shape_numpy_pdpd(
|
||||
const Shape& arg0_shape,
|
||||
PartialShape op::util::BroadcastBase::get_result_shape_pdpd(
|
||||
const PartialShape& arg0_shape,
|
||||
const Shape& target_shape,
|
||||
const op::BroadcastModeSpec& broadcast_spec) const
|
||||
{
|
||||
if (arg0_shape.rank().is_dynamic())
|
||||
{
|
||||
return PartialShape::dynamic(target_shape.size());
|
||||
}
|
||||
const auto arg_rank_length = arg0_shape.rank().get_length();
|
||||
PartialShape result_shape = target_shape;
|
||||
auto start_axis = (broadcast_spec.m_type == op::BroadcastType::PDPD)
|
||||
? broadcast_spec.m_axis
|
||||
: target_shape.size() - arg0_shape.size();
|
||||
auto start_axis = broadcast_spec.m_axis;
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
start_axis >= 0,
|
||||
"Broadcast target_shape has smaller rank ",
|
||||
target_shape.size(),
|
||||
" than arg shape ",
|
||||
arg0_shape.size());
|
||||
arg_rank_length);
|
||||
for (auto i = start_axis; i < target_shape.size(); i++)
|
||||
{
|
||||
if (arg0_shape[i - start_axis].is_dynamic())
|
||||
{
|
||||
result_shape[i] = Dimension::dynamic();
|
||||
continue;
|
||||
}
|
||||
const size_t arg_dim = arg0_shape[i - start_axis].get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
arg0_shape[i - start_axis] == 1 || target_shape[i] == 1 ||
|
||||
arg0_shape[i - start_axis] == target_shape[i],
|
||||
arg_dim == 1 || target_shape[i] == 1 || arg_dim == target_shape[i],
|
||||
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||
arg0_shape[i - start_axis],
|
||||
arg_dim,
|
||||
" . Got ",
|
||||
target_shape[i]);
|
||||
result_shape[i] = std::max(arg0_shape[i - start_axis], target_shape[i]);
|
||||
result_shape[i] = std::max(arg_dim, target_shape[i]);
|
||||
}
|
||||
return result_shape;
|
||||
}
|
||||
|
||||
void op::util::BroadcastBase::validate_target_shape_numpy(const PartialShape& arg_shape,
|
||||
const Shape& target_shape) const
|
||||
{
|
||||
if (arg_shape.rank().is_dynamic())
|
||||
{
|
||||
return;
|
||||
}
|
||||
const auto arg_rank_length = arg_shape.rank().get_length();
|
||||
auto start_axis = target_shape.size() - arg_rank_length;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
start_axis >= 0,
|
||||
"Broadcast target_shape has smaller rank ",
|
||||
target_shape.size(),
|
||||
" than arg shape ",
|
||||
arg_rank_length);
|
||||
for (auto i = start_axis; i < target_shape.size(); i++)
|
||||
{
|
||||
if (arg_shape[i - start_axis].is_dynamic())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
const size_t arg_dim = arg_shape[i - start_axis].get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
arg_dim == 1 || arg_dim == target_shape[i],
|
||||
"Input shape dimension equal ",
|
||||
arg_dim,
|
||||
" cannot be broadcasted (numpy mode) to ",
|
||||
target_shape[i],
|
||||
". Allowed input dimension value would be 1",
|
||||
target_shape[i] != 1
|
||||
? (std::string(" or ") + std::to_string(target_shape[i])).c_str()
|
||||
: "");
|
||||
}
|
||||
}
|
||||
|
||||
void op::util::BroadcastBase::validate_target_shape_none(const Shape& arg_shape,
|
||||
const AxisVector& axes_mapping_val,
|
||||
const Shape& target_shape) const
|
||||
@ -142,13 +186,28 @@ void op::util::BroadcastBase::validate_and_infer_types()
|
||||
}
|
||||
|
||||
PartialShape result_shape{PartialShape::dynamic()};
|
||||
auto input_rank = input_value(0).get_partial_shape().rank();
|
||||
auto output_rank = input_value(1).get_partial_shape();
|
||||
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
|
||||
const auto& input_shape = get_input_partial_shape(0);
|
||||
const auto input_rank = input_shape.rank();
|
||||
const auto& target_shape = input_value(1).get_partial_shape();
|
||||
const bool is_target_shape_known =
|
||||
target_shape.rank().is_static() && target_shape[0].is_static();
|
||||
|
||||
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL)
|
||||
{
|
||||
result_shape =
|
||||
PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
|
||||
if (input_rank.is_static() && is_target_shape_known)
|
||||
{
|
||||
result_shape = PartialShape::dynamic(
|
||||
std::max(input_rank.get_length(), target_shape[0].get_length()));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (is_target_shape_known)
|
||||
{
|
||||
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||
}
|
||||
}
|
||||
|
||||
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
|
||||
|
||||
if (auto concat = as_type_ptr<op::v0::Concat>(input_value(1).get_node_shared_ptr()))
|
||||
@ -206,17 +265,21 @@ void op::util::BroadcastBase::validate_and_infer_types()
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD)
|
||||
else if (m_mode.m_type == BroadcastType::NUMPY)
|
||||
{
|
||||
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static())
|
||||
if (shape_constant)
|
||||
{
|
||||
auto arg_shape = get_input_shape(0);
|
||||
|
||||
if (shape_constant)
|
||||
{
|
||||
const auto target_shape = shape_constant->get_shape_val();
|
||||
result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode);
|
||||
}
|
||||
const auto target_shape = shape_constant->get_shape_val();
|
||||
result_shape = target_shape;
|
||||
validate_target_shape_numpy(input_shape, target_shape);
|
||||
}
|
||||
}
|
||||
else if (m_mode.m_type == BroadcastType::PDPD)
|
||||
{
|
||||
if (shape_constant)
|
||||
{
|
||||
const auto target_shape = shape_constant->get_shape_val();
|
||||
result_shape = get_result_shape_pdpd(input_shape, target_shape, m_mode);
|
||||
}
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), result_shape);
|
||||
@ -490,9 +553,16 @@ bool op::util::BroadcastBase::evaluate(const HostTensorVector& outputs,
|
||||
validate_target_shape_none(inputs[0]->get_shape(), axes_mapping_val, target_shape);
|
||||
result_shape = target_shape;
|
||||
}
|
||||
else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD)
|
||||
else if (m_mode.m_type == BroadcastType::PDPD)
|
||||
{
|
||||
result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode);
|
||||
result_shape = get_result_shape_pdpd(arg_shape, target_shape, m_mode);
|
||||
pair_broadcast_axes =
|
||||
get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode);
|
||||
}
|
||||
else if (m_mode.m_type == BroadcastType::NUMPY)
|
||||
{
|
||||
result_shape = target_shape;
|
||||
validate_target_shape_numpy(arg_shape, target_shape);
|
||||
pair_broadcast_axes =
|
||||
get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode);
|
||||
}
|
||||
|
@ -77,9 +77,13 @@ namespace ngraph
|
||||
const AxisSet& broadcast_axes) const;
|
||||
|
||||
PartialShape
|
||||
get_result_shape_numpy_pdpd(const Shape& arg0_shape,
|
||||
const Shape& target_shape,
|
||||
const op::BroadcastModeSpec& broadcast_spec) const;
|
||||
get_result_shape_pdpd(const PartialShape& arg0_shape,
|
||||
const Shape& target_shape,
|
||||
const op::BroadcastModeSpec& broadcast_spec) const;
|
||||
|
||||
void validate_target_shape_numpy(const PartialShape& arg_shape,
|
||||
const Shape& target_shape) const;
|
||||
|
||||
static std::pair<bool, AxisSet>
|
||||
get_broadcast_axes_numpy_pdpd(const Shape& arg_shape,
|
||||
const Shape& result_shape,
|
||||
|
@ -315,7 +315,7 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi)
|
||||
Shape in_shape{1, 4, 1};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, in_shape);
|
||||
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 1, 4});
|
||||
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 4, 4});
|
||||
auto bcast_v3_num = make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::NUMPY);
|
||||
auto fun_num = make_shared<Function>(OutputVector{bcast_v3_num}, ParameterVector{A});
|
||||
|
||||
@ -343,6 +343,26 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi)
|
||||
ASSERT_EQ(expec2, result_val2);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_broadcast_v3_bidi_3d)
|
||||
{
|
||||
Shape in_shape{1, 4, 1};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, in_shape);
|
||||
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 1, 3});
|
||||
auto bcast_v3_num =
|
||||
make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::BIDIRECTIONAL);
|
||||
auto fun_num = make_shared<Function>(OutputVector{bcast_v3_num}, ParameterVector{A});
|
||||
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun_num->evaluate(
|
||||
{result}, {make_host_tensor<element::Type_t::f32>(in_shape, {1.0f, 2.0f, 3.0f, 4.0f})}));
|
||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result->get_partial_shape(), (PartialShape{1, 4, 3}));
|
||||
auto result_val = read_vector<float>(result);
|
||||
vector<float> expec{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f};
|
||||
ASSERT_EQ(expec, result_val);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_broadcast_v3_bidi_4d)
|
||||
{
|
||||
Shape in_shape{4, 1, 1};
|
||||
|
@ -439,6 +439,359 @@ TYPED_TEST_P(BroadcastTests, broadcast_axes_et_wrong)
|
||||
}
|
||||
}
|
||||
|
||||
// EXPLICIT MODE
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_all_inputs_dynamic)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_static_rank)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto target_shape =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_input_rank_static)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_and_input_data_rank_static)
|
||||
{
|
||||
// static rank data
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_input)
|
||||
{
|
||||
const auto target_shape =
|
||||
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 5, 10});
|
||||
// static rank data
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
|
||||
// dynamic target shape and axes mapping
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape and const axes mapping
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape_const_target_shape)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{4});
|
||||
auto target_shape = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 4, 2, 3});
|
||||
// dynamic axes mapping
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const =
|
||||
op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_target_shape)
|
||||
{
|
||||
// dynamic input
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape{4});
|
||||
const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
|
||||
|
||||
// static rank input
|
||||
data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
|
||||
}
|
||||
|
||||
// NUMPY MODE
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_shape_dynamic)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
// dynamic output shape
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_constant)
|
||||
{
|
||||
// dynamic data
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto target_shape =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
|
||||
|
||||
// static rank data
|
||||
data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_dynamic)
|
||||
{
|
||||
// static rank data
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static shape data
|
||||
data = make_shared<op::Parameter>(element::f32, PartialShape{3, 4, 5, 6});
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_target_shape_static_rank)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
|
||||
const auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_static_shape)
|
||||
{
|
||||
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
|
||||
// static rank target_shape
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// constant target_shape
|
||||
const auto target_shape_const =
|
||||
op::Constant::create(element::i64, Shape{3}, vector<int64_t>{3, 2, 3});
|
||||
bc = make_shared<TypeParam>(data, target_shape_const, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{3, 2, 3}));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_partially_dynamic)
|
||||
{
|
||||
const Shape expected_target_shape{1, 2, 3, 4};
|
||||
const auto target_shape = op::Constant::create(
|
||||
element::i64,
|
||||
{expected_target_shape.size()},
|
||||
std::vector<int64_t>(expected_target_shape.begin(), expected_target_shape.end()));
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
|
||||
|
||||
data = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()});
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
|
||||
|
||||
data = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
|
||||
|
||||
data = make_shared<op::Parameter>(
|
||||
element::f32,
|
||||
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
|
||||
bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_numpy_static_dims_incorrect)
|
||||
{
|
||||
const auto target_shape = op::Constant::create(element::i64, Shape{4}, {1, 2, 3, 4});
|
||||
|
||||
auto data =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 999, 3, 4});
|
||||
try
|
||||
{
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Input shape dimension equal 999 cannot be broadcasted (numpy mode) "
|
||||
"to 2. Allowed input dimension value would be 1 or 2");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
||||
data = make_shared<op::Parameter>(
|
||||
element::f32,
|
||||
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 888});
|
||||
try
|
||||
{
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Input shape dimension equal 888 cannot be broadcasted (numpy mode) "
|
||||
"to 4. Allowed input dimension value would be 1 or 4");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
||||
data = make_shared<op::Parameter>(
|
||||
element::f32,
|
||||
PartialShape{5, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
|
||||
try
|
||||
{
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Input shape dimension equal 5 cannot be broadcasted (numpy mode) to "
|
||||
"1. Allowed input dimension value would be 1");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
|
||||
broadcast_numpy,
|
||||
broadcast_axes_mapping,
|
||||
@ -451,7 +804,23 @@ REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
|
||||
broadcast_axes_wrong_rank,
|
||||
broadcast_fully_dynamic_target_shape,
|
||||
broadcast_broadcast_shape_et_wrong,
|
||||
broadcast_axes_et_wrong);
|
||||
broadcast_axes_et_wrong,
|
||||
broadcast_explicit_all_inputs_dynamic,
|
||||
broadcast_explicit_target_shape_static_rank,
|
||||
broadcast_explicit_const_target_shape,
|
||||
broadcast_explicit_input_rank_static,
|
||||
broadcast_explicit_target_shape_and_input_data_rank_static,
|
||||
broadcast_explicit_const_target_shape_static_rank_input,
|
||||
broadcast_explicit_static_input_shape,
|
||||
broadcast_explicit_static_input_shape_const_target_shape,
|
||||
broadcast_explicit_static_target_shape,
|
||||
broadcast_numpy_input_shape_dynamic,
|
||||
broadcast_numpy_target_shape_constant,
|
||||
broadcast_numpy_target_shape_dynamic,
|
||||
broadcast_numpy_input_target_shape_static_rank,
|
||||
broadcast_numpy_input_static_shape,
|
||||
broadcast_numpy_input_partially_dynamic,
|
||||
broadcast_numpy_static_dims_incorrect);
|
||||
|
||||
typedef ::testing::Types<op::v1::Broadcast, op::v3::Broadcast> BroadcastTypes;
|
||||
// the last empty argument resolves compiler warning on MAC:
|
||||
@ -696,7 +1065,8 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_arg)
|
||||
const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
|
||||
|
||||
const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
|
||||
PartialShape{Dimension::dynamic(), 8, 6, 4}));
|
||||
}
|
||||
|
||||
TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
|
||||
@ -706,5 +1076,114 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
|
||||
const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
|
||||
|
||||
const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(5)));
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 5);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
|
||||
PartialShape{8, 6, Dimension::dynamic(), 5, Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop, broadcast_v3_bidirectional_dynamic_input)
|
||||
{
|
||||
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
// dynamic target shape
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// constant target shape
|
||||
const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, broadcast_v3_bidirectional_static_rank_input)
|
||||
{
|
||||
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
|
||||
|
||||
// dynamic target shape
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// constant target shape
|
||||
const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, broadcast_v3_bidirectional_static_shape_input)
|
||||
{
|
||||
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 1});
|
||||
|
||||
// dynamic target shape
|
||||
auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// static rank target shape
|
||||
target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
|
||||
|
||||
// constant target shape
|
||||
auto target_shape_const = op::Constant::create(element::i64, {4}, {2, 2, 3, 2});
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{2, 2, 3, 2}));
|
||||
|
||||
target_shape_const = op::Constant::create(element::i64, {4}, {5, 2, 3, 7});
|
||||
broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{5, 2, 3, 7}));
|
||||
}
|
||||
|
||||
TEST(type_prop, broadcast_v3_bidirectional_partially_dynamic_input)
|
||||
{
|
||||
const auto target_shape =
|
||||
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 50, 50});
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{16, 1, Dimension::dynamic()});
|
||||
auto bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
|
||||
|
||||
data = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()});
|
||||
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
|
||||
|
||||
data = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{16, Dimension::dynamic(), Dimension::dynamic()});
|
||||
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
|
||||
|
||||
data = make_shared<op::Parameter>(
|
||||
element::f32,
|
||||
PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
|
||||
bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user