Fix Transpose shape issue

This commit is contained in:
River.Li 2023-06-15 13:00:07 +08:00
parent 5053b786a0
commit 6a26a4faa0
4 changed files with 11 additions and 3 deletions

View File

@ -701,6 +701,10 @@ pass::EliminateTranspose::EliminateTranspose() {
} }
const auto& order_values = order_const->cast_vector<int64_t>(); const auto& order_values = order_const->cast_vector<int64_t>();
// Cannot eliminate Transpose when empty order_value.
if (order_values.size() == 0)
return false;
vector<int64_t> ref_values(order_values.size()); vector<int64_t> ref_values(order_values.size());
iota(ref_values.begin(), ref_values.end(), 0); iota(ref_values.begin(), ref_values.end(), 0);
if (order_values != ref_values) { if (order_values != ref_values) {

View File

@ -128,7 +128,9 @@ bool ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
const size_t num_digits_in_pass_index = 3; const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index); std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str; index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name(); static size_t idx = 0;
auto base_filename = func->get_name() + std::string("_") + std::to_string(idx++) + std::string("_") +
index_str + std::string("_") + pass->get_name();
if (m_visualize) { if (m_visualize) {
auto file_ext = "svg"; auto file_ext = "svg";

View File

@ -181,6 +181,9 @@ void SyncInferRequest::redefineMemoryForInputNodes() {
void SyncInferRequest::update_external_inputs() { void SyncInferRequest::update_external_inputs() {
// Update it due to batched_tensors case will update input tensor // Update it due to batched_tensors case will update input tensor
if (m_batched_tensors.size() == 0)
return;
// for (auto input : _compiled_model->get_orig_model()->inputs()) {
for (auto input : get_inputs()) { for (auto input : get_inputs()) {
std::string input_name = get_port_name(input); std::string input_name = get_port_name(input);
if (input_name.empty()) { if (input_name.empty()) {
@ -398,7 +401,7 @@ void SyncInferRequest::check_port(const ov::Output<const ov::Node>& port) const
if (name.empty() || (_input_ports_map.find(name) == _input_ports_map.end() && if (name.empty() || (_input_ports_map.find(name) == _input_ports_map.end() &&
_output_ports_map.find(name) == _output_ports_map.end())) { _output_ports_map.find(name) == _output_ports_map.end())) {
OPENVINO_THROW("cpu plugin checking port failed: cannot find this port!"); OPENVINO_THROW("cpu plugin checking port failed: cannot find this port with name ", name);
} }
} }

View File

@ -119,7 +119,6 @@ Node::Node(const std::shared_ptr<ngraph::Node>& op,
bool isScalar = shape.rank().get_length() == 0; bool isScalar = shape.rank().get_length() == 0;
outputShapes.emplace_back(isScalar ? ngraph::PartialShape{1} : shape); outputShapes.emplace_back(isScalar ? ngraph::PartialShape{1} : shape);
std::cout << typeStr << " : " << op->get_output_element_type(i) << std::endl;
originalOutputPrecisions.emplace_back(details::convertPrecision(op->get_output_element_type(i))); originalOutputPrecisions.emplace_back(details::convertPrecision(op->get_output_element_type(i)));
} }
} }