Source code for AFL.automation.APIServer.DriverWebAppsMixin

import datetime
import io
import json
import pathlib
from collections import defaultdict

from flask import render_template
from tiled.client import from_uri
from tiled.queries import Contains, In


[docs] class DriverWebAppsMixin: TILED_RUN_DOCUMENTS_NODE = 'run_documents'
[docs] def tiled_browser(self, **kwargs): """Serve the Tiled database browser HTML interface.""" return render_template('tiled_browser/tiled_browser.html')
[docs] def tiled_plot(self, **kwargs): """Serve the Tiled plotting interface for selected entries.""" return render_template('tiled_browser/tiled_plot.html')
[docs] def tiled_gantt(self, **kwargs): """Serve the Tiled Gantt chart interface for selected entries.""" return render_template('tiled_browser/tiled_gantt.html')
def _read_tiled_config(self): """Internal helper to read Tiled config from ~/.afl/config.json. Returns: dict with status and config values or error message """ config_path = pathlib.Path.home() / '.afl' / 'config.json' if not config_path.exists(): return { 'status': 'error', 'message': 'Config file not found at ~/.afl/config.json. Please create this file with tiled_server and tiled_api_key settings.' } try: with open(config_path, 'r') as f: config_data = json.load(f) except (json.JSONDecodeError, ValueError) as e: return { 'status': 'error', 'message': f'Invalid JSON in config file: {str(e)}' } # Search through config entries (newest first) to find tiled settings if not config_data: return { 'status': 'error', 'message': 'Config file is empty.' } # Try entries in reverse sorted order to find one with tiled config # Use datetime parsing to properly sort date keys (format: YY/DD/MM HH:MM:SS.ffffff) datetime_key_format = '%y/%d/%m %H:%M:%S.%f' try: keys = sorted( config_data.keys(), key=lambda k: datetime.datetime.strptime(k, datetime_key_format), reverse=True ) except ValueError: # Fallback to lexicographic sort if datetime parsing fails keys = sorted(config_data.keys(), reverse=True) tiled_server = '' tiled_api_key = '' for key in keys: entry = config_data[key] if isinstance(entry, dict): server = entry.get('tiled_server', '') api_key = entry.get('tiled_api_key', '') if server and api_key: tiled_server = server tiled_api_key = api_key break if not tiled_server: return { 'status': 'error', 'message': 'tiled_server not configured in ~/.afl/config.json. Please add a tiled_server URL to your config.' } if not tiled_api_key: return { 'status': 'error', 'message': 'tiled_api_key not configured in ~/.afl/config.json. Please add your Tiled API key to the config.' } return { 'status': 'success', 'tiled_server': tiled_server, 'tiled_api_key': tiled_api_key }
[docs] def tiled_config(self, **kwargs): """Return Tiled server configuration from shared config file. Reads tiled_server and tiled_api_key from ~/.afl/config.json. Returns dict with status and config values or helpful error message. """ return self._read_tiled_config()
[docs] def tiled_upload_dataset( self, dataset=None, upload_bytes=None, filename='', file_format='', coordinate_column='', metadata=None, delimiter='', comment_prefix='', last_comment_as_header='', **kwargs, ): """Upload a dataset to Tiled from xarray, NetCDF bytes, CSV, TSV, or DAT. Args: dataset: Optional xarray.Dataset provided directly by Python callers. upload_bytes: Optional bytes payload for file uploads. filename: Original upload filename (used for format inference). file_format: Optional explicit format ('xarray', 'nc', 'csv', 'tsv', 'dat'). coordinate_column: Optional column name used to populate sample coordinate. metadata: Optional dict of metadata merged into dataset attrs. delimiter: Optional delimiter override for table formats. comment_prefix: Optional table comment prefix (e.g. '#'). last_comment_as_header: If truthy, use the last comment row as headers. Returns: dict with status/message and dataset summary. """ import numpy as np import pandas as pd import xarray as xr from tiled.client.xarray import write_xarray_dataset queued_time = datetime.datetime.now() start_time = datetime.datetime.now() client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': return client normalized_metadata = {} if isinstance(metadata, dict): normalized_metadata.update(metadata) elif isinstance(metadata, str) and metadata.strip(): try: parsed = json.loads(metadata) if isinstance(parsed, dict): normalized_metadata.update(parsed) except Exception: return { 'status': 'error', 'message': 'Invalid metadata JSON payload.', } # Merge kwargs metadata fields, excluding routing/internal keys. excluded_keys = { 'dataset', 'upload_bytes', 'filename', 'file_format', 'coordinate_column', 'metadata', 'delimiter', 'comment_prefix', 'last_comment_as_header', 'self', } for key, value in kwargs.items(): if key in excluded_keys: continue if value is None: continue if isinstance(value, str) and value == '': continue normalized_metadata[key] = value inferred_format = (file_format or '').strip().lower() filename_lower = (filename or '').strip().lower() if not inferred_format and filename_lower: if filename_lower.endswith('.nc'): inferred_format = 'nc' elif filename_lower.endswith('.csv'): inferred_format = 'csv' elif filename_lower.endswith('.tsv'): inferred_format = 'tsv' elif filename_lower.endswith('.dat'): inferred_format = 'dat' if dataset is not None: if not isinstance(dataset, xr.Dataset): return { 'status': 'error', 'message': 'Provided dataset is not an xarray.Dataset.', } dataset_to_write = dataset.copy(deep=True) else: if upload_bytes is None: return { 'status': 'error', 'message': 'No dataset object or file payload provided.', } if not inferred_format: return { 'status': 'error', 'message': 'Could not infer upload format. Please provide file_format.', } if inferred_format in ('xarray', 'nc', 'netcdf'): try: dataset_to_write = xr.open_dataset(io.BytesIO(upload_bytes)).load() except Exception as exc: return { 'status': 'error', 'message': f'Failed to read NetCDF upload: {str(exc)}', } elif inferred_format in ('csv', 'tsv', 'dat'): delimiter_token = (delimiter or '').strip().lower() normalized_comment_prefix = '#' if comment_prefix is None else str(comment_prefix).strip() header_from_last_comment = str(last_comment_as_header).strip().lower() in ( '1', 'true', 't', 'yes', 'y', 'on' ) try: text = upload_bytes.decode('utf-8-sig') nonempty_lines = [line for line in text.splitlines() if line.strip()] comment_header_line = '' if normalized_comment_prefix: noncomment_lines = [] comment_lines = [] for line in nonempty_lines: stripped = line.lstrip() if stripped.startswith(normalized_comment_prefix): comment_lines.append(stripped[len(normalized_comment_prefix):].strip()) else: noncomment_lines.append(line) nonempty_lines = noncomment_lines if header_from_last_comment and comment_lines: comment_header_line = comment_lines[-1] def _resolve_table_separator(): if delimiter_token: if delimiter_token in ('whitespace', 'space', r'\s+', 'ws'): return (r'\s+', 'python') return (delimiter, None) if inferred_format == 'csv': return (',', None) # TSV/DAT defaults to tab, but support whitespace-delimited text. first = comment_header_line or (nonempty_lines[0] if nonempty_lines else '') sample_data_lines = nonempty_lines[1:11] if len(nonempty_lines) > 1 else [] # Count actual tab usage in data rows, not just header row. data_lines_with_tabs = sum(1 for line in sample_data_lines if '\t' in line) if '\t' in first and data_lines_with_tabs > 0: return ('\t', None) # Support mixed files where header may be tab-separated # but data rows are whitespace-separated. probe_lines = sample_data_lines if sample_data_lines else [first] if any(len(line.split()) > 1 for line in probe_lines): return (r'\s+', 'python') return ('\t', None) separator, parser_engine = _resolve_table_separator() read_text = '\n'.join(nonempty_lines) if comment_header_line: read_text = f'{comment_header_line}\n{read_text}' if read_text else comment_header_line # Read as raw strings first to avoid pandas NA coercion turning # valid coordinate tokens into NaN. df = pd.read_csv( io.StringIO(read_text), sep=separator, engine=parser_engine, dtype=str, keep_default_na=False, ) except Exception as exc: return { 'status': 'error', 'message': f'Failed to parse table upload: {str(exc)}', } if df.empty: return { 'status': 'error', 'message': 'Table upload contains no rows.', } if coordinate_column and coordinate_column not in df.columns: return { 'status': 'error', 'message': f'Coordinate column "{coordinate_column}" is not in uploaded table headers.', } # Normalize whitespace in column names and cell contents. df.columns = [str(col).strip() for col in df.columns] for column in df.columns: df[column] = df[column].map(lambda x: x.strip() if isinstance(x, str) else x) # Best-effort numeric coercion that preserves non-empty strings. def _coerce_series(series): numeric = pd.to_numeric(series, errors='coerce') nonempty_mask = series.astype(str).str.strip() != '' # Use numeric values when at least one non-empty value parsed and # not all parsed values are NaN. if nonempty_mask.any() and not numeric[nonempty_mask].isna().all(): return numeric return series for column in df.columns: df[column] = _coerce_series(df[column]) dim_name = coordinate_column if coordinate_column else 'sample' coords = {dim_name: (df[coordinate_column].to_numpy() if coordinate_column else np.arange(len(df)))} data_vars = {} for column in df.columns: if coordinate_column and column == coordinate_column: continue data_vars[column] = ((dim_name,), df[column].to_numpy()) dataset_to_write = xr.Dataset(data_vars=data_vars, coords=coords) else: return { 'status': 'error', 'message': f'Unsupported file format "{inferred_format}". Supported formats: nc, csv, tsv, dat.', } # Tiled+dask cannot auto-rechunk object dtype arrays. Coerce object # variables/coordinates to strings for robust uploads. for var_name in list(dataset_to_write.data_vars.keys()): var = dataset_to_write[var_name] if getattr(var.dtype, 'kind', None) == 'O': dataset_to_write[var_name] = var.astype(str) for coord_name in list(dataset_to_write.coords.keys()): coord = dataset_to_write.coords[coord_name] if getattr(coord.dtype, 'kind', None) == 'O': dataset_to_write = dataset_to_write.assign_coords({coord_name: coord.astype(str)}) if not hasattr(dataset_to_write, 'attrs') or dataset_to_write.attrs is None: dataset_to_write.attrs = {} dataset_to_write.attrs.update(normalized_metadata) end_time = datetime.datetime.now() run_time = end_time - start_time generated_meta = { 'queued': queued_time.strftime('%m/%d/%y %H:%M:%S-%f %Z%z'), 'started': start_time.strftime('%m/%d/%y %H:%M:%S-%f %Z%z'), 'ended': end_time.strftime('%m/%d/%y %H:%M:%S-%f %Z%z'), 'run_time_seconds': run_time.seconds, 'run_time_minutes': run_time.seconds / 60, 'exit_state': 'Success!', 'return_val': 'xarray.Dataset', } if not hasattr(dataset_to_write, 'attrs') or dataset_to_write.attrs is None: dataset_to_write.attrs = {} existing_meta = {} if isinstance(dataset_to_write.attrs.get('meta'), dict): existing_meta.update(dataset_to_write.attrs.get('meta', {})) existing_meta.update(generated_meta) dataset_to_write.attrs['meta'] = existing_meta try: run_documents = self._get_tiled_run_documents_container(create=True) write_result = write_xarray_dataset(run_documents, dataset_to_write) except Exception as exc: error_msg = str(exc) if str(exc) else repr(exc) self.app.logger.error(f'Tiled upload error: {error_msg}', exc_info=True) return { 'status': 'error', 'message': f'Failed to write dataset to Tiled: {error_msg}', } entry_id = '' try: if hasattr(write_result, 'item'): entry_id = str(write_result.item.get('id', '')) elif hasattr(write_result, 'metadata') and isinstance(write_result.metadata, dict): entry_id = str(write_result.metadata.get('id', '')) except Exception: entry_id = '' return { 'status': 'success', 'message': 'Dataset uploaded to Tiled.', 'entry_id': entry_id, 'dataset_summary': { 'dims': {k: int(v) for k, v in dataset_to_write.sizes.items()}, 'data_vars': sorted(list(dataset_to_write.data_vars.keys())), 'coords': sorted(list(dataset_to_write.coords.keys())), }, }
def _get_tiled_client(self): """Get or create cached Tiled client. Returns: Tiled client or dict with error status """ if self._tiled_client is not None: return self._tiled_client # Get config using internal method (avoids decorator issues) config = self._read_tiled_config() if config['status'] == 'error': return config try: # Create and cache client self._tiled_client = from_uri( config['tiled_server'], api_key=config['tiled_api_key'], structure_clients="dask", ) return self._tiled_client except Exception as e: return { 'status': 'error', 'message': f'Failed to connect to Tiled: {str(e)}' } def _get_tiled_run_documents_container(self, create=False): """Get run_documents container, optionally creating it when missing.""" client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': raise RuntimeError(client.get('message', 'Failed to connect to Tiled.')) try: return client[self.TILED_RUN_DOCUMENTS_NODE] except Exception: if not create: return None client.create_container( key=self.TILED_RUN_DOCUMENTS_NODE, metadata={'type': self.TILED_RUN_DOCUMENTS_NODE}, ) return client[self.TILED_RUN_DOCUMENTS_NODE] def _normalize_run_document_entry_id(self, entry_id): entry_id = str(entry_id or '').strip() prefix = f'{self.TILED_RUN_DOCUMENTS_NODE}/' if entry_id.startswith(prefix): entry_id = entry_id[len(prefix):] return entry_id.strip('/') def _get_tiled_run_document_item(self, entry_id): normalized_id = self._normalize_run_document_entry_id(entry_id) if not normalized_id: raise KeyError(entry_id) container = self._get_tiled_run_documents_container(create=False) if container is None: raise KeyError(normalized_id) return normalized_id, container[normalized_id] def _read_tiled_item(self, item): """Read a Tiled item, disabling wide-table optimization when supported.""" try: return item.read(optimize_wide_table=False) except TypeError as exc: # Some non-xarray clients do not accept optimize_wide_table. message = str(exc) if 'optimize_wide_table' in message or 'unexpected keyword' in message: return item.read() raise
[docs] def tiled_get_data(self, entry_id, **kwargs): """Proxy endpoint to get xarray HTML representation from Tiled. Args: entry_id: Tiled entry ID Returns: dict with status and html, or error message """ # Get cached Tiled client client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': return client try: normalized_id, item = self._get_tiled_run_document_item(entry_id) # Try to get xarray dataset representation try: dataset = self._read_tiled_item(item) # Get HTML representation if hasattr(dataset, '_repr_html_'): html = dataset._repr_html_() else: # Fallback to string representation html = f'<pre>{str(dataset)}</pre>' return { 'status': 'success', 'html': html } except Exception as e: # If can't read as dataset, provide basic info html = '<div class="data-display">' html += f'<p><strong>Entry ID:</strong> {normalized_id}</p>' html += f'<p><strong>Type:</strong> {type(item).__name__}</p>' if hasattr(item, 'metadata'): html += '<h4>Metadata:</h4>' html += f'<pre>{json.dumps(dict(item.metadata), indent=2)}</pre>' html += f'<p><em>Could not load data representation: {str(e)}</em></p>' html += '</div>' return { 'status': 'success', 'html': html } except KeyError: return { 'status': 'error', 'message': f'Entry "{entry_id}" not found' } except Exception as e: return { 'status': 'error', 'message': f'Error fetching data: {str(e)}' }
[docs] def tiled_get_xarray_html(self, entry_ids, **kwargs): """Return xarray _repr_html_() for one or more Tiled entries. Reuses the combined-dataset cache shared with the plot manifest endpoint, so if the plot has already been rendered this call is effectively free. Args: entry_ids: JSON-encoded list of entry IDs Returns: dict with 'status' and 'html' (xarray HTML string) """ try: entry_ids_list = self._parse_entry_ids_param(entry_ids) except (json.JSONDecodeError, ValueError) as e: return {'status': 'error', 'message': f'Invalid entry_ids: {str(e)}'} if not entry_ids_list: return {'status': 'error', 'message': 'entry_ids must contain at least one entry'} try: ds = self._get_or_create_combined_dataset(entry_ids_list) except Exception as e: return {'status': 'error', 'message': f'Error loading dataset: {str(e)}'} try: html = ds._repr_html_() except Exception: html = f'<pre>{str(ds)}</pre>' return {'status': 'success', 'html': html}
[docs] def tiled_get_full_json(self, entry_id, **kwargs): """Proxy endpoint to get JSON-serializable full data payload for one entry. This endpoint is intended as a same-origin fallback for browser clients when direct browser->Tiled CORS access is not available. """ import numpy as np import pandas as pd client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': return client try: _, item = self._get_tiled_run_document_item(entry_id) payload = self._read_tiled_item(item) # xarray datasets: flatten into column->list payload if hasattr(payload, 'to_dataframe'): dataframe = payload.to_dataframe().reset_index() data = {} for column in dataframe.columns: series = dataframe[column] converted = [] for value in series.tolist(): if isinstance(value, (np.floating, float)): converted.append(None if np.isnan(value) else float(value)) elif isinstance(value, (np.integer, int)): converted.append(int(value)) elif isinstance(value, (np.bool_, bool)): converted.append(bool(value)) elif isinstance(value, np.ndarray): converted.append(value.tolist()) else: converted.append(value) data[str(column)] = converted return { 'status': 'success', 'data': data } # pandas tables if isinstance(payload, pd.DataFrame): dataframe = payload.reset_index() return { 'status': 'success', 'data': { str(col): dataframe[col].tolist() for col in dataframe.columns } } # arrays and scalars if isinstance(payload, np.ndarray): return { 'status': 'success', 'data': {'value': payload.tolist()} } return { 'status': 'success', 'data': {'value': payload} } except Exception as e: return { 'status': 'error', 'message': f'Error fetching full data: {str(e)}' }
[docs] def tiled_get_metadata(self, entry_id, **kwargs): """Proxy endpoint to get metadata from Tiled. Args: entry_id: Tiled entry ID Returns: dict with status and metadata, or error message """ # Get cached Tiled client client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': return client try: _, item = self._get_tiled_run_document_item(entry_id) # Extract metadata metadata = dict(item.metadata) if hasattr(item, 'metadata') else {} return { 'status': 'success', 'metadata': metadata } except KeyError: return { 'status': 'error', 'message': f'Entry "{entry_id}" not found' } except Exception as e: return { 'status': 'error', 'message': f'Error fetching metadata: {str(e)}' }
[docs] def tiled_get_distinct_values(self, field, **kwargs): """Get distinct/unique values for a metadata field using Tiled's distinct() method. Args: field: Metadata field name (e.g., 'sample_name', 'sample_uuid', 'AL_campaign_name', 'AL_uuid') Returns: dict with status and list of unique values, or error message """ # Get cached Tiled client client = self._get_tiled_client() if isinstance(client, dict) and client.get('status') == 'error': return client try: run_documents = self._get_tiled_run_documents_container(create=False) if run_documents is None: return { 'status': 'success', 'field': field, 'values': [], 'count': 0 } # Use Tiled's distinct() method scoped to run_documents. distinct_result = run_documents.distinct(field) # Extract the values from the metadata # distinct() returns {'metadata': {field: [{'value': ..., 'count': ...}, ...]}} if 'metadata' in distinct_result and field in distinct_result['metadata']: values_list = distinct_result['metadata'][field] # Extract just the 'value' field from each entry unique_values = [item['value'] for item in values_list if item.get('value') is not None] else: unique_values = [] return { 'status': 'success', 'field': field, 'values': unique_values, 'count': len(unique_values) } except Exception as e: return { 'status': 'error', 'message': f'Error getting distinct values for field "{field}": {str(e)}' }
def _fetch_single_tiled_entry(self, entry_id): """Fetch a single entry from Tiled and extract metadata. Parameters ---------- entry_id : str Tiled entry ID to fetch Returns ------- tuple (dataset, metadata_dict) where metadata_dict contains: - entry_id: str - The Tiled entry ID - sample_name: str - Sample name (from metadata, attrs, or entry_id) - sample_uuid: str - Sample UUID (from metadata, attrs, or '') - sample_composition: Optional[Dict] - Parsed composition with structure: {'components': List[str], 'values': List[float]} Raises ------ ValueError If Tiled client cannot be obtained If entry_id is not found in Tiled If dataset cannot be read """ import xarray as xr try: normalized_id, item = self._get_tiled_run_document_item(entry_id) except Exception: raise ValueError(f'Entry "{entry_id}" not found in tiled {self.TILED_RUN_DOCUMENTS_NODE}') from None # Fetch dataset with wide-table optimization disabled for xarray clients. dataset = self._read_tiled_item(item) # Extract metadata from tiled item tiled_metadata = dict(item.metadata) if hasattr(item, 'metadata') else {} # Also check dataset attrs for metadata ds_attrs = dict(dataset.attrs) if hasattr(dataset, 'attrs') else {} # Build metadata dict, preferring tiled metadata over dataset attrs # Include ALL metadata fields for Gantt chart metadata = { 'entry_id': normalized_id, 'sample_name': tiled_metadata.get('sample_name') or ds_attrs.get('sample_name') or normalized_id, 'sample_uuid': tiled_metadata.get('sample_uuid') or ds_attrs.get('sample_uuid') or '', 'sample_composition': None, # Add full metadata for Gantt chart and other uses 'attrs': tiled_metadata.get('attrs', {}) or ds_attrs.get('attrs', {}), 'meta': tiled_metadata.get('meta', {}) or tiled_metadata.get('attrs', {}).get('meta', {}) or ds_attrs.get('meta', {}), 'AL_campaign_name': tiled_metadata.get('AL_campaign_name') or tiled_metadata.get('attrs', {}).get('AL_campaign_name') or ds_attrs.get('AL_campaign_name', ''), 'AL_uuid': tiled_metadata.get('AL_uuid') or tiled_metadata.get('attrs', {}).get('AL_uuid') or ds_attrs.get('AL_uuid', ''), 'task_name': tiled_metadata.get('task_name') or tiled_metadata.get('attrs', {}).get('task_name') or ds_attrs.get('task_name', ''), 'driver_name': tiled_metadata.get('driver_name') or tiled_metadata.get('attrs', {}).get('driver_name') or ds_attrs.get('driver_name', ''), } # Extract sample_composition - be fault tolerant if it doesn't exist comp_dict = tiled_metadata.get('sample_composition') or ds_attrs.get('sample_composition') if comp_dict and isinstance(comp_dict, dict): # Parse composition dict to extract components and values components = [] values = [] for comp_name, comp_data in comp_dict.items(): # Skip non-component keys like 'units', 'components', etc. if comp_name in ('units', 'conc_units', 'mass_units', 'components'): continue try: if isinstance(comp_data, dict): # Handle both 'value' (scalar) and 'values' (array) cases if 'value' in comp_data: values.append(float(comp_data['value'])) components.append(comp_name) elif 'values' in comp_data: val = comp_data['values'] if isinstance(val, (list, tuple)) and len(val) > 0: values.append(float(val[0])) else: values.append(float(val) if val is not None else 0.0) components.append(comp_name) elif isinstance(comp_data, (int, float)): # Direct numeric value values.append(float(comp_data)) components.append(comp_name) except (ValueError, TypeError): # Skip components that can't be converted to float continue if components: metadata['sample_composition'] = { 'components': components, 'values': values } return dataset, metadata def _detect_sample_dimension(self, dataset, allow_size_fallback=True): """Detect the sample dimension from a dataset. Looks for dimensions matching patterns like '*_sample' or 'sample'. Optionally falls back to the first dimension with size > 1. Parameters ---------- dataset : xr.Dataset Dataset to inspect. allow_size_fallback : bool, default=True If True, use the first dimension with size > 1 when no explicit sample-like dimension name is found. If False, return None when no explicit sample-like dimension name is present. Returns ------- str or None The detected sample dimension name, or None if not found """ import re # Pattern priority: exact 'sample', then '*_sample', then first multi-valued dim dims = list(dataset.sizes.keys()) # Check for exact 'sample' first if 'sample' in dims: return 'sample' # Check for *_sample pattern sample_pattern = re.compile(r'.*_sample$') for dim in dims: if sample_pattern.match(dim): return dim if allow_size_fallback: # Fallback: first dimension with size > 1 for dim in dims: if dataset.sizes[dim] > 1: return dim # Last resort: first dimension return dims[0] if dims else None return None
[docs] def tiled_concat_datasets(self, entry_ids, concat_dim='index', variable_prefix=''): """Gather datasets from Tiled entries and concatenate them along a dimension. This method fetches multiple datasets from a Tiled server, extracts metadata (sample_name, sample_uuid, sample_composition), and concatenates them along the specified dimension. It also supports prefixing variable names. For a single entry, the dataset is returned as-is without concatenation, and the sample dimension is auto-detected from existing dimensions. Parameters ---------- entry_ids : List[str] List of Tiled entry IDs to fetch and concatenate concat_dim : str, default="index" Dimension name along which to concatenate the datasets (ignored for single entry) variable_prefix : str, default="" Optional prefix to prepend to variable, coordinate, and dimension names (except the concat_dim itself) Returns ------- xr.Dataset For single entry: The original dataset with metadata added as attributes For multiple entries: Concatenated dataset with: - All original data variables and coordinates from individual datasets - Additional coordinates along concat_dim: - sample_name: Sample name from metadata or entry_id - sample_uuid: Sample UUID from metadata or empty string - entry_id: The Tiled entry ID for each dataset - If sample_composition metadata exists: - composition: DataArray with dims [concat_dim, "components"] containing composition values for each sample Raises ------ ValueError If entry_ids is empty If any entry_id is not found in Tiled If datasets cannot be fetched or concatenated """ import xarray as xr import numpy as np if not entry_ids: raise ValueError("entry_ids list cannot be empty") # Fetch all entry datasets and metadata datasets = [] metadata_list = [] for entry_id in entry_ids: try: ds, metadata = self._fetch_single_tiled_entry(entry_id) datasets.append(ds) metadata_list.append(metadata) except Exception as e: raise ValueError(f"Failed to fetch entry '{entry_id}': {str(e)}") if not datasets: raise ValueError("No datasets fetched") # SINGLE ENTRY CASE: Return dataset as-is with metadata added if len(datasets) == 1: dataset = datasets[0] metadata = metadata_list[0] # Detect the sample dimension from the dataset # For single entries, avoid guessing a random axis as "sample". sample_dim = self._detect_sample_dimension(dataset, allow_size_fallback=False) if sample_dim is None: # Ensure plotter paths have a sample dimension even for one entry. # Use concat_dim for consistency with multi-entry flow. sample_dim = concat_dim dataset = dataset.expand_dims({sample_dim: [0]}) dataset = dataset.assign_coords({ 'sample_name': (sample_dim, [metadata['sample_name']]), 'sample_uuid': (sample_dim, [metadata['sample_uuid']]), 'entry_id': (sample_dim, [metadata['entry_id']]), }) # Add metadata as dataset attributes (not coordinates, since we don't have a new dim) dataset.attrs['sample_name'] = metadata['sample_name'] dataset.attrs['sample_uuid'] = metadata['sample_uuid'] dataset.attrs['entry_id'] = metadata['entry_id'] dataset.attrs['_detected_sample_dim'] = sample_dim # If sample_composition exists, add it as a DataArray along the sample dimension if metadata['sample_composition'] and sample_dim: components = metadata['sample_composition']['components'] values = metadata['sample_composition']['values'] # Check if composition already exists in dataset (common case) # If not, we could add it, but for single entry this is usually already there if 'composition' not in dataset.data_vars: # Create composition array - but we need to match the sample dimension size # This is tricky for single entry since composition is per-sample # For now, store in attrs dataset.attrs['sample_composition'] = { 'components': components, 'values': values } # Apply variable prefix if specified if variable_prefix: rename_dict = {} for var_name in list(dataset.data_vars): if not var_name.startswith(variable_prefix): rename_dict[var_name] = variable_prefix + var_name for coord_name in list(dataset.coords): if coord_name not in dataset.dims and not coord_name.startswith(variable_prefix): rename_dict[coord_name] = variable_prefix + coord_name for dim_name in list(dataset.dims): if not dim_name.startswith(variable_prefix): rename_dict[dim_name] = variable_prefix + dim_name if rename_dict: dataset = dataset.rename(rename_dict) return dataset # MULTIPLE ENTRIES CASE: Concatenate along concat_dim # Collect metadata values for each entry sample_names = [m['sample_name'] for m in metadata_list] sample_uuids = [m['sample_uuid'] for m in metadata_list] entry_id_values = [m['entry_id'] for m in metadata_list] # Build compositions DataArray before concatenation # Collect all unique components across all entries all_components = set() for m in metadata_list: if m['sample_composition']: all_components.update(m['sample_composition']['components']) all_components = sorted(list(all_components)) # Create composition data array if we have components if all_components: n_samples = len(datasets) n_components = len(all_components) comp_data = np.zeros((n_samples, n_components)) for i, m in enumerate(metadata_list): if m['sample_composition']: for j, comp_name in enumerate(all_components): if comp_name in m['sample_composition']['components']: idx = m['sample_composition']['components'].index(comp_name) comp_data[i, j] = m['sample_composition']['values'][idx] # Create the compositions DataArray compositions = xr.DataArray( data=comp_data, dims=[concat_dim, "components"], coords={ concat_dim: range(n_samples), "components": all_components }, name="composition" ) else: compositions = None # Concatenate along new dimension # Use coords="minimal" to avoid conflict with compat="override" concatenated = xr.concat(datasets, dim=concat_dim, coords="minimal", compat='override') # Assign 1D coordinates along concat_dim concatenated = concatenated.assign_coords({ 'sample_name': (concat_dim, sample_names), 'sample_uuid': (concat_dim, sample_uuids), 'entry_id': (concat_dim, entry_id_values) }) # Add compositions if we have it if compositions is not None: concatenated = concatenated.assign(composition=compositions) # Prefix names (data vars, coords, dims) but NOT the concat_dim itself if variable_prefix: rename_dict = {} # Rename data variables for var_name in list(concatenated.data_vars): if not var_name.startswith(variable_prefix): rename_dict[var_name] = variable_prefix + var_name # Rename coordinates (but not concat_dim) for coord_name in list(concatenated.coords): if coord_name == concat_dim: continue # Don't rename the concat_dim coordinate if coord_name not in concatenated.dims: # Non-dimension coordinates if not coord_name.startswith(variable_prefix): rename_dict[coord_name] = variable_prefix + coord_name # Rename dimensions but NOT concat_dim for dim_name in list(concatenated.dims): if dim_name == concat_dim: continue # Don't rename the concat_dim if not dim_name.startswith(variable_prefix): rename_dict[dim_name] = variable_prefix + dim_name # Apply all renames if rename_dict: concatenated = concatenated.rename(rename_dict) return concatenated
def _parse_entry_ids_param(self, entry_ids): """Parse entry_ids parameter from JSON string or list.""" if isinstance(entry_ids, str): parsed = json.loads(entry_ids) else: parsed = entry_ids if not isinstance(parsed, list): raise ValueError('entry_ids must be a JSON array or list') return parsed def _entry_ids_cache_key(self, entry_ids_list): """Create a stable cache key from ordered entry IDs.""" return json.dumps(entry_ids_list, separators=(',', ':')) def _cache_put(self, cache, order, key, value, max_items): """Put value into bounded insertion-ordered cache.""" if key in cache: cache[key] = value if key in order: order.remove(key) order.append(key) return cache[key] = value order.append(key) while len(order) > max_items: old_key = order.pop(0) cache.pop(old_key, None) def _get_or_create_combined_dataset(self, entry_ids_list): """Get combined dataset from cache or create and cache it.""" key = self._entry_ids_cache_key(entry_ids_list) cached = self._combined_dataset_cache.get(key) if cached is not None: return cached combined_dataset = self.tiled_concat_datasets( entry_ids=entry_ids_list, concat_dim='index', variable_prefix='' ) self._cache_put( self._combined_dataset_cache, self._combined_dataset_cache_order, key, combined_dataset, self._max_combined_dataset_cache ) return combined_dataset def _sanitize_for_json(self, obj): """Recursively replace NaN/Inf with None for JSON compatibility.""" import math if isinstance(obj, list): return [self._sanitize_for_json(x) for x in obj] if isinstance(obj, float): if math.isnan(obj) or math.isinf(obj): return None return obj return obj def _safe_tolist(self, arr): """Convert numpy array to JSON-serializable list.""" import numpy as np import pandas as pd if not isinstance(arr, np.ndarray): arr = np.asarray(arr) if np.issubdtype(arr.dtype, np.datetime64): return pd.to_datetime(arr).astype(str).tolist() if np.issubdtype(arr.dtype, np.timedelta64): return (arr / np.timedelta64(1, 's')).tolist() return self._sanitize_for_json(arr.tolist())