Source code for dioptra_builtins.backend_configs.tensorflow

# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""A task plugin module for initializing and configuring Tensorflow."""

from __future__ import annotations

import structlog
from structlog.stdlib import BoundLogger

from dioptra import pyplugs
from dioptra.sdk.exceptions import TensorflowDependencyError
from dioptra.sdk.utilities.decorators import require_package

LOGGER: BoundLogger = structlog.stdlib.get_logger()


try:
    import tensorflow as tf

except ImportError:  # pragma: nocover
    LOGGER.warn(
        "Unable to import one or more optional packages, functionality may be reduced",
        package="tensorflow",
    )


[docs]@pyplugs.register @require_package("tensorflow", exc_type=TensorflowDependencyError) def init_tensorflow(seed: int) -> None: """Initializes Tensorflow to ensure compatibility and reproducibility. This task plugin **must** be run before any other features from Tensorflow are used. It disables Tensorflow's eager execution, which is not compatible with Dioptra's entry point structure, and sets Tensorflow's internal seed for its random number generator. Args: seed: The seed to use for Tensorflow's random number generator. """ tf.compat.v1.disable_eager_execution() tf.random.set_seed(seed)