diff --git a/models/public/gcn/README.md b/models/public/gcn/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/models/public/gcn/accuracy-check.yml b/models/public/gcn/accuracy-check.yml new file mode 100644 index 00000000000..f1a07e4d994 --- /dev/null +++ b/models/public/gcn/accuracy-check.yml @@ -0,0 +1,24 @@ +models: +- name: GCN + launchers: + - framework: DGL + adapter: node_classification + device: CPU + model: gcn_model.pt + module: GCN.py + module_name: GCN + batch: 32 + output_names: + - logits + + datasets: + - name: Cora + reader: graph(dgl)_reader + data_source: graph.bin + annotation_conversion: + converter: DGL_converter + graph_path: graph.bin + metrics: + - name: node_accuracy_name + type: node_accuracy + reference: 0.778 diff --git a/models/public/gcn/model.yml b/models/public/gcn/model.yml new file mode 100644 index 00000000000..642451e1909 --- /dev/null +++ b/models/public/gcn/model.yml @@ -0,0 +1,22 @@ +description: >- + Tmp +task_type: node_classification +files: + - name: gcn_model.pt + size: 94635 + source: https://raw.githubusercontent.com/itlab-vision/itlab-vision-dl-benchmark-models/main/dgl/models/classification/GCN/gcn_model.pt + checksum: abceacb966cf92ce225e6b7e9b29b1a165e6283f0a780a1617344405d5661bd74ff4e3f7c6d7d0c14fbb44f486d24c2f + - name: GCN.py + size: 736 + source: https://raw.githubusercontent.com/itlab-vision/itlab-vision-dl-benchmark-models/main/dgl/models/classification/GCN/GCN.py + checksum: a8cf92d876d5c4f495c8fc9c0354a1c337e60038c4c35b11959d5e56105c2f85d60378a41ba2436c7176dd9e708f761c + - name: graph.bin + size: 50908 + source: https://raw.githubusercontent.com/itlab-vision/dl-benchmark/master/tests/smoke_test/test_graph/dgl/default_graph.bin + checksum: 7cf6911b0bd1a7dfd1aa5dc03193f3d0b805774ffd5aff6289763902568cbfd742bae87f6110c320e90ded6344850c9d +model_optimizer_args: + - --input=$dl_dir/graph.bin + - --input_model=$dl_dir/gcn_model.pt + - --model_class=$dl_dir/GCN.py +framework: dgl_pytorch +license: https://raw.githubusercontent.com/itlab-vision/itlab-vision-dl-benchmark-models/main/LICENSE diff --git a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py index 3526e3945f1..9acd84c04bb 100644 --- a/tools/accuracy_checker/accuracy_checker/adapters/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/adapters/__init__.py @@ -146,6 +146,8 @@ from .palm_detection import PalmDetectionAdapter +from .graphs import GraphNodeClassificationAdapter + __all__ = [ 'Adapter', 'AdapterField', @@ -298,5 +300,7 @@ 'ImageBackgroundMattingAdapter', - 'PalmDetectionAdapter' + 'PalmDetectionAdapter', + + 'GraphNodeClassificationAdapter' ] diff --git a/tools/accuracy_checker/accuracy_checker/adapters/graphs.py b/tools/accuracy_checker/accuracy_checker/adapters/graphs.py new file mode 100644 index 00000000000..c5a2bf47981 --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/adapters/graphs.py @@ -0,0 +1,84 @@ +import numpy as np + +from ..adapters import Adapter +from ..config import BoolField, StringField, NumberField +from ..representation import ClassificationPrediction, ArgMaxClassificationPrediction +from ..utils import softmax + + +class GraphNodeClassificationAdapter(Adapter): + """ + Class for converting output of node classification model to ClassificationPrediction representation + """ + __provider__ = 'node_classification' + prediction_types = (ClassificationPrediction, ) + + @classmethod + def parameters(cls): + parameters = super().parameters() + + return parameters + + def configure(self): + self.label_as_array = self.get_value_from_config('label_as_array') + self.block = self.get_value_from_config('block') + self.classification_out = self.get_value_from_config('classification_output') + self.multilabel_thresh = self.get_value_from_config('multi_label_threshold') + self.output_verified = False + + def select_output_blob(self, outputs): + self.output_verified = True + if self.classification_out: + self.classification_out = self.check_output_name(self.classification_out, outputs) + return + super().select_output_blob(outputs) + self.classification_out = self.output_blob + return + + def process(self, raw, identifiers, frame_meta): + """ + Args: + identifiers: list of input data identifiers + raw: output of model + frame_meta: list of meta information about each frame + Returns: + list of ClassificationPrediction objects + """ + if not self.output_verified: + self.select_output_blob(raw) + multi_infer = frame_meta[-1].get('multi_infer', False) if frame_meta else False + raw_prediction = self._extract_predictions(raw, frame_meta) # ok + prediction = raw_prediction[self.output_blob] # тензор предиктов + if multi_infer: + prediction = np.mean(prediction, axis=0) + if len(np.shape(prediction)) == 1: + prediction = np.expand_dims(prediction, axis=0) + prediction = np.reshape(prediction, (prediction.shape[0], -1)) + + result = [] + if self.block: + result.append(self.prepare_representation(identifiers[0], prediction)) + else: + for identifier, output in zip(identifiers, prediction): + result.append(self.prepare_representation(identifier, output)) + + return result + + def prepare_representation(self, identifier, prediction): + single_prediction = ClassificationPrediction( + identifier, prediction, self.label_as_array, + multilabel_threshold=self.multilabel_thresh) + return single_prediction + + @staticmethod + def _extract_predictions(outputs_list, meta): + is_multi_infer = meta[-1].get('multi_infer', False) if meta else False + if not is_multi_infer: + return outputs_list[0] if not isinstance(outputs_list, dict) else outputs_list + + output_map = {} + for output_key in outputs_list[0].keys(): + output_data = np.asarray([output[output_key] for output in outputs_list]) + output_map[output_key] = output_data + + return output_map diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/__init__.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/__init__.py index a45cde088ef..2039e6fc9fd 100644 --- a/tools/accuracy_checker/accuracy_checker/annotation_converters/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/__init__.py @@ -19,6 +19,7 @@ from .market1501 import Market1501Converter from .veri776 import VeRi776Converter from .mars import MARSConverter +from .dgl import DGLConverter from .pascal_voc import PascalVOCDetectionConverter, SYGDetectionConverter from .sample_converter import SampleConverter from .wider import WiderFormatConverter @@ -150,6 +151,7 @@ 'SYGDetectionConverter', 'WiderFormatConverter', 'MARSConverter', + 'DGLConverter', 'DetectionOpenCVStorageFormatConverter', 'LFWConverter', 'FaceRecognitionBinary', diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/dgl.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/dgl.py new file mode 100644 index 00000000000..84868039ec8 --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/dgl.py @@ -0,0 +1,27 @@ +import re + +from ._reid_common import check_dirs, read_directory +from .format_converter import GraphFileBasedAnnotationConverter, ConverterReturn +from ..representation import ClassificationAnnotation +from pathlib import Path +import dgl + +MARS_IMAGE_PATTERN = re.compile(r'([\d]+)C(\d)') + + +class DGLConverter(GraphFileBasedAnnotationConverter): + __provider__ = 'DGL_converter' + annotation_types = (ClassificationAnnotation, ) + + def convert(self, check_content=False, **kwargs): + print('run convert') + graph = dgl.data.utils.load_graphs(Path(self.graph_path).__str__()) + g = graph[0][0] + + labels = g.ndata["label"] + + annotation = [ + ClassificationAnnotation(identifier='', label=labels) + ] + + return ConverterReturn(annotation, {'labels': labels}, None) diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/format_converter.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/format_converter.py index 60a319aedbf..a3b7ac6fab5 100644 --- a/tools/accuracy_checker/accuracy_checker/annotation_converters/format_converter.py +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/format_converter.py @@ -145,6 +145,22 @@ def convert(self, check_content=False, **kwargs): pass +class GraphFileBasedAnnotationConverter(BaseFormatConverter): + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters.update({ + 'graph_path': PathField(is_directory=False, description="Path to graph data.") + }) + return parameters + + def configure(self): + self.graph_path = self.get_value_from_config('graph_path') + + def convert(self, check_content=False, **kwargs): + pass + + def verify_label_map(label_map): valid_label_map = {} for class_id, class_name in label_map.items(): diff --git a/tools/accuracy_checker/accuracy_checker/data_readers/__init__.py b/tools/accuracy_checker/accuracy_checker/data_readers/__init__.py index 3df993ea318..830533ae6d9 100644 --- a/tools/accuracy_checker/accuracy_checker/data_readers/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/data_readers/__init__.py @@ -53,6 +53,7 @@ RawpyReader ) from .text_readers import JSONReader +from .dgl_graph_reader import DGLGraphReader __all__ = [ 'BaseReader', @@ -93,6 +94,7 @@ 'LMDBReader', 'KaldiARKReader', 'JSONReader', + 'DGLGraphReader' 'create_reader', 'REQUIRES_ANNOTATIONS', diff --git a/tools/accuracy_checker/accuracy_checker/data_readers/data_reader.py b/tools/accuracy_checker/accuracy_checker/data_readers/data_reader.py index cb6b2a6f3fb..cd100322493 100644 --- a/tools/accuracy_checker/accuracy_checker/data_readers/data_reader.py +++ b/tools/accuracy_checker/accuracy_checker/data_readers/data_reader.py @@ -18,6 +18,7 @@ from collections import OrderedDict, namedtuple from functools import singledispatch from pathlib import Path +import dgl import numpy as np @@ -42,6 +43,8 @@ def __init__(self, data, meta=None, identifier=''): if self.metadata.get('input_is_dict_type'): return + if isinstance(data, dgl.DGLGraph): + return if np.isscalar(data): self.metadata['image_size'] = 1 elif isinstance(data, list) and np.isscalar(data[0]): diff --git a/tools/accuracy_checker/accuracy_checker/data_readers/dgl_graph_reader.py b/tools/accuracy_checker/accuracy_checker/data_readers/dgl_graph_reader.py new file mode 100644 index 00000000000..efc1b5f3ac2 --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/data_readers/dgl_graph_reader.py @@ -0,0 +1,27 @@ +import numpy as np + +from ..config import StringField, ConfigError +from .data_reader import BaseReader +from ..utils import get_path, read_json +from pathlib import Path +import dgl + + +class DGLGraphReader(BaseReader): + __provider__ = 'graph(dgl)_reader' + + def configure(self): + if not self.data_source: + if not self._postpone_data_source: + raise ConfigError('data_source parameter is required to create "{}" ' + 'data reader and read data'.format(self.__provider__)) + else: + self.data_source = get_path(self.data_source, is_directory=False) + + def read(self, data_id): + data_path = self.data_source / data_id if self.data_source is not None else data_id + + graph = dgl.data.utils.load_graphs(Path(data_path).__str__()) + g = graph[0][0] + + return g diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/dgl_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/dgl_evaluator.py new file mode 100644 index 00000000000..799eb9c5be6 --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/dgl_evaluator.py @@ -0,0 +1,38 @@ +from .base_custom_evaluator import BaseCustomEvaluator +from .base_models import BaseCascadeModel, create_model, create_encoder + +class DGLEvaluator(BaseCustomEvaluator): + def __init__(self, dataset_config, launcher, model, orig_config): + super().__init__(dataset_config, launcher, orig_config) + print('create evaluator') + self.model = model + # if hasattr(self.model.decoder, 'adapter'): + # self.adapter_type = self.model.decoder.adapter.__provider__ + + @classmethod + def from_configs(cls, config, delayed_model_loading=False, orig_config=None): + dataset_config, launcher, _ = cls.get_dataset_and_launcher_info(config) + model = DGLGraphModel( + config.get('network_info', {}), launcher, config.get('_models', []), config.get('_model_is_blob'), + delayed_model_loading + ) + return cls(dataset_config, launcher, model, orig_config) + + def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file): + pass + +class DGLGraphModel(BaseCascadeModel): + def __init__(self, network_info, launcher, models_args, is_blob, delayed_model_loading=False): + super().__init__(network_info, launcher) + + def predict(self, identifiers, input_data, encoder_callback=None): + pass + + def reset(self): + pass + + def save_encoder_predictions(self): + pass + + def _add_raw_encoder_predictions(self, encoder_prediction): + pass \ No newline at end of file diff --git a/tools/accuracy_checker/accuracy_checker/launcher/__init__.py b/tools/accuracy_checker/accuracy_checker/launcher/__init__.py index ef8052bee86..0e44e93b104 100644 --- a/tools/accuracy_checker/accuracy_checker/launcher/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/launcher/__init__.py @@ -91,6 +91,13 @@ 'paddle_paddle', "PaddlePaddle isn't installed. Please, install it before using. \n{}".format(import_error.msg) ) +try: + from .dgl_launcher import DGLLauncher +except ImportError as import_error: + DGLLauncher = unsupported_launcher( + 'dgl', "DGL isn't installed. Please, install it before using. \n{}".format(import_error.msg) + ) + from .pytorch_launcher import PyTorchLauncher __all__ = [ @@ -107,5 +114,6 @@ 'PyTorchLauncher', 'PaddlePaddleLauncher', 'DummyLauncher', - 'InputFeeder' + 'InputFeeder', + 'DGLLauncher' ] diff --git a/tools/accuracy_checker/accuracy_checker/launcher/dgl_launcher.py b/tools/accuracy_checker/accuracy_checker/launcher/dgl_launcher.py new file mode 100644 index 00000000000..5925e17ac3a --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/launcher/dgl_launcher.py @@ -0,0 +1,129 @@ +import os +import sys +from collections import OrderedDict + +from ..config import NumberField, StringField, BoolField +from ..config import PathField, StringField, NumberField, BoolField, ConfigError +from .launcher import Launcher +import importlib.util +from pprint import pprint +import torch + +import numpy as np + +class DGLLauncher(Launcher): + __provider__ = 'DGL' + + def __init__(self, config_entry: dict, *args, **kwargs): + super().__init__(config_entry, *args, **kwargs) + try: + import dgl # pylint: disable=C0415 + self._dgl = dgl + except ImportError as import_error: + raise ValueError( + "DGL isn't installed. Please, install it before using. \n{}".format( + import_error.msg + ) + ) + + try: + import torch # pylint: disable=C0415 + self._torch = torch + except ImportError as import_error: + raise ValueError( + "Torch isn't installed. Please, install it before using. \n{}".format( + import_error.msg + ) + ) + + self.validate_config(config_entry) + self.device = self._get_device_to_infer(config_entry.get('device')) # конфиг это параметры launchers из accuracy-check.yml + + self.module = self.load_module( + config_entry.get('model'), + config_entry.get('module'), + config_entry.get('module_name') + ) + + self._batch = self.get_value_from_config('batch') + + self._generate_inputs() + self.output_names = self.get_value_from_config('output_names') or ['output'] + + def _get_device_to_infer(self, device): + if device == 'CPU': + return self._torch.device('cpu') + elif device == 'GPU': + return self._torch.device('cuda') + else: + raise ValueError('The device is not supported') + + def load_module(self, model_path, module_path, module_name): + file_type = model_path.split('.')[-1] + supported_extensions = ['pt'] + if file_type not in supported_extensions: + raise ValueError(f'The file type {file_type} is not supported') + + spec = importlib.util.spec_from_file_location(module_name, module_path) + foo = importlib.util.module_from_spec(spec) + sys.modules[f'{module_name}'] = foo + spec.loader.exec_module(foo) + + import __main__ + setattr(__main__, module_name, getattr(foo, module_name)) + module = self._torch.load(model_path) + module.to(self.device) + module.eval() + + return module + + def _generate_inputs(self): + config_inputs = self.config.get('inputs') + if not config_inputs: + self._inputs = {'input': (self.batch, ) + (-1, ) * 3} + return + input_shapes = OrderedDict() + for input_description in config_inputs: + input_shapes[input_description['name']] = input_description.get('shape', (self.batch, ) + (-1, ) * 3) + self._inputs = input_shapes + + @property + def inputs(self): + return self._inputs + + @property + def batch(self): + return self._batch + + @property + def output_blob(self): + return next(iter(self.output_names)) + + @classmethod + def parameters(cls): + """Добавляем доп параметры для запуска + """ + parameters = super().parameters() + parameters.update({ + 'model': PathField(description="Path to model.", file_or_directory=True), + 'module': StringField(description='Network module for loading'), + 'device': StringField(default='cpu'), + 'module_name': StringField(description='Network module name') + }) + return parameters + + def predict(self, inputs, metadata=None, **kwargs): + input_graph = inputs[0]['input'][0] + features = input_graph.ndata['feat'] + with torch.inference_mode(): + predictions = self.module(input_graph, features).argmax(dim=1) + result = [{ + 'output': predictions + }] + return result + + def release(self): + """ + Releases launcher. + """ + del self.module diff --git a/tools/accuracy_checker/accuracy_checker/metrics/__init__.py b/tools/accuracy_checker/accuracy_checker/metrics/__init__.py index 0484c6e3b02..47affbd6d40 100644 --- a/tools/accuracy_checker/accuracy_checker/metrics/__init__.py +++ b/tools/accuracy_checker/accuracy_checker/metrics/__init__.py @@ -131,6 +131,8 @@ from .clip_score import ClipScore from .matches_homography import MatchesHomography +from .graph import ClassificationGraphAccuracy + __all__ = [ 'Metric', 'MetricsExecutor', @@ -266,5 +268,7 @@ 'MeanSquaredErrorWithMask', 'ClipScore', - 'MatchesHomography' + 'MatchesHomography', + + 'ClassificationGraphAccuracy' ] diff --git a/tools/accuracy_checker/accuracy_checker/metrics/graph.py b/tools/accuracy_checker/accuracy_checker/metrics/graph.py new file mode 100644 index 00000000000..4664267674e --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/metrics/graph.py @@ -0,0 +1,102 @@ +import numpy as np + +from ..representation import ( + ClassificationAnnotation, + ClassificationPrediction, + TextClassificationAnnotation, + UrlClassificationAnnotation, + ArgMaxClassificationPrediction, + AnomalySegmentationAnnotation, + AnomalySegmentationPrediction +) + +from .classification import ClassificationProfilingSummaryHelper + +from ..config import NumberField, StringField, ConfigError, BoolField +from .metric import Metric, PerImageEvaluationMetric +from .average_meter import AverageMeter +from ..utils import UnsupportedPackage + +try: + from sklearn.metrics import accuracy_score, confusion_matrix +except ImportError as import_error: + accuracy_score = UnsupportedPackage("sklearn.metric.accuracy_score", import_error.msg) + confusion_matrix = UnsupportedPackage("sklearn.metric.confusion_matrix", import_error.msg) + + + +class ClassificationGraphAccuracy(PerImageEvaluationMetric): + """ + Class for evaluating accuracy metric of classification models. + """ + + __provider__ = 'node_accuracy' + + annotation_types = (ClassificationAnnotation, TextClassificationAnnotation) + prediction_types = (ClassificationPrediction, ArgMaxClassificationPrediction) + + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters.update({ + 'top_k': NumberField( + value_type=int, min_value=1, optional=True, default=1, + description="The number of classes with the highest probability, which will be used to decide " + "if prediction is correct." + ), + 'match': BoolField(optional=True, default=False), + 'cast_to_int': BoolField(optional=True, default=False) + }) + + return parameters + + def configure(self): + self.top_k = self.get_value_from_config('top_k') + self.match = self.get_value_from_config('match') + self.cast_to_int = self.get_value_from_config('cast_to_int') + self.summary_helper = None + + def loss(annotation_label, prediction_top_k_labels): + return int(annotation_label in prediction_top_k_labels) + + if isinstance(accuracy_score, UnsupportedPackage): + accuracy_score.raise_error(self.__provider__) + self.accuracy = [] + if self.profiler: + self.summary_helper = ClassificationProfilingSummaryHelper() + + def set_profiler(self, profiler): + self.profiler = profiler + self.summary_helper = ClassificationProfilingSummaryHelper() + + def update(self, annotation, prediction): + pred_labels = prediction.scores + + accuracy = accuracy_score(annotation.label, pred_labels) + self.accuracy.append(accuracy) + + if self.profiler: + self.summary_helper.submit_data(annotation.label, prediction.top_k(self.top_k), prediction.scores) + self.profiler.update( + annotation.identifier, annotation.label, prediction.top_k(self.top_k), self.name, accuracy, + prediction.scores + ) + return accuracy + + def evaluate(self, annotations, predictions): + if self.profiler: + self.profiler.finish() + summary = self.summary_helper.get_summary_report() + self.profiler.write_summary(summary) + else: + accuracy = np.mean(self.accuracy) + return accuracy + + def reset(self): + if not self.match: + self.accuracy.reset() + else: + self.accuracy = [] + + if self.profiler: + self.profiler.reset() \ No newline at end of file diff --git a/tools/accuracy_checker/tmp.py b/tools/accuracy_checker/tmp.py new file mode 100644 index 00000000000..118ed025e70 --- /dev/null +++ b/tools/accuracy_checker/tmp.py @@ -0,0 +1,6 @@ +from accuracy_checker.main import main +import sys + +if __name__ == '__main__': + sys.argv.extend(['-c', 'C:\\Users\\Atikin\\Desktop\\Programming\\open_model_zoo\\models\\public\\gcn\\accuracy-check.yml']) + main() \ No newline at end of file diff --git a/tools/model_tools/src/omz_tools/__init__.py b/tools/model_tools/src/omz_tools/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tools/model_tools/src/omz_tools/_common.py b/tools/model_tools/src/omz_tools/_common.py index ee8d30f3554..94aaba7722c 100644 --- a/tools/model_tools/src/omz_tools/_common.py +++ b/tools/model_tools/src/omz_tools/_common.py @@ -38,6 +38,7 @@ 'onnx': None, 'pytorch': 'pytorch_to_onnx.py', 'tf': None, + 'dgl_pytorch': None, } KNOWN_PRECISIONS = { 'FP16', 'FP16-INT1', 'FP16-INT8', @@ -74,6 +75,7 @@ 'time_series', 'token_recognition', 'background_matting', + 'node_classification', } diff --git a/tools/model_tools/src/omz_tools/_configuration.py b/tools/model_tools/src/omz_tools/_configuration.py index 8bcdf62e37f..a016ea9ec47 100644 --- a/tools/model_tools/src/omz_tools/_configuration.py +++ b/tools/model_tools/src/omz_tools/_configuration.py @@ -38,7 +38,7 @@ def __init__(self, name, size, checksum, source): @classmethod def deserialize(cls, file): name = validation.validate_relative_path('"name"', file['name']) - + with validation.deserialization_context('In file "{}"'.format(name)): size = validation.validate_nonnegative_int('"size"', file['size'])