Source code for dioptra_builtins.metrics.distance

# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""A task plugin module for getting functions from a distance metric registry.

.. |Linf| replace:: L\\ :sub:`∞`
.. |L1| replace:: L\\ :sub:`1`
.. |L2| replace:: L\\ :sub:`2`
"""

from __future__ import annotations

from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import structlog
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import paired_distances
from structlog.stdlib import BoundLogger

from dioptra import pyplugs

from .exceptions import UnknownDistanceMetricError

LOGGER: BoundLogger = structlog.stdlib.get_logger()


[docs]@pyplugs.register def get_distance_metric_list( request: List[Dict[str, str]] ) -> List[Tuple[str, Callable[..., np.ndarray]]]: """Gets multiple distance metric functions from the registry. The following metrics are available in the registry, - `l_inf_norm` - `l_1_norm` - `l_2_norm` - `paired_cosine_similarities` - `paired_euclidean_distances` - `paired_manhattan_distances` - `paired_wasserstein_distances` Args: request: A list of dictionaries with the keys `name` and `func`. The `func` key is used to lookup the metric function in the registry and must match one of the metric names listed above. The `name` key is human-readable label for the metric function. Returns: A list of tuples with two elements. The first element of each tuple is the label from the `name` key of `request`, and the second element is the callable metric function. """ distance_metrics_list: List[Tuple[str, Callable[..., np.ndarray]]] = [] for metric in request: metric_callable: Optional[Callable[..., np.ndarray]] = ( DISTANCE_METRICS_REGISTRY.get(metric["func"]) ) if metric_callable is not None: distance_metrics_list.append((metric["name"], metric_callable)) else: LOGGER.warn( "Distance metric not in registry, skipping...", name=metric["name"], func=metric["func"], ) return distance_metrics_list
[docs]@pyplugs.register def get_distance_metric(func: str) -> Callable[..., np.ndarray]: """Gets a distance metric function from the registry. The following metrics are available in the registry, - `l_inf_norm` - `l_1_norm` - `l_2_norm` - `paired_cosine_similarities` - `paired_euclidean_distances` - `paired_manhattan_distances` - `paired_wasserstein_distances` Args: func: A string that identifies the distance metric to return from the registry. The string must match one of the names of the metrics in the registry. Returns: A callable distance metric function. """ metric_callable: Optional[Callable[..., np.ndarray]] = ( DISTANCE_METRICS_REGISTRY.get(func) ) if metric_callable is None: LOGGER.error( "Distance metric not in registry", func=func, ) raise UnknownDistanceMetricError( f"Could not find any distance metric named {func!r} in the metrics " "plugin collection. Check spelling and try again." ) return metric_callable
[docs]def l_inf_norm(y_true, y_pred) -> np.ndarray: """Calculates the |Linf| norm between a batch of two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of |Linf| norms. """ metric: np.ndarray = _matrix_difference_l_norm( y_true=y_true, y_pred=y_pred, order=np.inf ) return metric
[docs]def l_1_norm(y_true, y_pred) -> np.ndarray: """Calculates the |L1| norm between a batch of two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of |L1| norms. """ metric: np.ndarray = _matrix_difference_l_norm( y_true=y_true, y_pred=y_pred, order=1 ) return metric
[docs]def l_2_norm(y_true, y_pred) -> np.ndarray: """Calculates the |L2| norm between a batch of two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of |L2| norms. """ metric: np.ndarray = _matrix_difference_l_norm( y_true=y_true, y_pred=y_pred, order=2 ) return metric
[docs]def paired_cosine_similarities(y_true, y_pred) -> np.ndarray: """Calculates the cosine similarity between a batch of two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of cosine similarities. """ y_true_normalized: np.ndarray = _normalize_batch(_flatten_batch(y_true), order=2) y_pred_normalized: np.ndarray = _normalize_batch(_flatten_batch(y_pred), order=2) metric: np.ndarray = np.sum(y_true_normalized * y_pred_normalized, axis=1) return metric
[docs]def paired_euclidean_distances(y_true, y_pred) -> np.ndarray: """Calculates the Euclidean distance between a batch of two matrices. The Euclidean distance is equivalent to the |L2| norm. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of euclidean distances. """ metric: np.ndarray = l_2_norm(y_true=y_true, y_pred=y_pred) return metric
[docs]def paired_manhattan_distances(y_true, y_pred) -> np.ndarray: """Calculates the Manhattan distance between a batch of two matrices. The Manhattan distance is equivalent to the |L1| norm. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of Manhattan distances. """ metric: np.ndarray = l_1_norm(y_true=y_true, y_pred=y_pred) return metric
[docs]def paired_wasserstein_distances(y_true, y_pred, **kwargs) -> np.ndarray: """Calculates the Wasserstein distance between a batch of two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. Returns: A :py:class:`numpy.ndarray` containing a batch of Wasserstein distances. See Also: - :py:func:`scipy.stats.wasserstein_distance` """ def wrapped_metric(X, Y): return wasserstein_distance(u_values=X, v_values=Y, **kwargs) metric: np.ndarray = paired_distances( X=_flatten_batch(y_true), Y=_flatten_batch(y_pred), metric=wrapped_metric ) return metric
def _flatten_batch(X: np.ndarray) -> np.ndarray: """Flattens each of the matrices in a batch into a one-dimensional array. Args: X: A batch of matrices. Returns: A :py:class:`numpy.ndarray` containing a batch of one-dimensional arrays. """ num_samples: int = X.shape[0] num_matrix_elements: int = int(np.prod(X.shape[1:])) return X.reshape((num_samples, num_matrix_elements)) def _matrix_difference_l_norm(y_true, y_pred, order) -> np.ndarray: """Calculates a batch of norms of the difference between two matrices. Args: y_true: A batch of matrices containing the original or target values. y_pred: A batch of matrices containing the perturbed or predicted values. order: The order of the norm, see :py:func:`numpy.linalg.norm` for the full list of norms that can be calculated. Returns: A :py:class:`numpy.ndarray` containing a batch of norms. See Also: - :py:func:`numpy.linalg.norm` """ y_diff: np.ndarray = _flatten_batch(y_true - y_pred) y_diff_l_norm: np.ndarray = np.linalg.norm(y_diff, axis=1, ord=order) return y_diff_l_norm def _normalize_batch(X: np.ndarray, order: int) -> np.ndarray: """Normalizes a batch of matrices by their norms. Args: X: A batch of matrices to be normalized. order: The order of the norm used for normalization, see :py:func:`numpy.linalg.norm` for the full list of available norms. Returns: A :py:class:`numpy.ndarray` containing a batch of normalized matrices. See Also: - :py:func:`numpy.linalg.norm` """ X_l_norm: np.ndarray = np.linalg.norm(X, axis=1, ord=order) num_samples: int = X_l_norm.shape[0] normalized_batch: np.ndarray = X / X_l_norm.reshape((num_samples, 1)) return normalized_batch DISTANCE_METRICS_REGISTRY: Dict[str, Callable[..., Any]] = dict( l_inf_norm=l_inf_norm, l_1_norm=l_1_norm, l_2_norm=l_2_norm, paired_cosine_similarities=paired_cosine_similarities, paired_euclidean_distances=paired_euclidean_distances, paired_manhattan_distances=paired_manhattan_distances, paired_wasserstein_distances=paired_wasserstein_distances, )