[GPU] Fix canonicalization for fused dep's shape (#19667)
* [GPU] Fix canonicalization for fused dep's shape Signed-off-by: Andrew Park <andrew.park@intel.com> * Update TC to reproducible on the latest master Signed-off-by: Andrew Park <andrew.park@intel.com> * Fix custom canonicalize shapes for Gather --------- Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
committed by
GitHub
parent
631d6d3980
commit
394e58fafb
@@ -116,7 +116,12 @@ public:
|
||||
out_layout.format = format::adjust_to_rank(out_layout.format, output_pshape.size());
|
||||
}
|
||||
|
||||
return primitive_impl::static_canonicalize_shapes(updated_impl_params);
|
||||
for (auto& input_layout : updated_impl_params.input_layouts) {
|
||||
input_layout.set_partial_shape(extend_shape_to_rank_from_end(input_layout.get_partial_shape()));
|
||||
}
|
||||
out_layout.set_partial_shape(extend_shape_to_rank_from_end(out_layout.get_partial_shape()));
|
||||
|
||||
return updated_impl_params;
|
||||
}
|
||||
|
||||
kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
|
||||
|
||||
@@ -226,6 +226,45 @@ TEST(canonicalization, gather) {
|
||||
}
|
||||
}
|
||||
|
||||
struct fusing_gather_eltwise_params {
|
||||
ov::PartialShape data_shape;
|
||||
ov::Shape out_shape;
|
||||
int64_t axis;
|
||||
int64_t batch_dim;
|
||||
bool support_neg_ind;
|
||||
};
|
||||
|
||||
std::vector<std::pair<Shapes, fusing_gather_eltwise_params>> fusing_gather_eltwise_shapes_with_params {
|
||||
{
|
||||
{{{}, {}}, {{4624, 4, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {4624, 1, 1, 1}}, {{4624, 1, 1, 1}}},
|
||||
{{4624, 4}, {4624}, 1, 0, true}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(canonicalization, fusing_gather_eltwise) {
|
||||
for (const auto& shapes : fusing_gather_eltwise_shapes_with_params) {
|
||||
layout input_gather_layout = create_default_layout(shapes.second.data_shape);
|
||||
layout indices_layout_first = create_default_layout(std::get<0>(shapes.first)[0]);
|
||||
layout indices_layout_second = create_default_layout(std::get<0>(shapes.first)[0]);
|
||||
layout input_mul_layout = create_default_layout(std::get<0>(shapes.first)[1]);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input_gather_layout));
|
||||
topology.add(input_layout("indices_first", indices_layout_first));
|
||||
topology.add(input_layout("indices_second", indices_layout_second));
|
||||
topology.add(input_layout("data", input_mul_layout));
|
||||
topology.add(gather("gather_first", input_info("input"), input_info("indices_first"), shapes.second.axis,
|
||||
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
|
||||
topology.add(gather("gather_second", input_info("input"), input_info("indices_second"), shapes.second.axis,
|
||||
shapes.second.out_shape, shapes.second.batch_dim, shapes.second.support_neg_ind));
|
||||
topology.add(eltwise("mul", {input_info("gather_first"), input_info("data")}, eltwise_mode::prod));
|
||||
topology.add(eltwise("add", {input_info("gather_second"), input_info("mul")}, eltwise_mode::sum));
|
||||
topology.add(reorder("out_reorder", input_info("add"), format::bfyx, data_types::f32));
|
||||
|
||||
canonicalization_test(topology, "gather_first", std::get<1>(shapes.first), std::get<2>(shapes.first), true);
|
||||
}
|
||||
}
|
||||
|
||||
struct fusing_gemm_eltwise_params {
|
||||
ov::PartialShape input_gemm_first;
|
||||
ov::PartialShape weights_gemm_first;
|
||||
|
||||
Reference in New Issue
Block a user