Source code for AFL.automation.instrument.VirtualSAS

import warnings

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import lazy_loader as lazy

from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder

sasmodels = lazy.load("sasmodels", require="AFL-automation[sas-analysis]")


from AFL.automation.APIServer.Driver import Driver
from AFL.automation.shared.utilities import mpl_plot_to_bytes
AFLagent = lazy.load("AFL.agent", require="AFL-agent")

try:
    from AFL.agent.xarray_extensions import *
except ImportError:
    warnings.warn('AFL-agent xarray_extensions import failed! Some functionality may not work.  Install afl-agent',stacklevel=2)

[docs] class VirtualSAS(Driver): defaults = {} defaults['noise'] = 0.0 defaults['ternary'] = False defaults['fast_locate'] = True # RFC hyperparameters defaults['rfc_n_estimators'] = 100 defaults['rfc_max_depth'] = None defaults['rfc_random_state'] = 42 defaults['rfc_min_samples_split'] = 2 defaults['rfc_min_samples_leaf'] = 1 # Boundary datasets structure defaults['boundary_datasets'] = {} # Components list defaults['components'] = [] # Reference data configurations defaults['reference_data'] = [] # List of dicts with keys: q, I, dI, dq # SASView model configurations defaults['sasview_models'] = {} # Dict: {label: {'model_name': str, 'model_kw': dict}}
[docs] def __init__(self,overrides=None): ''' Generates smoothly interpolated scattering data via a noiseless GPR from an experiments netcdf file. Uses RandomForestClassifier for phase boundary classification. ''' self.app = None Driver.__init__(self,name='VirtualSAS_theory',defaults=self.gather_defaults(),overrides=overrides) import sasmodels.data import sasmodels.core import sasmodels.direct_model import sasmodels.bumps_model # Machine learning attributes for phase classification self.classifier = None self.label_encoder = LabelEncoder() self.X_train = None self.y_train = None self.phase_labels = None # Loaded reference data and models (created from config) self._reference_data = [] self._sasmodels = {} # Keep boundary_dataset for backward compatibility self.boundary_dataset = None # Load reference data and models from config at startup self.load_reference_data() self.load_sasview_models()
[docs] def status(self): status = [] status.append(f'Configurations Loaded={len(self._reference_data)}') status.append(f'SASView Models={len(self._sasmodels)}') status.append(f'Components={self.config["components"]}') if self.classifier is not None: status.append(f'Classifier Trained=True') status.append(f'Phases={len(self.phase_labels)} {self.phase_labels}') status.append(f'Training Samples={len(self.X_train) if self.X_train is not None else 0}') else: status.append(f'Classifier Trained=False') status.append(f'Noise Level={self.config["noise"]}') status.append(f'Ternary={self.config["ternary"]}') if self.classifier is not None: status.append(f'RFC n_estimators={self.config["rfc_n_estimators"]}') status.append(f'RFC max_depth={self.config["rfc_max_depth"]}') return status
[docs] def load_reference_data(self): ''' Load reference data from config into sasmodels Data1D objects. Config format: reference_data = [ {'q': [...], 'I': [...], 'dI': [...], 'dq': [...]}, {'q': [...], 'I': [...], 'dI': [...], 'dq': [...]}, ] ''' import sasmodels.data self._reference_data = [] for ref_config in self.config.get('reference_data', []): if not all(k in ref_config for k in ['q', 'I', 'dI', 'dq']): raise ValueError( 'Each reference_data entry must have keys: q, I, dI, dq. ' f'Got: {list(ref_config.keys())}' ) data = sasmodels.data.Data1D( x=np.array(ref_config['q']), y=np.array(ref_config['I']), dy=np.array(ref_config['dI']), dx=np.array(ref_config['dq']), ) self._reference_data.append(data)
[docs] def load_sasview_models(self): ''' Load SASView models from config. Config format: sasview_models = { 'phase_A': {'model_name': 'sphere', 'model_kw': {'radius': 50, ...}}, 'phase_B': {'model_name': 'cylinder', 'model_kw': {'radius': 20, ...}}, } ''' import sasmodels.core import sasmodels.direct_model self._sasmodels = {} for label, model_config in self.config.get('sasview_models', {}).items(): if 'model_name' not in model_config or 'model_kw' not in model_config: raise ValueError( f'SASView model "{label}" must have "model_name" and "model_kw" keys' ) calculators = [] sasdatas = [] for sasdata in self._reference_data: model_info = sasmodels.core.load_model_info(model_config['model_name']) kernel = sasmodels.core.build_model(model_info) calculator = sasmodels.direct_model.DirectModel(sasdata, kernel) calculators.append(calculator) sasdatas.append(sasdata) self._sasmodels[label] = { 'name': model_config['model_name'], 'kw': model_config['model_kw'], 'calculators': calculators, 'sasdata': sasdatas, }
[docs] def validate_boundary_datasets_config(self): ''' Validate boundary_datasets config structure and dimensionality. Returns ------- bool True if valid, raises ValueError otherwise Raises ------ ValueError If config structure is invalid ''' boundary_datasets = self.config.get('boundary_datasets', {}) if not isinstance(boundary_datasets, dict): raise ValueError('boundary_datasets must be a dictionary') if not boundary_datasets: raise ValueError('boundary_datasets is empty') for phase_label, phase_data in boundary_datasets.items(): if not isinstance(phase_data, dict): raise ValueError(f'Phase "{phase_label}" data must be a dictionary') if 'points' not in phase_data: raise ValueError(f'Phase "{phase_label}" missing "points" key') points = np.array(phase_data['points']) if points.ndim != 2: raise ValueError(f'Phase "{phase_label}" points must be 2D array') # Validate ternary mode requirements if self.config['ternary']: # Ternary transformation requires exactly 3D input if points.shape[1] != 3: raise ValueError( f'Ternary mode enabled but phase "{phase_label}" has {points.shape[1]} ' f'dimensions. Ternary coordinate transformation requires 3D data.' ) return True
[docs] def migrate_boundary_dataset_to_config(self): ''' Convert old boundary_dataset xarray object to new config format. This is a migration helper for transitioning from the old shapely-based system to the new RFC-based system. Returns ------- dict Boundary datasets in new config format Raises ------ ValueError If boundary_dataset is None or invalid ''' if self.boundary_dataset is None: raise ValueError('boundary_dataset is None, nothing to migrate') boundary_datasets = {} label_variable = self.boundary_dataset.attrs.get('labels') if label_variable is None: raise ValueError('boundary_dataset missing "labels" attribute') for label, sds in self.boundary_dataset.groupby(label_variable): # Extract components using same logic as old trace_boundaries comps = sds[sds.attrs['components']].transpose(..., 'component') points = comps.values.tolist() # Extract component names if available if 'components_dim' in sds.attrs: component_names = list(self.boundary_dataset[sds.attrs['components_dim']].values) else: component_names = list(sds.attrs.get('components', [])) boundary_datasets[label] = { 'points': points, 'component_names': component_names } # Update config self.config['boundary_datasets'] = boundary_datasets return boundary_datasets
[docs] def train_classifier(self, reset=True, drop_phases=None): ''' Train RandomForestClassifier on boundary datasets from config. Parameters ---------- reset : bool If True, reinitialize classifier before training drop_phases : list or None Phase labels to exclude from training Raises ------ ValueError If boundary_datasets not configured or invalid format ''' if drop_phases is None: drop_phases = [] if reset: self.classifier = None self.X_train = None self.y_train = None # Validate config structure boundary_datasets = self.config.get('boundary_datasets', {}) if not boundary_datasets: raise ValueError( 'Must set boundary_datasets in config before training! ' 'Expected format: {"phase_A": {"points": [[x1,y1], [x2,y2], ...]}, ...}' ) # Extract training data from config X_list = [] y_list = [] phase_labels = [] for phase_label, phase_data in boundary_datasets.items(): if phase_label in drop_phases: continue # Validate phase data structure if 'points' not in phase_data: raise ValueError( f'Phase "{phase_label}" missing "points" key. ' f'Expected format: {{"points": [[x1,y1], [x2,y2], ...]}}' ) points = np.array(phase_data['points']) # Validate dimensionality if points.ndim != 2: raise ValueError( f'Phase "{phase_label}" points must be 2D array, got shape {points.shape}' ) # Apply ternary coordinate transformation if enabled if self.config['ternary']: # Ternary transformation requires 3D input and produces 2D output if points.shape[1] != 3: raise ValueError( f'Ternary mode enabled but phase "{phase_label}" has {points.shape[1]} ' f'dimensions. Ternary coordinate transformation requires exactly 3D data.' ) # Convert ternary to xy for classifier xy = AFLagent.util.ternary_to_xy(points[:, [2, 0, 1]]) # Match old coordinate system else: # Use data as-is (can be any dimension: 2D, 3D, etc.) xy = points # Warn about insufficient samples if len(xy) < 3: warnings.warn( f'Phase "{phase_label}" has only {len(xy)} training samples. ' f'RandomForestClassifier may not perform well. Recommend >= 10 samples per phase.', stacklevel=2 ) X_list.append(xy) y_list.extend([phase_label] * len(xy)) phase_labels.append(phase_label) if not X_list: raise ValueError('No training data available after filtering drop_phases') # Combine all training data self.X_train = np.vstack(X_list) self.y_train = np.array(y_list) self.phase_labels = phase_labels # Encode labels self.label_encoder.fit(self.y_train) y_encoded = self.label_encoder.transform(self.y_train) # Initialize and train RFC self.classifier = RandomForestClassifier( n_estimators=self.config['rfc_n_estimators'], max_depth=self.config['rfc_max_depth'], random_state=self.config['rfc_random_state'], min_samples_split=self.config['rfc_min_samples_split'], min_samples_leaf=self.config['rfc_min_samples_leaf'], ) self.classifier.fit(self.X_train, y_encoded)
# ds = xr.Dataset() # ds.attrs['classifier_trained'] = True # ds.attrs['n_training_samples'] = len(self.X_train) # ds.attrs['phase_labels'] = phase_labels # return ds
[docs] def locate(self, composition): ''' Predict phase membership using trained RandomForestClassifier. Parameters ---------- composition : array-like Composition vector. Dimensionality depends on ternary mode: - ternary=True: Must be 3D (will be transformed to 2D for RFC) - ternary=False: Any dimension matching training data Returns ------- ds : xr.Dataset Dataset containing: - 'phase': predicted phase label - 'probability': confidence score - attrs: all phase probabilities Notes ----- Use locate_with_uncertainty() to get prediction probabilities as tuple ''' composition = np.array(composition) if self.classifier is None: if (self.boundary_dataset is None) or (self._sasmodels is None): raise ValueError('Must call load_reference_data() and load_sasview_models() before locate()') else: self.train_classifier() # Convert to 2D array if needed (composition might be 1D) if composition.ndim == 1: composition = composition.reshape(1, -1) # Apply ternary coordinate transformation if enabled if self.config['ternary']: if composition.shape[1] != 3: raise ValueError( f'Ternary mode enabled but composition has {composition.shape[1]} ' f'dimensions. Ternary transformation requires exactly 3D data.' ) xy = AFLagent.util.ternary_to_xy(composition) else: # Use composition as-is (must match training data dimensionality) xy = composition # Validate dimensionality matches training data if xy.shape[1] != self.X_train.shape[1]: raise ValueError( f'Composition has {xy.shape[1]} dimensions after transformation, ' f'but classifier was trained on {self.X_train.shape[1]}D data' ) # Predict using RFC y_pred_encoded = self.classifier.predict(xy) y_pred = self.label_encoder.inverse_transform(y_pred_encoded) # Get probabilities for uncertainty y_proba = self.classifier.predict_proba(xy) max_proba = np.max(y_proba, axis=1) # Create xarray Dataset ds = xr.Dataset() ds['phase'] = str(y_pred[0]) # Convert to native Python string ds['probability'] = float(max_proba[0]) # Store all probabilities in attrs (convert keys to native Python strings) ds.attrs['all_probabilities'] = { str(label): float(prob) for label, prob in zip(self.label_encoder.classes_, y_proba[0]) } return ds
[docs] def locate_with_uncertainty(self, composition): ''' Predict phase with uncertainty estimation via predict_proba. Parameters ---------- composition : array-like Composition vector. Dimensionality depends on ternary mode: - ternary=True: Must be 3D (will be transformed to 2D for RFC) - ternary=False: Any dimension matching training data Returns ------- phase : str Predicted phase label probability : float Confidence score for predicted phase (0-1) all_probabilities : dict Probability scores for all phases {phase_label: probability} ''' composition = np.array(composition) if self.classifier is None: raise ValueError('Must call train_classifier() before locate_with_uncertainty()') # Convert to 2D array if needed if composition.ndim == 1: composition = composition.reshape(1, -1) # Apply ternary coordinate transformation if enabled if self.config['ternary']: if composition.shape[1] != 3: raise ValueError( f'Ternary mode enabled but composition has {composition.shape[1]} ' f'dimensions. Ternary transformation requires exactly 3D data.' ) xy = AFLagent.util.ternary_to_xy(composition) else: # Use composition as-is xy = composition # Validate dimensionality matches training data if xy.shape[1] != self.X_train.shape[1]: raise ValueError( f'Composition has {xy.shape[1]} dimensions after transformation, ' f'but classifier was trained on {self.X_train.shape[1]}D data' ) # Predict class and probabilities y_pred_encoded = self.classifier.predict(xy) y_pred = self.label_encoder.inverse_transform(y_pred_encoded) y_proba = self.classifier.predict_proba(xy) phase = str(y_pred[0]) # Convert to native Python string probability = float(np.max(y_proba[0])) all_probabilities = { str(label): float(prob) # Convert keys to native Python strings for label, prob in zip(self.label_encoder.classes_, y_proba[0]) } # Create xarray Dataset ds = xr.Dataset() ds['phase'] = phase ds['probability'] = probability ds.attrs['all_probabilities'] = all_probabilities return ds
[docs] def generate(self, label): ''' Generate scattering data for a given phase label. Parameters ---------- label : str Phase label (must exist in sasview_models config) Returns ------- ds : xr.Dataset Dataset containing: - 'q': scattering vector - 'I': scattered intensity (with noise) - 'I_noiseless': scattered intensity (without noise) - 'dI': uncertainty - attrs: phase label, model name ''' if label not in self._sasmodels: raise ValueError( f'Phase label "{label}" not found in sasview_models config. ' f'Available: {list(self._sasmodels.keys())}' ) kw = self._sasmodels[label]['kw'] calculators = self._sasmodels[label]['calculators'] sasdatas = self._sasmodels[label]['sasdata'] noise = self.config['noise'] q_list = [] I_noiseless_list = [] I_list = [] dI_list = [] for sasdata, calc in zip(sasdatas, calculators): I_noiseless = calc(**kw) dI_model = sasdata.dy * np.sqrt(I_noiseless / sasdata.y) mean_var = np.mean(dI_model * dI_model / I_noiseless) dI = sasdata.dy * noise / mean_var I = np.random.normal(loc=I_noiseless, scale=dI) q_list.append(sasdata.x) I_noiseless_list.append(I_noiseless) I_list.append(I) dI_list.append(dI) # Concatenate and sort by q q_all = np.concatenate(q_list) I_all = np.concatenate(I_list) I_noiseless_all = np.concatenate(I_noiseless_list) dI_all = np.concatenate(dI_list) # Sort by q sort_idx = np.argsort(q_all) q_sorted = q_all[sort_idx] I_sorted = I_all[sort_idx] I_noiseless_sorted = I_noiseless_all[sort_idx] dI_sorted = dI_all[sort_idx] # Create xarray Dataset ds = xr.Dataset( { 'I': ('q', I_sorted), 'I_noiseless': ('q', I_noiseless_sorted), 'dI': ('q', dI_sorted), }, coords={'q': q_sorted} ) # Store metadata ds.attrs['phase'] = label ds.attrs['model_name'] = self._sasmodels[label]['name'] return ds
[docs] def expose(self, *args, **kwargs): ''' Mimic the expose command from other instrument servers. Returns ------- ds : xr.Dataset Combined dataset from locate() and generate() with composition info ''' # Get components from config components = self.config.get('components', []) if not components: raise ValueError( 'components not configured. Set config["components"] = [...] with component names' ) # Validate all components are available in self.data if self.data is None: raise ValueError('self.data is None. Cannot extract sample_composition.') missing_components = [] for component in components: if component not in self.data.get('sample_composition', {}): missing_components.append(component) if missing_components: raise ValueError( f'Components {missing_components} not found in self.data["sample_composition"]. ' f'Available: {list(self.data.get("sample_composition", {}).keys())}' ) # Extract composition from self.data composition = [self.data['sample_composition'][component]['value'] for component in components] # Get phase prediction from locate() ds_locate = self.locate(composition) label = str(ds_locate['phase'].item()) # Ensure native Python string # Generate scattering data ds_generate = self.generate(label) # Merge results ds = ds_generate.copy() # Add composition information ds['components'] = ('component', components) ds['composition'] = ('component', composition) # Merge locate results into attrs ds.attrs['phase'] = label ds.attrs['prediction_probability'] = ds_locate['probability'].item() ds.attrs['all_probabilities'] = ds_locate.attrs['all_probabilities'] ds.attrs['components'] = components return ds
[docs] @Driver.unqueued(render_hint='precomposed_svg') def plot_decision_boundaries(self, grid_resolution=200, **kwargs): ''' Plot RFC decision boundaries with training data overlay. Parameters ---------- grid_resolution : int Number of grid points per axis for decision boundary mesh ''' matplotlib.use('Agg') # very important fig, ax = plt.subplots(figsize=(10, 8)) if self.classifier is None: plt.text(0.5, 0.5, 'No classifier trained. Run train_classifier()', ha='center', va='center', fontsize=14) plt.xlim(0, 1) plt.ylim(0, 1) else: # Create mesh grid for decision boundary x_min, x_max = self.X_train[:, 0].min() - 0.05, self.X_train[:, 0].max() + 0.05 y_min, y_max = self.X_train[:, 1].min() - 0.05, self.X_train[:, 1].max() + 0.05 xx, yy = np.meshgrid( np.linspace(x_min, x_max, grid_resolution), np.linspace(y_min, y_max, grid_resolution) ) # Predict on grid grid_points = np.c_[xx.ravel(), yy.ravel()] Z_encoded = self.classifier.predict(grid_points) Z_numeric = Z_encoded.reshape(xx.shape) # Plot decision boundary as contourf n_classes = len(self.label_encoder.classes_) contour = ax.contourf(xx, yy, Z_numeric, alpha=0.3, levels=n_classes - 1, cmap='viridis') # Overlay training data points y_train_encoded = self.label_encoder.transform(self.y_train) scatter = ax.scatter( self.X_train[:, 0], self.X_train[:, 1], c=y_train_encoded, cmap='viridis', edgecolors='black', s=50, alpha=0.8 ) # Add legend handles = [] for i, label in enumerate(self.label_encoder.classes_): color_val = i / max(1, (n_classes - 1)) handle = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=plt.cm.viridis(color_val), markersize=10, label=label, markeredgecolor='black') handles.append(handle) ax.legend(handles=handles, title='Phase', loc='best') ax.set_xlabel('X coordinate' if not self.config['ternary'] else 'Ternary X') ax.set_ylabel('Y coordinate' if not self.config['ternary'] else 'Ternary Y') ax.set_title(f'RFC Decision Boundaries (n={len(self.X_train)} samples)') svg = mpl_plot_to_bytes(fig, format='svg') plt.close(fig) return svg
[docs] @Driver.unqueued(render_hint='precomposed_svg') def plot_hulls(self, **kwargs): ''' DEPRECATED: Use plot_decision_boundaries() instead. This wrapper is provided for backward compatibility. ''' warnings.warn( 'plot_hulls() is deprecated and will be removed in future versions. ' 'Use plot_decision_boundaries() instead.', DeprecationWarning, stacklevel=2 ) return self.plot_decision_boundaries(**kwargs)
[docs] @Driver.unqueued(render_hint='precomposed_svg') def plot_boundary_data(self, **kwargs): '''Plot boundary training data in ternary projection (if ternary mode enabled)''' matplotlib.use('Agg') # very important if self.config['ternary']: fig, ax = plt.subplots(subplot_kw={'projection': 'ternary'}) else: fig, ax = plt.subplots() if self.classifier is None: plt.text(0.5, 0.5, 'No classifier trained. Run train_classifier()', ha='center', va='center') if not self.config['ternary']: plt.xlim(0, 1) plt.ylim(0, 1) else: # Plot training data from config boundary_datasets = self.config.get('boundary_datasets', {}) for phase_label, phase_data in boundary_datasets.items(): points = np.array(phase_data['points']) if self.config['ternary']: # Plot in ternary space (assumes points are 3D ternary) ax.scatter(points[:, 0], points[:, 1], points[:, 2], label=phase_label, alpha=0.7, s=30) else: # Plot in 2D ax.scatter(points[:, 0], points[:, 1], label=phase_label, alpha=0.7, s=30) ax.legend(title='Phase') ax.set_title('Boundary Training Data') if not self.config['ternary']: ax.set_xlabel('Component 1') ax.set_ylabel('Component 2') svg = mpl_plot_to_bytes(fig, format='svg') plt.close(fig) return svg
if __name__ == '__main__': from AFL.automation.shared.launcher import * # This allows the file to be run directly to start a server for this driver