[GPU] Fix canonicalization for fused dep's shape (#19667)

* [GPU] Fix canonicalization for fused dep's shape

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Update TC to reproducible on the latest master

Signed-off-by: Andrew Park <andrew.park@intel.com>

* Fix custom canonicalize shapes for Gather

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park
2023-09-20 08:57:10 +09:00
committed by GitHub
parent 631d6d3980
commit 394e58fafb
2 changed files with 45 additions and 1 deletions

View File

@@ -116,7 +116,12 @@ public:
out_layout.format = format::adjust_to_rank(out_layout.format, output_pshape.size());
}
return primitive_impl::static_canonicalize_shapes(updated_impl_params);
for (auto& input_layout : updated_impl_params.input_layouts) {
input_layout.set_partial_shape(extend_shape_to_rank_from_end(input_layout.get_partial_shape()));
}
out_layout.set_partial_shape(extend_shape_to_rank_from_end(out_layout.get_partial_shape()));
return updated_impl_params;
}
kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {

View File

@@ -226,6 +226,45 @@ TEST(canonicalization, gather) {
}
}
struct fusing_gather_eltwise_params {
ov::PartialShape data_shape;
ov::Shape out_shape;
int64_t axis;
int64_t batch_dim;
bool support_neg_ind;
};
std::vector<std::pair<Shapes, fusing_gather_eltwise_params>> fusing_gather_eltwise_shapes_with_params {
{
{{{}, {}}, {{4624, 4, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {4624, 1, 1, 1}}, {{4624, 1, 1, 1}}},
{{4624, 4}, {4624}, 1, 0, true}
}
};
TEST(canonicalization, fusing_gather_eltwise) {
for (const auto& shapes : fusing_gather_eltwise_shapes_with_params) {
layout input_gather_layout = create_default_layout(shapes.second.data_shape);
layout indices_layout_first = create_default_layout(std::get<0>(shapes.first)[0]);
layout indices_layout_second = create_default_layout(std::get<0>(shapes.first)[0]);
layout input_mul_layout = create_default_layout(std::get<0>(shapes.first)[1]);
topology topology;
topology.add(input_layout("input", input_gather_layout));
topology.add(input_layout("indices_first", indices_layout_first));
topology.add(input_layout("indices_second", indices_layout_second));
topology.add(input_layout("data", input_mul_layout));
topology.add(gather("gather_first", input_info("input"), input_info("indices_first"), shapes.second.axis,
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
topology.add(gather("gather_second", input_info("input"), input_info("indices_second"), shapes.second.axis,
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
topology.add(eltwise("mul", {input_info("gather_first"), input_info("data")}, eltwise_mode::prod));
topology.add(eltwise("add", {input_info("gather_second"), input_info("mul")}, eltwise_mode::sum));
topology.add(reorder("out_reorder", input_info("add"), format::bfyx, data_types::f32));
canonicalization_test(topology, "gather_first", std::get<1>(shapes.first), std::get<2>(shapes.first), true);
}
}
struct fusing_gemm_eltwise_params {
ov::PartialShape input_gemm_first;
ov::PartialShape weights_gemm_first;