Add input transposing for MatMul (#1462)
This commit is contained in:
@@ -258,7 +258,9 @@ std::shared_ptr<Node> makeShuffleChannels(const ngraph::Output<Node> &in,
|
||||
int group);
|
||||
|
||||
std::shared_ptr<Node> makeMatMul(const Output<Node> &A,
|
||||
const Output<Node> &B);
|
||||
const Output<Node> &B,
|
||||
bool transpose_a = false,
|
||||
bool transpose_b = false);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeReduce(std::vector<ngraph::Output<Node>> &in,
|
||||
const std::vector<int> &reductionAxes,
|
||||
|
||||
@@ -8,8 +8,10 @@ namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<Node> makeMatMul(const Output<Node>& A,
|
||||
const Output<Node>& B) {
|
||||
return std::make_shared<ngraph::opset3::MatMul>(A, B);
|
||||
const Output<Node>& B,
|
||||
bool transpose_a,
|
||||
bool transpose_b) {
|
||||
return std::make_shared<ngraph::opset3::MatMul>(A, B, transpose_a, transpose_b);
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
|
||||
Reference in New Issue
Block a user