Remove unnecessary AutoBroadcastSpec parameter in MatMulMultiplyFusion (#10005)

This commit is contained in:
Mateusz Tabaka 2022-02-17 06:51:32 +01:00 committed by GitHub
parent 1fc61299c8
commit ab4a11b3bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,8 +16,7 @@ NGRAPH_RTTI_DEFINITION(pass::MatMulMultiplyFusion, "MatMulMultiplyFusion", 0);
static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>& matmul, static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>& matmul,
const Output<Node>& weights, const Output<Node>& weights,
std::shared_ptr<opset8::Constant> mul_const, std::shared_ptr<opset8::Constant> mul_const) {
const op::AutoBroadcastSpec& autob) {
auto const_shape = mul_const->get_shape(); auto const_shape = mul_const->get_shape();
auto const_rank = static_cast<int64_t>(const_shape.size()); auto const_rank = static_cast<int64_t>(const_shape.size());
const auto& weights_shape = weights.get_partial_shape(); const auto& weights_shape = weights.get_partial_shape();
@ -149,15 +148,13 @@ pass::MatMulMultiplyFusion::MatMulMultiplyFusion() {
matcher_pass_callback callback = [=](pattern::Matcher& m) { matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map(); const auto& pattern_map = m.get_pattern_value_map();
const auto& weights = pattern_map.at(weights_pattern); const auto& weights = pattern_map.at(weights_pattern);
auto mul = std::dynamic_pointer_cast<opset8::Multiply>(pattern_map.at(mul_pattern).get_node_shared_ptr()); auto mul = pattern_map.at(mul_pattern).get_node_shared_ptr();
if (!mul)
return false;
auto mul_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_map.at(mul_const_pattern).get_node_shared_ptr()); auto mul_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_map.at(mul_const_pattern).get_node_shared_ptr());
if (!mul_const) if (!mul_const)
return false; return false;
auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr(); auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
auto new_weights = fuse_const_to_weights(matmul, weights, mul_const, mul->get_autob()); auto new_weights = fuse_const_to_weights(matmul, weights, mul_const);
if (!new_weights) if (!new_weights)
return false; return false;