Remove unnecessary AutoBroadcastSpec parameter in MatMulMultiplyFusion (#10005)
This commit is contained in:
parent
1fc61299c8
commit
ab4a11b3bd
@ -16,8 +16,7 @@ NGRAPH_RTTI_DEFINITION(pass::MatMulMultiplyFusion, "MatMulMultiplyFusion", 0);
|
|||||||
|
|
||||||
static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>& matmul,
|
static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>& matmul,
|
||||||
const Output<Node>& weights,
|
const Output<Node>& weights,
|
||||||
std::shared_ptr<opset8::Constant> mul_const,
|
std::shared_ptr<opset8::Constant> mul_const) {
|
||||||
const op::AutoBroadcastSpec& autob) {
|
|
||||||
auto const_shape = mul_const->get_shape();
|
auto const_shape = mul_const->get_shape();
|
||||||
auto const_rank = static_cast<int64_t>(const_shape.size());
|
auto const_rank = static_cast<int64_t>(const_shape.size());
|
||||||
const auto& weights_shape = weights.get_partial_shape();
|
const auto& weights_shape = weights.get_partial_shape();
|
||||||
@ -149,15 +148,13 @@ pass::MatMulMultiplyFusion::MatMulMultiplyFusion() {
|
|||||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
const auto& weights = pattern_map.at(weights_pattern);
|
const auto& weights = pattern_map.at(weights_pattern);
|
||||||
auto mul = std::dynamic_pointer_cast<opset8::Multiply>(pattern_map.at(mul_pattern).get_node_shared_ptr());
|
auto mul = pattern_map.at(mul_pattern).get_node_shared_ptr();
|
||||||
if (!mul)
|
|
||||||
return false;
|
|
||||||
auto mul_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_map.at(mul_const_pattern).get_node_shared_ptr());
|
auto mul_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_map.at(mul_const_pattern).get_node_shared_ptr());
|
||||||
if (!mul_const)
|
if (!mul_const)
|
||||||
return false;
|
return false;
|
||||||
auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
|
auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
|
||||||
|
|
||||||
auto new_weights = fuse_const_to_weights(matmul, weights, mul_const, mul->get_autob());
|
auto new_weights = fuse_const_to_weights(matmul, weights, mul_const);
|
||||||
if (!new_weights)
|
if (!new_weights)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user