123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Label map utility functions."""
- import logging
- import tensorflow as tf
- from google.protobuf import text_format
- from segment.sheet_resolve.lib.ssd_model.protos import string_int_label_map_pb2
- def _validate_label_map(label_map):
- """Checks if a label map is valid.
- Args:
- label_map: StringIntLabelMap to validate.
- Raises:
- ValueError: if label map is invalid.
- """
- for item in label_map.item:
- if item.id < 0:
- raise ValueError('Label map ids should be >= 0.')
- if (item.id == 0 and item.name != 'background' and
- item.display_name != 'background'):
- raise ValueError('Label map id 0 is reserved for the background label')
- def create_category_index(categories):
- """Creates dictionary of COCO compatible categories keyed by category id.
- Args:
- categories: a list of dicts, each of which has the following keys:
- 'id': (required) an integer id uniquely identifying this category.
- 'name': (required) string representing category name
- e.g., 'cat', 'dog', 'pizza'.
- Returns:
- category_index: a dict containing the same entries as categories, but keyed
- by the 'id' field of each category.
- """
- category_index = {}
- for cat in categories:
- category_index[cat['id']] = cat
- return category_index
- def get_max_label_map_index(label_map):
- """Get maximum index in label map.
- Args:
- label_map: a StringIntLabelMapProto
- Returns:
- an integer
- """
- return max([item.id for item in label_map.item])
- def convert_label_map_to_categories(label_map,
- max_num_classes,
- use_display_name=True):
- """Given label map proto returns categories list compatible with eval.
- This function converts label map proto and returns a list of dicts, each of
- which has the following keys:
- 'id': (required) an integer id uniquely identifying this category.
- 'name': (required) string representing category name
- e.g., 'cat', 'dog', 'pizza'.
- We only allow class into the list if its id-label_id_offset is
- between 0 (inclusive) and max_num_classes (exclusive).
- If there are several items mapping to the same id in the label map,
- we will only keep the first one in the categories list.
- Args:
- label_map: a StringIntLabelMapProto or None. If None, a default categories
- list is created with max_num_classes categories.
- max_num_classes: maximum number of (consecutive) label indices to include.
- use_display_name: (boolean) choose whether to load 'display_name' field as
- category name. If False or if the display_name field does not exist, uses
- 'name' field as category names instead.
- Returns:
- categories: a list of dictionaries representing all possible categories.
- """
- categories = []
- list_of_ids_already_added = []
- if not label_map:
- label_id_offset = 1
- for class_id in range(max_num_classes):
- categories.append({
- 'id': class_id + label_id_offset,
- 'name': 'category_{}'.format(class_id + label_id_offset)
- })
- return categories
- for item in label_map.item:
- if not 0 < item.id <= max_num_classes:
- logging.info(
- 'Ignore item %d since it falls outside of requested '
- 'label range.', item.id)
- continue
- if use_display_name and item.HasField('display_name'):
- name = item.display_name
- else:
- name = item.name
- if item.id not in list_of_ids_already_added:
- list_of_ids_already_added.append(item.id)
- categories.append({'id': item.id, 'name': name})
- return categories
- def load_labelmap(path):
- """Loads label map proto.
- Args:
- path: path to StringIntLabelMap proto text file.
- Returns:
- a StringIntLabelMapProto
- """
- with tf.gfile.GFile(path, 'r') as fid:
- label_map_string = fid.read()
- label_map = string_int_label_map_pb2.StringIntLabelMap()
- try:
- text_format.Merge(label_map_string, label_map)
- except text_format.ParseError:
- label_map.ParseFromString(label_map_string)
- _validate_label_map(label_map)
- return label_map
- def get_label_map_dict(label_map_path,
- use_display_name=False,
- fill_in_gaps_and_background=False):
- """Reads a label map and returns a dictionary of label names to id.
- Args:
- label_map_path: path to StringIntLabelMap proto text file.
- use_display_name: whether to use the label map items' display names as keys.
- fill_in_gaps_and_background: whether to fill in gaps and background with
- respect to the id field in the proto. The id: 0 is reserved for the
- 'background' class and will be added if it is missing. All other missing
- ids in range(1, max(id)) will be added with a dummy class name
- ("class_<id>") if they are missing.
- Returns:
- A dictionary mapping label names to id.
- Raises:
- ValueError: if fill_in_gaps_and_background and label_map has non-integer or
- negative values.
- """
- label_map = load_labelmap(label_map_path)
- label_map_dict = {}
- for item in label_map.item:
- if use_display_name:
- label_map_dict[item.display_name] = item.id
- else:
- label_map_dict[item.name] = item.id
- if fill_in_gaps_and_background:
- values = set(label_map_dict.values())
- if 0 not in values:
- label_map_dict['background'] = 0
- if not all(isinstance(value, int) for value in values):
- raise ValueError('The values in label map must be integers in order to'
- 'fill_in_gaps_and_background.')
- if not all(value >= 0 for value in values):
- raise ValueError('The values in the label map must be positive.')
- if len(values) != max(values) + 1:
- # there are gaps in the labels, fill in gaps.
- for value in range(1, max(values)):
- if value not in values:
- label_map_dict['class_' + str(value)] = value
- return label_map_dict
- def create_categories_from_labelmap(label_map_path, use_display_name=True):
- """Reads a label map and returns categories list compatible with eval.
- This function converts label map proto and returns a list of dicts, each of
- which has the following keys:
- 'id': an integer id uniquely identifying this category.
- 'name': string representing category name e.g., 'cat', 'dog'.
- Args:
- label_map_path: Path to `StringIntLabelMap` proto text file.
- use_display_name: (boolean) choose whether to load 'display_name' field
- as category name. If False or if the display_name field does not exist,
- uses 'name' field as category names instead.
- Returns:
- categories: a list of dictionaries representing all possible categories.
- """
- label_map = load_labelmap(label_map_path)
- max_num_classes = max(item.id for item in label_map.item)
- return convert_label_map_to_categories(label_map, max_num_classes,
- use_display_name)
- def create_category_index_from_labelmap(label_map_path, use_display_name=True):
- """Reads a label map and returns a category index.
- Args:
- label_map_path: Path to `StringIntLabelMap` proto text file.
- use_display_name: (boolean) choose whether to load 'display_name' field
- as category name. If False or if the display_name field does not exist,
- uses 'name' field as category names instead.
- Returns:
- A category index, which is a dictionary that maps integer ids to dicts
- containing categories, e.g.
- {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
- """
- categories = create_categories_from_labelmap(label_map_path, use_display_name)
- return create_category_index(categories)
- def create_class_agnostic_category_index():
- """Creates a category index with a single `object` class."""
- return {1: {'id': 1, 'name': 'object'}}
|