Source code for AFL.automation.shared.DataLabelerWidget

from typing import List, Optional

import ipywidgets  # type: ignore
import numpy as np
import plotly.express as px  # type: ignore
import plotly.graph_objects as go  # type: ignore
import xarray as xr
from math import sqrt
from sklearn.preprocessing import OrdinalEncoder  # type: ignore


[docs] class DataLabelerWidget:
[docs] def __init__( self, input_dataset: xr.Dataset, sas_variable: str, composition_variable: str | List[str], sample_dim: str = "sample", fit_variable: Optional[str] = None, ): """ Parameters ---------- dataset: xr.Dataset Dataset from AFL sas_variable: str Name of data variable in the `xarray.Dataset` that holds the scattering data composition_variable: str | List[str] Name of data variable in the `xarray.Dataset` that holds the composition. If the composition is split across multiple variables, pass in a list of variables. sample_dim: str The name of the `xarray` dimension corresponding to each sample or measurement. This is typically named 'sample' in much of the AFL agent codebase. fit_variable: Optional[str] If not none, this data will be plotted along with the sas_variable data. This data variable should have the same shape as sas_variable. """ # preprocess the dataset before sending to the data model dataset = xr.Dataset() dataset["sas"] = input_dataset[sas_variable].transpose(sample_dim, ...) if isinstance(composition_variable, list): dataset["composition"] = ( input_dataset[composition_variable] .to_array("component") .transpose(..., "component") ) else: dataset["composition"] = input_dataset[composition_variable].transpose( sample_dim, ... ) if fit_variable is not None: dataset['fit'] = input_dataset[fit_variable].transpose(sample_dim,...) self.data_view = DataLabelerView() self.data_model = DataLabelerModel(dataset) self.data_index = 0
[docs] def next_button_callback(self, click): self.data_index += 1 self.data_view.current_index.value = self.data_index self.update_plot()
[docs] def prev_button_callback(self, click): self.data_index -= 1 self.data_view.current_index.value = self.data_index self.update_plot()
[docs] def goto_callback(self, click): index = self.data_view.current_index.value self.data_index = index self.update_plot()
[docs] def composition_click_callback(self, figure, location, click): index = location.point_inds[0] self.data_index = index self.data_view.current_index.value = self.data_index self.update_plot()
[docs] def update_plot(self): saxs_data = self.data_model.sas_data[self.data_index] composition_data = self.data_model.composition_data[self.data_index] self.data_view.update_plot( x=saxs_data[self.data_model.q_variable].values, y=saxs_data.values, composition=composition_data, ) self.data_view.current_label.value = self.data_model.phase_labels[ self.data_index ]
[docs] def draw_peaks(self, peaks): self.data_view.output.clear_output() with self.data_view.output: for i, (vl, (spacing, peak_loc)) in enumerate( zip(self.data_view.fig1.layout.shapes, peaks.items()) ): print("{}q* = {}".format(spacing, peak_loc)) vl.update({"x0": peak_loc, "x1": peak_loc, "visible": True}) for ( i, vl, ) in enumerate(self.data_view.fig1.layout.shapes): if i > len(peaks) - 1: vl.visible = False
[docs] def change_qstar_callback(self, figure, location, click): model = self.data_view.phase_dropdown.value n_orders = self.data_view.n_orders.value peaks = self.data_model.get_peaks( model, qstar=location.xs[0], max_order=n_orders ) self.draw_peaks(peaks)
[docs] def change_model_callback(self, data): model = self.data_view.phase_dropdown.value n_orders = self.data_view.n_orders.value peaks = self.data_model.get_peaks(model, max_order=n_orders) self.draw_peaks(peaks)
[docs] def change_norder_callback(self, data): model = self.data_view.phase_dropdown.value n_orders = self.data_view.n_orders.value peaks = self.data_model.get_peaks(model, max_order=n_orders) self.draw_peaks(peaks)
[docs] def label(self, label): self.data_model.phase_labels[self.data_index] = label self.data_view.update_composition_colors(self.data_model.ordinal_phase_labels()) self.next_button_callback(None)
[docs] def run(self): saxs_data = self.data_model.sas_data[self.data_index] composition_data = self.data_model.composition_data[self.data_index] components = self.data_model.composition_data[ self.data_model.component_variable ] widget = self.data_view.run( x=saxs_data[self.data_model.q_variable].values, y=saxs_data.values, all_compositions=self.data_model.composition_data, composition=composition_data, models=list(self.data_model.models.keys()), ternary=self.data_model.ternary, components=components, ) self.data_view.intensity.on_click(self.change_qstar_callback) self.data_view.phase_dropdown.observe(self.change_model_callback) self.data_view.current_label.value = self.data_model.phase_labels[ self.data_index ] self.data_view.current_index.value = str(self.data_index) self.data_view.bnext.on_click(self.next_button_callback) self.data_view.bprev.on_click(self.prev_button_callback) self.data_view.bgoto.on_click(self.goto_callback) self.data_view.all_composition.on_click(self.composition_click_callback) self.data_view.n_orders.observe(self.change_norder_callback) self.data_view.b0.on_click(lambda click: self.label("A")) self.data_view.b1.on_click(lambda click: self.label("B")) self.data_view.b2.on_click(lambda click: self.label("C")) self.data_view.b3.on_click(lambda click: self.label("D")) self.data_view.b4.on_click(lambda click: self.label("E")) return widget
################## ### Data Model ### ##################
[docs] class DataLabelerModel:
[docs] def __init__(self, dataset: xr.Dataset): self.dataset = dataset self.sas_data = dataset["sas"] # try to infer q-variable name... sasdims = list(self.sas_data.dims) sasdims.remove("sample") self.q_variable = sasdims[0] self.composition_data = dataset["composition"] # try to infer component-variable name... compdims = list(self.composition_data.dims) compdims.remove("sample") self.component_variable = compdims[0] if self.composition_data.sizes[self.component_variable] == 2: self.ternary = False elif self.composition_data.sizes[self.component_variable] == 3: self.ternary = True else: raise ValueError( ( f"Can only handle two or three components. You passed {self.composition_data[self.component_variable]}" ) ) self.phase_labels = ["Unlabeled"] * dataset.composition.shape[0] self.qstar = 0.02 self.init_models()
[docs] def ordinal_phase_labels(self): enc = OrdinalEncoder() return enc.fit_transform(np.asarray(self.phase_labels)[:, np.newaxis]).flatten()
[docs] def init_models(self): self.models = {} self.models["q*"] = {"labels": ["1"], "spacings": [1]} self.models["LAM"] = { "labels": [ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", ], "spacings": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, ], } self.models["HCP CYL"] = { "labels": [ "1", " sqrt(3)", " 2", " sqrt(7)", " 3", " sqrt(12)", " sqrt(13)", " 4", " sqrt(19)", " sqrt(21)", " 5", " sqrt(27)", " sqrt(28)", " sqrt(31)", " 6", " sqrt(37)", " sqrt(39)", " sqrt(43)", " sqrt(48)", "7", ], "spacings": [ 1, sqrt(3), 2, sqrt(7), 3, sqrt(12), sqrt(13), 4, sqrt(19), sqrt(21), 5, sqrt(27), sqrt(28), sqrt(31), 6, sqrt(37), sqrt(39), sqrt(43), sqrt(48), 7, ], } self.models["SC"] = { "labels": [ "1", "sqrt(2)", "sqrt(3)", "2", "sqrt(5)", "sqrt(6)", "sqrt(8)", "3", ], "spacings": [1, sqrt(2), sqrt(3), 2, sqrt(5), sqrt(6), sqrt(8), 3], } self.models["BCC"] = { "labels": [ "1", "sqrt(2)", "sqrt(3)", "2", "sqrt(5)", "sqrt(6)", "sqrt(7)", "sqrt(8)", "3", ], "spacings": [1, sqrt(2), sqrt(3), 2, sqrt(5), sqrt(6), sqrt(7), sqrt(8), 3], } self.models["FCC"] = { "labels": [ "sqrt(3)", "2", "sqrt(8)", "sqrt(11)", "sqrt(12)", "4", "sqrt(19)", ], "spacings": [ 1, 2 / sqrt(3), sqrt(8) / sqrt(3), sqrt(11) / sqrt(3), sqrt(12) / sqrt(3), 4 / sqrt(3), sqrt(19) / sqrt(3), ], } self.models["SPH"] = { "labels": ["1", "2/sqrt(3)", "2*sqrt(2)/sqrt(3)", "2"], "spacings": [1, 2 / sqrt(3), 2 * sqrt(2) / sqrt(3), 2], } self.models["HCP SPH"] = { "labels": [ "sqrt(32)", "6", "sqrt(41)", "sqrt(68)", "sqrt(96)", "sqrt(113)", ], "spacings": [ 1, 6 / sqrt(32), sqrt(41) / sqrt(32), sqrt(68) / sqrt(32), sqrt(96) / sqrt(32), sqrt(113) / sqrt(32), ], } self.models["DD"] = { "labels": [ "sqrt(2)", " sqrt(3)", " 2", " sqrt(6)", " sqrt(8)", " 3", " sqrt(10)", " sqrt(11)", ], "spacings": [ 1, sqrt(3) / sqrt(2), 2 / sqrt(2), sqrt(6) / sqrt(2), sqrt(8) / sqrt(2), 3 / sqrt(2), sqrt(10) / sqrt(2), sqrt(11) / sqrt(2), ], } self.models["GYR"] = { "labels": [ "sqrt(6)", "sqrt(8)", "sqrt(14)", "sqrt(16)", "sqrt(20)", "sqrt(22)", "sqrt(24)", "sqrt(26)", "sqrt(30)", "sqrt(32)", "sqrt(34)", "sqrt(38)", "sqrt(40)", "sqrt(42)", "sqrt(46)", "sqrt(48)", "sqrt(50)", "sqrt(52)", "sqrt(56)", "sqrt(62)", "sqrt(64)", "sqrt(66)", "sqrt(68)", "sqrt(70)", "sqrt(72)", "sqrt(74)", "sqrt(78)", "sqrt(80)", "sqrt(84)", "sqrt(86)", "sqrt(88)", "sqrt(90)", ], "spacings": [ 6, 8, 14, 16, 20, 22, 24, 26, 30, 32, 34, 38, 40, 42, 46, 48, 50, 52, 56, 62, 64, 66, 68, 70, 72, 74, 78, 80, 84, 86, 88, 90, ], } self.models["GYR"]["spacings"] = [ sqrt(i) / sqrt(6) for i in self.models["GYR"]["spacings"] ]
[docs] def get_peaks(self, model, qstar=None, max_order=4): if qstar is None: qstar = self.qstar else: self.qstar = qstar peaks = {} for i, (spacing, label) in enumerate( zip(self.models[model]["spacings"], self.models[model]["labels"]) ): if i >= max_order: break peaks[label] = spacing * qstar return peaks
################# ### Data View ### #################
[docs] class DataLabelerView:
[docs] def __init__(self): self.intensity = None self.nverts = 8
[docs] def update_plot(self, x, y, composition): self.intensity.update({"x": x, "y": y}) if self.ternary: self.composition.update( {"a": (composition[0],), "b": (composition[1],), "c": (composition[2],)} ) else: self.composition.update({"x": (composition[0],), "y": (composition[1],)})
[docs] def remove_vertical_lines(self): self.fig1.layout["shapes"] = []
[docs] def add_vertical_line(self, x, y0=0, y1=128, row=1, col=1, line_kw=None): if line_kw is None: line_kw = dict(color="red", dash="dot", width=0.3) self.fig1.add_shape( name="vertical", xref="x", yref="paper", x0=x, x1=x, y0=0, y1=1, line=line_kw, )
[docs] def update_composition_colors(self, colors): self.all_composition.marker["color"] = colors
[docs] def run(self, x, y, all_compositions, composition, models, ternary, components): self.ternary = ternary self.fig1 = go.FigureWidget(go.Scatter(x=x, y=y, mode="markers")) self.intensity = self.fig1.data[0] self.fig1.update_yaxes(type="log") self.fig1.update_xaxes(type="log") self.fig1.update_xaxes({"range": (np.log10(0.005), np.log10(1.0))}) self.fig1.update_layout( height=300, width=400, margin=dict(t=10, b=10, l=10, r=0) ) for i in range(self.nverts): self.add_vertical_line(0.02 * i) self.fig1.layout.shapes[i].visible = False if self.ternary: self.fig2 = go.FigureWidget( [ go.Scatterternary( a=all_compositions[:, 0], b=all_compositions[:, 1], c=all_compositions[:, 2], mode="markers", marker={ "color": ["black"] * len(x), "colorscale": px.colors.qualitative.Prism, }, customdata=list(range(len(all_compositions))), opacity=1.0, showlegend=False, ), go.Scatterternary( a=(composition[0],), b=(composition[1],), c=(composition[2],), mode="markers", showlegend=False, marker={ "color": "red", "symbol": "hexagon-open", "size": 10, }, ), ] ) else: self.fig2 = go.FigureWidget( [ go.Scatter( x=all_compositions[:, 0], y=all_compositions[:, 1], mode="markers", marker={ "color": ["black"] * len(x), "colorscale": px.colors.qualitative.Prism, }, customdata=list(range(len(all_compositions))), opacity=1.0, showlegend=False, ), go.Scatter( x=(composition[0],), y=(composition[1],), mode="markers", showlegend=False, marker={ "color": "red", "symbol": "hexagon-open", "size": 10, }, ), ] ) self.all_composition = self.fig2.data[0] self.composition = self.fig2.data[1] self.fig2.update_layout( height=300, width=500, margin=dict(t=25, b=35, l=10), xaxis_title=components[0].values[()], yaxis_title=components[1].values[()], ) plot_box = ipywidgets.HBox([self.fig1, self.fig2]) self.b0 = ipywidgets.Button(description="PhaseA") self.b1 = ipywidgets.Button(description="PhaseB") self.b2 = ipywidgets.Button(description="PhaseC") self.b3 = ipywidgets.Button(description="PhaseD") self.b4 = ipywidgets.Button(description="PhaseE") self.bprev = ipywidgets.Button(description="Prev") self.bnext = ipywidgets.Button(description="Next") self.bgoto = ipywidgets.Button(description="GoTo") self.current_index = ipywidgets.IntText( description="Data Index:", value=0, min=0 ) self.n_orders = ipywidgets.BoundedIntText( description="n_orders", min=1, max=8, value=4 ) self.phase_dropdown = ipywidgets.Dropdown(options=models) self.current_label_label = ipywidgets.Label("Current Label:") self.current_label = ipywidgets.Label("") phase_box = ipywidgets.HBox( [ self.phase_dropdown, self.n_orders, self.current_label_label, self.current_label, ] ) buttons_hbox1 = ipywidgets.HBox([self.b0, self.b1, self.b2, self.b3, self.b4]) buttons_hbox2 = ipywidgets.HBox( [ self.current_index, self.bgoto, self.bprev, self.bnext, ] ) self.output = ipywidgets.Output() plot_tab = ipywidgets.VBox( [buttons_hbox1, buttons_hbox2, phase_box, plot_box, self.output] ) ### Tab2 # dataset_loader_tab = ipywidgets.VBox( # [] # ) # tabs = ipywidgets.Tab([dataset_loader_tab,plot_tab]) return plot_tab