Source code for dioptra_builtins.estimators.methods

# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
from __future__ import annotations

import datetime
from typing import Any, Dict, Optional

import mlflow
import structlog
from structlog.stdlib import BoundLogger

from dioptra import pyplugs
from dioptra.sdk.generics import estimator_predict, fit_estimator

LOGGER: BoundLogger = structlog.stdlib.get_logger()


[docs]@pyplugs.register def fit( estimator: Any, x: Any = None, y: Any = None, fit_kwargs: Optional[Dict[str, Any]] = None, ) -> Any: """Fits the estimator to the given data. This task plugin wraps :py:func:`~dioptra.sdk.generics.fit_estimator`, which is a generic function that uses multiple argument dispatch to handle the estimator fitting method for different machine learning libraries. The modules attached to the advertised plugin entry point `dioptra.generics.fit_estimator` are used to build the function dispatch registry at runtime. For more information on the supported fitting methods and `fit_kwargs` arguments, please refer to the documentation of the registered dispatch functions. Args: estimator: The model to be trained. x: The input data to be used for training. y: The target data to be used for training. fit_kwargs: An optional dictionary of keyword arguments to pass to the dispatched function. Returns: The object returned by the estimator's fitting function. For further details on the type of object this method can return, see the documentation for the registered dispatch functions. See Also: - :py:func:`dioptra.sdk.generics.fit_estimator` """ fit_kwargs = fit_kwargs or {} time_start: datetime.datetime = datetime.datetime.now() LOGGER.info( "Begin estimator fit", timestamp=time_start.isoformat(), ) estimator_fit_result: Any = fit_estimator(estimator, x, y, **fit_kwargs) time_end: datetime.datetime = datetime.datetime.now() total_seconds: float = (time_end - time_start).total_seconds() total_minutes: float = total_seconds / 60 mlflow.log_metric("training_time_in_minutes", total_minutes) LOGGER.info( "Estimator fit complete", timestamp=time_end.isoformat(), total_minutes=total_minutes, ) return estimator_fit_result
[docs]@pyplugs.register def predict( estimator: Any, x: Any = None, predict_kwargs: Optional[Dict[str, Any]] = None, ) -> Any: """Uses the estimator to make predictions on the given input data. This task plugin wraps :py:func:`~dioptra.sdk.generics.estimator_predict`, which is a generic function that uses multiple argument dispatch to handle estimator prediction methods for different machine learning libraries. The modules attached to the advertised plugin entry point `dioptra.generics.estimator_predict` are used to build the function dispatch registry at runtime. For more information on the supported prediction methods and `predict_kwargs` arguments, refer to the documentation of the registered dispatch functions. Args: estimator: A trained model to be used to generate predictions. x: The input data for which to generate predictions. predict_kwargs: An optional dictionary of keyword arguments to pass to the dispatched function. Returns: The object returned by the estimator's predict function. For further details on the type of object this method can return, see the documentation for the registered dispatch functions. See Also: - :py:func:`dioptra.sdk.generics.estimator_predict` """ predict_kwargs = predict_kwargs or {} prediction: Any = estimator_predict(estimator, x, **predict_kwargs) return prediction