Source code for dioptra_builtins.registry.art
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""A task plugin module for interfacing the |ART| with the MLFlow model registry.
.. |ART| replace:: `Adversarial Robustness Toolbox\
<https://adversarial-robustness-toolbox.readthedocs.io/en/latest/>`__
"""
from __future__ import annotations
from typing import Any, Dict, Optional
import numpy as np
import structlog
from structlog.stdlib import BoundLogger
from dioptra import pyplugs
from dioptra.sdk.exceptions import ARTDependencyError, TensorflowDependencyError
from dioptra.sdk.utilities.decorators import require_package
from .mlflow import load_tensorflow_keras_classifier
LOGGER: BoundLogger = structlog.stdlib.get_logger()
try:
from art.estimators.classification import TensorFlowV2Classifier
except ImportError: # pragma: nocover
LOGGER.warn(
"Unable to import one or more optional packages, functionality may be reduced",
package="art",
)
try:
from tensorflow.keras import losses
from tensorflow.keras.models import Sequential
except ImportError: # pragma: nocover
LOGGER.warn(
"Unable to import one or more optional packages, functionality may be reduced",
package="tensorflow",
)
[docs]@pyplugs.register
@require_package("art", exc_type=ARTDependencyError)
@require_package("tensorflow", exc_type=TensorflowDependencyError)
def load_wrapped_tensorflow_keras_classifier(
name: str,
version: int,
imagenet_preprocessing: bool = False,
classifier_kwargs: Optional[Dict[str, Any]] = None,
) -> TensorFlowV2Classifier:
"""Loads and wraps a registered Keras classifier for compatibility with the |ART|.
Args:
name: The name of the registered model in the MLFlow model registry.
version: The version number of the registered model in the MLFlow registry.
classifier_kwargs: A dictionary mapping argument names to values which will
be passed to the TensorFlowV2Classifier constructor.
Returns:
A trained :py:class:`~art.estimators.classification.TensorFlowV2Classifier`
object.
See Also:
- :py:class:`art.estimators.classification.TensorFlowV2Classifier`
- :py:func:`.mlflow.load_tensorflow_keras_classifier`
"""
classifier_kwargs = classifier_kwargs or {}
keras_classifier: Sequential = load_tensorflow_keras_classifier(
name=name, version=version
)
nb_classes = keras_classifier.output_shape[1]
input_shape = keras_classifier.input_shape
loss_object = losses.get(keras_classifier.loss)
preprocessing = (
(np.array([103.939, 116.779, 123.680]), np.array([1.0, 1.0, 1.0]))
if imagenet_preprocessing
else None
)
wrapped_keras_classifier: TensorFlowV2Classifier = TensorFlowV2Classifier(
model=keras_classifier,
nb_classes=nb_classes,
input_shape=input_shape,
loss_object=loss_object,
preprocessing=preprocessing,
**classifier_kwargs,
)
LOGGER.info(
"Wrap Keras classifier for compatibility with Adversarial Robustness Toolbox"
)
return wrapped_keras_classifier