Source code for dioptra.rq.tasks.run_mlflow

# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
import os
import shlex
import subprocess
from pathlib import Path
from subprocess import CompletedProcess
from tempfile import TemporaryDirectory
from typing import List, Optional

import boto3
import structlog
from botocore.client import BaseClient
from rq.job import Job as RQJob
from rq.job import get_current_job
from structlog.stdlib import BoundLogger

from dioptra.sdk.utilities.s3.uri import s3_uri_to_bucket_prefix
from dioptra.worker.s3_download import s3_download

LOGGER: BoundLogger = structlog.stdlib.get_logger()


def _download_workflow(s3: BaseClient, dest_dir: str, workflow_uri: str):
    """
    Download the file pointed to by workflow_uri, to directory dest_dir.
    Directory structure implied by workflow_uri is not mirrored in the local
    filesystem.

    Args:
        s3: A boto3 S3 client object
        dest_dir: A directory, as a str, which must already exist
        workflow_uri: An S3 URI referring to a file
    """
    bucket, key = s3_uri_to_bucket_prefix(workflow_uri)
    dest_file = Path(key).name  # get the last path component (a filename)
    dest_path = Path(dest_dir) / dest_file

    s3.download_file(bucket, key, str(dest_path))


[docs]def run_mlflow_task( workflow_uri: str, entry_point: str, experiment_id: str, conda_env: str = "base", entry_point_kwargs: Optional[str] = None, s3: Optional[BaseClient] = None, ) -> CompletedProcess: mlflow_s3_endpoint_url = os.getenv("MLFLOW_S3_ENDPOINT_URL") dioptra_plugins_s3_uri = os.getenv("DIOPTRA_PLUGINS_S3_URI") dioptra_custom_plugins_s3_uri = os.getenv("DIOPTRA_CUSTOM_PLUGINS_S3_URI") dioptra_plugin_dir = os.getenv("DIOPTRA_PLUGIN_DIR") # For mypy; assume correct environment variables assert mlflow_s3_endpoint_url assert dioptra_plugins_s3_uri assert dioptra_custom_plugins_s3_uri assert dioptra_plugin_dir if not s3: s3 = boto3.client("s3", endpoint_url=mlflow_s3_endpoint_url) cmd: List[str] = [ "/usr/local/bin/run-mlflow-job.sh", "--s3-workflow", workflow_uri, "--entry-point", entry_point, "--conda-env", conda_env, "--experiment-id", experiment_id, ] env = os.environ.copy() rq_job: Optional[RQJob] = get_current_job() if rq_job is not None: env["DIOPTRA_RQ_JOB_ID"] = rq_job.get_id() log: BoundLogger = LOGGER.new(rq_job_id=env.get("DIOPTRA_RQ_JOB_ID")) if entry_point_kwargs is not None: cmd.extend(shlex.split(entry_point_kwargs)) with TemporaryDirectory(dir=os.getenv("DIOPTRA_WORKDIR")) as tmpdir: log.info("Downloading workflow: %s", workflow_uri) _download_workflow(s3, tmpdir, workflow_uri) log.info("Downloading plugins") s3_download( s3, dioptra_plugin_dir, True, True, dioptra_plugins_s3_uri, dioptra_custom_plugins_s3_uri, ) log.info("Executing MLFlow job", cmd=" ".join(cmd)) p = subprocess.run(args=cmd, cwd=tmpdir, env=env) if p.returncode > 0: log.warning( "MLFlow job stopped unexpectedly", returncode=p.returncode, stderr=p.stderr ) return p