Files
openvino/model-optimizer/mo/utils/simple_proto_parser_test.py
2020-02-11 22:48:49 +03:00

203 lines
9.9 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
import tempfile
import unittest
from mo.utils.simple_proto_parser import SimpleProtoParser
correct_proto_message_1 = 'model { faster_rcnn { num_classes: 90 image_resizer { keep_aspect_ratio_resizer {' \
' min_dimension: 600 max_dimension: 1024 }}}}'
correct_proto_message_2 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
' 16 scales: 0.25 scales: 0.5 scales: 1.0 scales: 2.0 aspect_ratios: 0.5 aspect_ratios:' \
' 1.0 aspect_ratios: 2.0}}'
correct_proto_message_3 = ' initializer \n{variance_scaling_initializer \n{\nfactor: 1.0 uniform: true bla: false ' \
'mode: FAN_AVG}}'
correct_proto_message_4 = 'train_input_reader {label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"' \
' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
correct_proto_message_5 = ' initializer \n # abc \n{variance_scaling_initializer \n{\nfactor: 1.0 \n # sd ' \
'\nuniform: true bla: false mode: FAN_AVG}}'
correct_proto_message_6 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: 0.5 aspect_ratios:' \
' 1.0 aspect_ratios: 2.0}}'
correct_proto_message_7 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: [] }}'
correct_proto_message_8 = 'model {good_list: [3.0, 5.0, ]}'
correct_proto_message_9 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16, width_stride:' \
' 16 scales: [ 0.25, 0.5, 1.0, 2.0], aspect_ratios: [] }}'
correct_proto_message_10 = 'train_input_reader {label_map_path: "C:\mscoco_label_map.pbtxt"' \
' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
correct_proto_message_11 = 'model {path: "C:\[{],}" other_value: [1, 2, 3, 4]}'
incorrect_proto_message_1 = 'model { bad_no_value }'
incorrect_proto_message_2 = 'model { abc: 3 { }'
incorrect_proto_message_3 = 'model { too_many_values: 3 4 }'
incorrect_proto_message_4 = 'model { missing_values: '
incorrect_proto_message_5 = 'model { missing_values: aa bb : }'
incorrect_proto_message_6 = 'model : '
incorrect_proto_message_7 = 'model : {bad_list: [3.0, 4, , 4.0]}'
class TestingSimpleProtoParser(unittest.TestCase):
def test_correct_proto_reader_from_string_1(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_2(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
expected_result = {'first_stage_anchor_generator': {
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
'aspect_ratios': [0.5, 1.0, 2.0]}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_3(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_3)
expected_result = {
'initializer': {
'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_4(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_4)
expected_result = {
'train_input_reader': {'label_map_path': "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt",
'tf_record_input_reader': {
'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_comments(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_5)
expected_result = {
'initializer': {
'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_lists(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_6)
expected_result = {'first_stage_anchor_generator': {
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
'aspect_ratios': [0.5, 1.0, 2.0]}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_empty_list(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_7)
expected_result = {'first_stage_anchor_generator': {
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
'aspect_ratios': []}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
expected_result = {'model': {'good_list': [3.0, 5.0]}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_redundant_commas(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_9)
expected_result = {'first_stage_anchor_generator': {
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
'aspect_ratios': []}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_windows_path(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_10)
expected_result = {
'train_input_reader': {'label_map_path': "C:\mscoco_label_map.pbtxt",
'tf_record_input_reader': {
'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
self.assertDictEqual(result, expected_result)
def test_correct_proto_reader_from_string_with_special_characters_in_string(self):
result = SimpleProtoParser().parse_from_string(correct_proto_message_11)
expected_result = {'model': {'path': "C:\[{],}",
'other_value': [1, 2, 3, 4]}}
self.assertDictEqual(result, expected_result)
def test_incorrect_proto_reader_from_string_1(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_1)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_2(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_2)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_3(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_3)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_4(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_4)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_5(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_5)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_6(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_6)
self.assertIsNone(result)
def test_incorrect_proto_reader_from_string_7(self):
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_7)
self.assertIsNone(result)
def test_correct_proto_reader_from_file(self):
file = tempfile.NamedTemporaryFile('wt', delete=False)
file.write(correct_proto_message_1)
file_name = file.name
file.close()
result = SimpleProtoParser().parse_file(file_name)
expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
self.assertDictEqual(result, expected_result)
os.unlink(file_name)
@unittest.skipIf(sys.platform.startswith("win"), "chmod() on Windows do nor support not writable dir")
def test_proto_reader_from_non_readable_file(self):
file = tempfile.NamedTemporaryFile('wt', delete=False)
file.write(correct_proto_message_1)
file_name = file.name
file.close()
os.chmod(file_name, 0000)
result = SimpleProtoParser().parse_file(file_name)
self.assertIsNone(result)
os.unlink(file_name)
def test_proto_reader_from_non_existing_file(self):
result = SimpleProtoParser().parse_file('/non/existing/file')
self.assertIsNone(result)