Remove unnecessary AutoBroadcastSpec parameter in MatMulMultiplyFusion (#10005)
This commit is contained in:
parent
1fc61299c8
commit
ab4a11b3bd
@ -16,8 +16,7 @@ NGRAPH_RTTI_DEFINITION(pass::MatMulMultiplyFusion, "MatMulMultiplyFusion", 0);
|
||||
|
||||
static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>& matmul,
|
||||
const Output<Node>& weights,
|
||||
std::shared_ptr<opset8::Constant> mul_const,
|
||||
const op::AutoBroadcastSpec& autob) {
|
||||
std::shared_ptr<opset8::Constant> mul_const) {
|
||||
auto const_shape = mul_const->get_shape();
|
||||
auto const_rank = static_cast<int64_t>(const_shape.size());
|
||||
const auto& weights_shape = weights.get_partial_shape();
|
||||
@ -149,15 +148,13 @@ pass::MatMulMultiplyFusion::MatMulMultiplyFusion() {
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto& weights = pattern_map.at(weights_pattern);
|
||||
auto mul = std::dynamic_pointer_cast<opset8::Multiply>(pattern_map.at(mul_pattern).get_node_shared_ptr());
|
||||
if (!mul)
|
||||
return false;
|
||||
auto mul = pattern_map.at(mul_pattern).get_node_shared_ptr();
|
||||
auto mul_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_map.at(mul_const_pattern).get_node_shared_ptr());
|
||||
if (!mul_const)
|
||||
return false;
|
||||
auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
|
||||
|
||||
auto new_weights = fuse_const_to_weights(matmul, weights, mul_const, mul->get_autob());
|
||||
auto new_weights = fuse_const_to_weights(matmul, weights, mul_const);
|
||||
if (!new_weights)
|
||||
return false;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user