203 lines
9.9 KiB
Python
203 lines
9.9 KiB
Python
"""
|
|
Copyright (C) 2018-2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
from mo.utils.simple_proto_parser import SimpleProtoParser
|
|
|
|
correct_proto_message_1 = 'model { faster_rcnn { num_classes: 90 image_resizer { keep_aspect_ratio_resizer {' \
|
|
' min_dimension: 600 max_dimension: 1024 }}}}'
|
|
|
|
correct_proto_message_2 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
|
|
' 16 scales: 0.25 scales: 0.5 scales: 1.0 scales: 2.0 aspect_ratios: 0.5 aspect_ratios:' \
|
|
' 1.0 aspect_ratios: 2.0}}'
|
|
|
|
correct_proto_message_3 = ' initializer \n{variance_scaling_initializer \n{\nfactor: 1.0 uniform: true bla: false ' \
|
|
'mode: FAN_AVG}}'
|
|
|
|
correct_proto_message_4 = 'train_input_reader {label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"' \
|
|
' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
|
|
|
|
correct_proto_message_5 = ' initializer \n # abc \n{variance_scaling_initializer \n{\nfactor: 1.0 \n # sd ' \
|
|
'\nuniform: true bla: false mode: FAN_AVG}}'
|
|
|
|
correct_proto_message_6 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
|
|
' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: 0.5 aspect_ratios:' \
|
|
' 1.0 aspect_ratios: 2.0}}'
|
|
|
|
correct_proto_message_7 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16 width_stride:' \
|
|
' 16 scales: [ 0.25, 0.5, 1.0, 2.0] aspect_ratios: [] }}'
|
|
|
|
correct_proto_message_8 = 'model {good_list: [3.0, 5.0, ]}'
|
|
|
|
correct_proto_message_9 = ' first_stage_anchor_generator {grid_anchor_generator {height_stride: 16, width_stride:' \
|
|
' 16 scales: [ 0.25, 0.5, 1.0, 2.0], aspect_ratios: [] }}'
|
|
|
|
correct_proto_message_10 = 'train_input_reader {label_map_path: "C:\mscoco_label_map.pbtxt"' \
|
|
' tf_record_input_reader { input_path: "PATH_TO_BE_CONFIGURED/ mscoco_train.record" }}'
|
|
|
|
correct_proto_message_11 = 'model {path: "C:\[{],}" other_value: [1, 2, 3, 4]}'
|
|
|
|
incorrect_proto_message_1 = 'model { bad_no_value }'
|
|
|
|
incorrect_proto_message_2 = 'model { abc: 3 { }'
|
|
|
|
incorrect_proto_message_3 = 'model { too_many_values: 3 4 }'
|
|
|
|
incorrect_proto_message_4 = 'model { missing_values: '
|
|
|
|
incorrect_proto_message_5 = 'model { missing_values: aa bb : }'
|
|
|
|
incorrect_proto_message_6 = 'model : '
|
|
|
|
incorrect_proto_message_7 = 'model : {bad_list: [3.0, 4, , 4.0]}'
|
|
|
|
|
|
class TestingSimpleProtoParser(unittest.TestCase):
|
|
def test_correct_proto_reader_from_string_1(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_1)
|
|
expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
|
|
'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_2(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_2)
|
|
expected_result = {'first_stage_anchor_generator': {
|
|
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
|
|
'aspect_ratios': [0.5, 1.0, 2.0]}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_3(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_3)
|
|
expected_result = {
|
|
'initializer': {
|
|
'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_4(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_4)
|
|
expected_result = {
|
|
'train_input_reader': {'label_map_path': "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt",
|
|
'tf_record_input_reader': {
|
|
'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_comments(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_5)
|
|
expected_result = {
|
|
'initializer': {
|
|
'variance_scaling_initializer': {'factor': 1.0, 'uniform': True, 'bla': False, 'mode': 'FAN_AVG'}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_lists(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_6)
|
|
expected_result = {'first_stage_anchor_generator': {
|
|
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
|
|
'aspect_ratios': [0.5, 1.0, 2.0]}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_empty_list(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_7)
|
|
expected_result = {'first_stage_anchor_generator': {
|
|
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
|
|
'aspect_ratios': []}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_comma_trailing_list(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_8)
|
|
expected_result = {'model': {'good_list': [3.0, 5.0]}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_redundant_commas(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_9)
|
|
expected_result = {'first_stage_anchor_generator': {
|
|
'grid_anchor_generator': {'height_stride': 16, 'width_stride': 16, 'scales': [0.25, 0.5, 1.0, 2.0],
|
|
'aspect_ratios': []}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_windows_path(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_10)
|
|
expected_result = {
|
|
'train_input_reader': {'label_map_path': "C:\mscoco_label_map.pbtxt",
|
|
'tf_record_input_reader': {
|
|
'input_path': "PATH_TO_BE_CONFIGURED/ mscoco_train.record"}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_correct_proto_reader_from_string_with_special_characters_in_string(self):
|
|
result = SimpleProtoParser().parse_from_string(correct_proto_message_11)
|
|
expected_result = {'model': {'path': "C:\[{],}",
|
|
'other_value': [1, 2, 3, 4]}}
|
|
self.assertDictEqual(result, expected_result)
|
|
|
|
def test_incorrect_proto_reader_from_string_1(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_1)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_2(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_2)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_3(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_3)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_4(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_4)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_5(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_5)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_6(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_6)
|
|
self.assertIsNone(result)
|
|
|
|
def test_incorrect_proto_reader_from_string_7(self):
|
|
result = SimpleProtoParser().parse_from_string(incorrect_proto_message_7)
|
|
self.assertIsNone(result)
|
|
|
|
def test_correct_proto_reader_from_file(self):
|
|
file = tempfile.NamedTemporaryFile('wt', delete=False)
|
|
file.write(correct_proto_message_1)
|
|
file_name = file.name
|
|
file.close()
|
|
|
|
result = SimpleProtoParser().parse_file(file_name)
|
|
expected_result = {'model': {'faster_rcnn': {'num_classes': 90, 'image_resizer': {
|
|
'keep_aspect_ratio_resizer': {'min_dimension': 600, 'max_dimension': 1024}}}}}
|
|
self.assertDictEqual(result, expected_result)
|
|
os.unlink(file_name)
|
|
|
|
@unittest.skipIf(sys.platform.startswith("win"), "chmod() on Windows do nor support not writable dir")
|
|
def test_proto_reader_from_non_readable_file(self):
|
|
file = tempfile.NamedTemporaryFile('wt', delete=False)
|
|
file.write(correct_proto_message_1)
|
|
file_name = file.name
|
|
file.close()
|
|
os.chmod(file_name, 0000)
|
|
|
|
result = SimpleProtoParser().parse_file(file_name)
|
|
self.assertIsNone(result)
|
|
os.unlink(file_name)
|
|
|
|
def test_proto_reader_from_non_existing_file(self):
|
|
result = SimpleProtoParser().parse_file('/non/existing/file')
|
|
self.assertIsNone(result)
|