Source code for dioptra_builtins.backend_configs.tensorflow
# This Software (Dioptra) is being made available as a public service by the# National Institute of Standards and Technology (NIST), an Agency of the United# States Department of Commerce. This software was developed in part by employees of# NIST and in part by NIST contractors. Copyright in portions of this software that# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant# to Title 17 United States Code Section 105, works of NIST employees are not# subject to copyright protection in the United States. However, NIST may hold# international copyright in software created by its employees and domestic# copyright (or licensing rights) in portions of software that were assigned or# licensed to NIST. To the extent that NIST holds copyright in this software, it is# being made available under the Creative Commons Attribution 4.0 International# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts# of the software developed or licensed by NIST.## ACCESS THE FULL CC BY 4.0 LICENSE HERE:# https://creativecommons.org/licenses/by/4.0/legalcode"""A task plugin module for initializing and configuring Tensorflow."""from__future__importannotationsimportstructlogfromstructlog.stdlibimportBoundLoggerfromdioptraimportpyplugsfromdioptra.sdk.exceptionsimportTensorflowDependencyErrorfromdioptra.sdk.utilities.decoratorsimportrequire_packageLOGGER:BoundLogger=structlog.stdlib.get_logger()try:importtensorflowastfexceptImportError:# pragma: nocoverLOGGER.warn("Unable to import one or more optional packages, functionality may be reduced",package="tensorflow",)
[docs]@pyplugs.register@require_package("tensorflow",exc_type=TensorflowDependencyError)definit_tensorflow(seed:int)->None:"""Initializes Tensorflow to ensure reproducibility. This task plugin **must** be run before any other features from Tensorflow are used to ensure reproducibility. Args: seed: The seed to use for Tensorflow's random number generator. """tf.random.set_seed(seed)