Source code for AFL.automation.shared.DatasetWidget
import ast
import re
from collections import defaultdict
from typing import Optional, Dict, List
import ipywidgets # type: ignore
import xarray as xr
import numpy as np
import plotly.express as px # type: ignore
import plotly.graph_objects as go # type: ignore
[docs]
class DatasetWidget:
[docs]
def __init__(
self,
dataset: xr.Dataset,
sample_dim: str = "sample",
scatt_variables: Optional[List[str]] = None,
comps_variable: Optional[str] = None,
comps_color_variable: Optional[str] = None,
xmin: float = 0.001,
xmax: float = 1.0,
):
"""Interactive widget for viewing compositionally varying scattering data
Parameters
----------
dataset: xr.Dataset
`xarray.Dataset` containing scattering data and compositions to be plotted.
sample_dim: str
The name of the `xarray` dimension corresponding to sample variation, typically "sample"
comps_variable: Optional[str]
The name of the `xarray` variable to plot as compositional data. Optional, if not specified, can be
customized in the GUI.
Only the first two columns of the data will be used in the plot. If the compositions are in separate
`xarray.DataArray`s, they should be grouped into
single `xarray.DataArray`s like so:
```python
ds['comps'] = ds[['A','B','C']].to_array('component').transpose(...,'component')
```
comps_color_variable: Optional[str]
The name of the `xarray` variable to use as the colorscale of the compositional data scatter plot. Optional,
if not specified, can be customized in the GUI.
xmin, xmax: float
Set the default q-range of the scattering data. Can be customized in the GUI
Usage
-----
```python
widget = DatasetWidget(ds)
widget.run()
```
"""
# preprocess the dataset before sending to the data model
self.data_view = DatasetWidget_View(
initial_scatt_variables=scatt_variables,
initial_comps_variable=comps_variable,
initial_comps_color_variable=comps_color_variable,
)
self.data_model = DatasetWidget_Model(dataset, sample_dim)
self.data_index = 0
self.initial_xmin = xmin
self.initial_xmax = xmax
[docs]
def next_button_callback(self, *args):
self.data_index += 1
self.data_view.text_input["index"].value = self.data_index
self.update_plots()
[docs]
def prev_button_callback(self, *args):
self.data_index -= 1
self.data_view.text_input["index"].value = self.data_index
self.update_plots()
[docs]
def goto_callback(self, *args):
index = self.data_view.text_input["index"].value
self.data_index = index
self.update_plots()
[docs]
def composition_click_callback(self, figure, location, click):
index = location.point_inds[0]
self.data_index = int(index)
self.data_view.text_input["index"].value = self.data_index
self.update_plots()
[docs]
def update_composition_plot(self):
(
x,
y,
z,
xname,
yname,
zname,
) = self.get_comps()
if z is None:
self.data_view.update_selected(x=(x[self.data_index],), y=(y[self.data_index],))
else:
self.data_view.update_selected(
x=(x[self.data_index],), y=(y[self.data_index],), z=(z[self.data_index],)
)
[docs]
def update_scattering_plot(self):
if len(self.data_view.dropdown["scatter"].value)>0:
append = False
for scatt_variable in self.data_view.dropdown["scatter"].value:
if scatt_variable != "None":
x, y = self.data_model.get_scattering(scatt_variable, self.data_index)
self.data_view.plot_sas(x, y, name=scatt_variable, append=append)
append=True
[docs]
def get_comps(self):
composition_variable = self.data_view.dropdown["composition"].value
x, y, z, xname, yname, zname = self.data_model.get_composition(composition_variable)
return x, y, z, xname, yname, zname
[docs]
def initialize_plots(self, *args):
self.update_scattering_plot()
# need to plot comps manually, so we don't redraw "all comps" every time
x, y, z, xname, yname, zname = self.get_comps()
if self.data_view.dropdown["composition_color"].value != "None":
colors = self.data_model.dataset[
self.data_view.dropdown["composition_color"].value
].values
else:
colors = None
if z is None:
self.data_view.plot_comp(x=x, y=y, xname=xname, yname=yname, colors=colors)
else:
self.data_view.plot_comp(x=x, y=y, z=z, xname=xname, yname=yname, zname=zname, colors=colors)
self.data_view.comp_fig.data[0].on_click(self.composition_click_callback)
[docs]
def update_colors(self, *args):
if self.data_view.dropdown["composition_color"].value != "None":
colors = self.data_model.dataset[
self.data_view.dropdown["composition_color"].value
].values
else:
colors = None
self.data_view.update_colorscale(colors)
[docs]
def apply_sel(self, *args):
key = self.data_view.dropdown["sel"].value
value = ast.literal_eval(self.data_view.text_input["sel"].value)
self.data_model.apply_sel({key: value})
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
self.update_dropdowns()
[docs]
def apply_isel(self, *args):
key = self.data_view.dropdown["sel"].value
value = ast.literal_eval(self.data_view.text_input["sel"].value)
self.data_model.apply_isel({key: value})
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
self.update_dropdowns()
[docs]
def extract_var(self, *args):
extract_from_var = self.data_view.dropdown["extract_from_var"].value
extract_from_coord = self.data_view.dropdown["extract_from_coord"].value
self.data_model.extract_var(extract_from_var,extract_from_coord)
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
self.update_dropdowns()
[docs]
def combine_vars(self, *args):
combined_var = self.data_view.text_input["combined_var"].value
to_combine_vars = ast.literal_eval(
self.data_view.text_input["to_combine"].value
)
self.data_model.combine_vars(
combined_var=combined_var, to_combine_vars=to_combine_vars
)
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
self.update_dropdowns()
[docs]
def reset_dataset(self, *args):
self.data_model.reset_dataset()
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
[docs]
def update_dropdowns(self, *args):
sample_vars, comp_vars, scatt_vars = self.data_model.split_vars()
self.data_view.update_dropdowns(
sample_vars=sample_vars,
scatt_vars=scatt_vars,
comp_vars=comp_vars,
)
[docs]
def update_sample_dim(self, *args):
# self.data_model.sample_dim = self.data_view.text_input["sample_dim"].value
self.data_model.sample_dim = self.data_view.dropdown["sample_dim"].value
self.data_view.initial_comps_variable = "None"
self.data_view.initial_comps_color_variable = "None"
self.data_index = 0
self.data_view.text_input['index'].value = self.data_index
self.update_dropdowns()
[docs]
def update_extract_coords(self, change):
if change["type"] == "change" and change["name"] == "value":
extract_from_var = self.data_view.dropdown["extract_from_var"].value
dims = self.data_model.get_non_sample_dims(extract_from_var)
self.data_view.dropdown["extract_from_coord"].options = (
self.data_model.dataset[dims[0]].values
)
[docs]
def run(self):
widget = self.data_view.run()
self.update_dropdowns()
self.data_view.dataset_html.value = self.data_model.dataset._repr_html_()
self.data_view.dropdown["sample_dim"].options = list(self.data_model.dataset.sizes.keys())
self.data_view.dropdown["sample_dim"].value = self.data_model.sample_dim
self.data_view.dropdown["sample_dim"].observe(self.update_sample_dim)
self.data_view.text_input["xmin"].value = self.initial_xmin
self.data_view.text_input["xmax"].value = self.initial_xmax
self.data_view.text_input["cmin"].observe(self.update_colors)
self.data_view.text_input["cmax"].observe(self.update_colors)
self.data_view.button["update_plot"].on_click(self.initialize_plots)
self.data_view.button["next"].on_click(self.next_button_callback)
self.data_view.button["prev"].on_click(self.prev_button_callback)
self.data_view.button["sel"].on_click(self.apply_sel)
self.data_view.button["isel"].on_click(self.apply_isel)
self.data_view.button["reset_dataset"].on_click(self.reset_dataset)
self.data_view.button["combine"].on_click(self.combine_vars)
self.data_view.button["extract"].on_click(self.extract_var)
self.data_view.dropdown["extract_from_var"].observe(self.update_extract_coords)
return widget
##################
### Data Model ###
##################
[docs]
class DatasetWidget_Model:
[docs]
def __init__(self, dataset: xr.Dataset, sample_dim: str):
self.original_dataset = dataset
self.working_dataset = dataset.copy()
self.sample_dim = sample_dim
@property
def dataset(self):
return self.working_dataset
@dataset.setter
def dataset(self, value):
self.working_dataset = value
[docs]
def split_vars(self):
"""Heuristically try to split vars into categories"""
vars = self.dataset.keys()
sample_vars = []
comp_vars = []
scatt_vars = []
for var in vars:
if len(self.dataset[var].dims) == 1 and (
self.dataset[var].dims[0] == self.sample_dim
):
sample_vars.append(var)
else:
try:
other_dim = (
self.dataset[var].transpose(self.sample_dim, ...).dims[1]
)
except ValueError:
continue
if (
self.dataset.sizes[other_dim] < 10
): # stupid guess at compositions, hopefully this is always 2
comp_vars.append(var)
else:
scatt_vars.append(var)
return sample_vars, comp_vars, scatt_vars
[docs]
def get_non_sample_dims(self, var: str):
dims = self.dataset[var].transpose(self.sample_dim, ...).dims[1:]
return dims
[docs]
def apply_sel(self, kw):
temp_dataset = self.dataset.copy()
for k, v in kw.items():
temp_dataset = temp_dataset.set_index({self.sample_dim: k}).sel(
{self.sample_dim: v}
)
self.dataset = temp_dataset
[docs]
def apply_isel(self, kw):
temp_dataset = self.dataset.copy()
for k, v in kw.items():
temp_dataset = temp_dataset.set_index({self.sample_dim: k}).isel(
{self.sample_dim: v}
)
self.dataset = temp_dataset
[docs]
def combine_vars(self, combined_var: str, to_combine_vars: List[str]):
# need to figure out dim name...
reg = re.compile("component([0-9]*)")
dims = [reg.findall(str(k)) for k in self.dataset.dims]
dim_nums = [
int(d[0]) for d in dims if len(d) == 1 and d[0]
] # dim num should be length1 and not empty
try:
new_dim = f"component{max(dim_nums)+1}"
except ValueError:
new_dim = f"component1"
self.dataset[combined_var] = (
self.dataset[to_combine_vars].to_array(new_dim).transpose(..., new_dim)
)
[docs]
def extract_var(self, extract_from_var: str, extract_from_coord: str):
var_name = f'{extract_from_var}_{extract_from_coord}'
dim = self.get_non_sample_dims(extract_from_var)[0]
self.dataset[var_name] = self.dataset[extract_from_var].sel({dim:extract_from_coord})
[docs]
def get_composition(self, variable):
dataset = self.dataset.transpose(self.sample_dim, ...)
x = dataset[variable][:, 0].values
y = dataset[variable][:, 1].values
if dataset[variable].values.shape[1]>2:
z = dataset[variable][:, 2].values
else:
z = None
component_dim = dataset[variable].transpose(self.sample_dim, ...).dims[1]
if z is None:
xname, yname = dataset[variable][component_dim].values[:2]
zname = None
else:
xname, yname, zname = dataset[variable][component_dim].values[:3]
return x, y, z, xname, yname, zname
[docs]
def get_scattering(self, variable, index):
sds = self.dataset[variable].isel(**{self.sample_dim: index})
x = sds[sds.squeeze().dims[0]].values
y = sds.values
return x, y
#################
### Data View ###
#################
[docs]
class DatasetWidget_View:
[docs]
def __init__(
self,
initial_scatt_variables: Optional[List[str]] = None,
initial_comps_variable: Optional[str] = None,
initial_comps_color_variable: Optional[str] = None,
):
self.scatt_fig = None
self.comp_fig = None
self.initial_scatt_variables = initial_scatt_variables
self.initial_comps_variable = initial_comps_variable
self.initial_comps_color_variable = initial_comps_color_variable
self.tabs: ipywidgets.Tab = ipywidgets.Tab()
self.dropdown: Dict[str, ipywidgets.Dropdown] = {}
self.button: Dict[str, ipywidgets.Button] = {}
self.checkbox: Dict[str, ipywidgets.Checkbox] = {}
self.text_input: Dict[
str, ipywidgets.FloatText | ipywidgets.IntText | ipywidgets.Text
] = {}
# keep track of dropdowns in categories in case options need to be updated
self.dropdown_categories: Dict[str, List] = defaultdict(list)
[docs]
def update_colorscale(self,colors=None):
if len(self.comp_fig.data) == 0:
return
if colors is not None:
self.comp_fig.data[0]["marker"]["color"] = colors
#self.comp_fig.data[0]["marker"]["customdata"] = colors
self.comp_fig.data[0]["marker"]["cmin"] = self.text_input["cmin"].value
self.comp_fig.data[0]["marker"]["cmax"] = self.text_input["cmax"].value
[docs]
def update_dropdowns(self, sample_vars=None, scatt_vars=None, comp_vars=None):
if sample_vars is not None:
for dropdown in self.dropdown_categories["sample"]:
dropdown.options = sample_vars
# set the default value if possible
if "Colors" in dropdown.description:
if self.initial_comps_color_variable is None:
self.initial_comps_color_variable = "None"
dropdown.options = ["None"] + list(dropdown.options)
dropdown.value = self.initial_comps_color_variable
if scatt_vars is not None:
for dropdown in self.dropdown_categories["scatter"]:
dropdown.options = ["None"] + scatt_vars
if self.initial_scatt_variables is None:
initial_scatt_variables = ["None"]
else:
initial_scatt_variables = self.initial_scatt_variables
dropdown.value = initial_scatt_variables
if comp_vars is not None:
for dropdown in self.dropdown_categories["composition"]:
dropdown.options = ["None"] + comp_vars
# set the default value if possible
if self.initial_comps_variable is None:
initial_comps_variable = "None"
else:
initial_comps_variable = self.initial_comps_variable
dropdown.value = initial_comps_variable
[docs]
def plot_sas(self, x, y, name="SAS", append=False):
scatt1 = go.Scatter(x=x, y=y, name=name, mode="markers")
if not append:
self.scatt_fig.data = []
self.scatt_fig.add_trace(scatt1)
# update xaxis
if self.checkbox["logx"].value:
self.scatt_fig.update_xaxes(type="log")
xrange = (
np.log10(self.text_input["xmin"].value),
np.log10(self.text_input["xmax"].value),
)
else:
self.scatt_fig.update_xaxes(type="linear")
xrange = (
self.text_input["xmin"].value,
self.text_input["xmax"].value,
)
self.scatt_fig.update_xaxes({"range": xrange})
# update yaxis
if self.checkbox["logy"].value:
self.scatt_fig.update_yaxes(type="log")
else:
self.scatt_fig.update_yaxes(type="linear")
[docs]
def plot_comp(self, x, y, z=None, xname="x", yname="y", zname="z", colors=None):
if colors is None:
colors = ([0] * len(x),)
else:
self.text_input["cmin"].value = min(colors)
self.text_input["cmax"].value = max(colors)
if z is None:
scatt1 = go.Scatter(
x=x,
y=y,
mode="markers",
marker={
"color": colors,
"showscale": True,
"cmin": self.text_input["cmin"].value,
"cmax": self.text_input["cmax"].value,
"colorscale": px.colors.get_colorscale(
self.dropdown["composition_colorscale"].value
),
"colorbar": dict(thickness=15, outlinewidth=0),
},
opacity=1.0,
showlegend=False,
customdata=colors,
hovertemplate=(
f"""{xname}: %{{x:3.2f}} <br>"""
f"""{yname}: %{{y:3.2f}} <br>"""
"""color: %{customdata:3.2f}"""
),
)
scatt2 = go.Scatter(
x=(x[0],),
y=(y[0],),
mode="markers",
showlegend=False,
marker={
"color": "red",
"symbol": "hexagon-open",
"size": 10,
},
)
self.comp_fig.update_layout(xaxis_title=xname, yaxis_title=yname)
else:
scatt1 = go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker={
"color": colors,
"showscale": True,
"cmin": self.text_input["cmin"].value,
"cmax": self.text_input["cmax"].value,
"colorscale": px.colors.get_colorscale(
self.dropdown["composition_colorscale"].value
),
"colorbar": dict(thickness=15, outlinewidth=0),
},
opacity=1.0,
showlegend=False,
customdata=colors,
hovertemplate=(
f"""{xname}: %{{x:3.2f}} <br>"""
f"""{yname}: %{{y:3.2f}} <br>"""
f"""{zname}: %{{z:3.2f}} <br>"""
"""color: %{customdata:3.2f}"""
),
)
scatt2 = go.Scatter3d(
x=(x[0],),
y=(y[0],),
z=(z[0],),
mode="markers",
showlegend=False,
marker={
"color": "red",
"symbol": "circle-open",
"size": 10,
},
)
self.comp_fig.update_layout(
scene=dict(
xaxis_title=xname,
yaxis_title=yname,
zaxis_title=zname
)
)
if hasattr(self.comp_fig, "data"):
self.comp_fig.data = []
self.comp_fig.add_trace(scatt1)
self.comp_fig.add_trace(scatt2)
self.comp_fig.update_scenes(aspectmode="cube")
[docs]
def init_plots(self):
self.scatt_fig = go.FigureWidget(
[],
layout=dict(
xaxis_title="q",
yaxis_title="I",
height=300,
width=500,
legend=dict(yanchor="top", xanchor="right", y=0.99, x=0.99),
),
)
self.scatt_fig.update_layout(margin=dict(t=10, b=10, l=10, r=10))
self.scatt_fig.update_yaxes(type="log")
self.scatt_fig.update_xaxes(type="log")
self.scatt_fig.update_xaxes(
{
"range": (
np.log10(self.text_input["xmin"].value),
np.log10(self.text_input["xmax"].value),
)
}
)
self.comp_fig = go.FigureWidget(
[],
layout=dict(
height=300,
width=500,
),
)
self.comp_fig.update_layout(margin=dict(t=10, b=10, l=10, r=10))
[docs]
def init_buttons(self):
self.button["prev"] = ipywidgets.Button(description="Previous")
self.button["next"] = ipywidgets.Button(description="Next")
self.button["update_plot"] = ipywidgets.Button(description="Plot")
self.button["sel"] = ipywidgets.Button(description="Apply sel")
self.button["isel"] = ipywidgets.Button(description="Apply isel")
self.button["reset_dataset"] = ipywidgets.Button(description="Reset Dataset")
self.button["combine"] = ipywidgets.Button(description="Combine Vars")
self.button["extract"] = ipywidgets.Button(description="Extract Var")
[docs]
def init_checkboxes(self):
self.checkbox["logx"] = ipywidgets.Checkbox(description="log x", value=True)
self.checkbox["logy"] = ipywidgets.Checkbox(description="log y", value=True)
[docs]
def init_dropdowns(self):
self.dropdown["scatter"] = ipywidgets.SelectMultiple(
options=[],
layout=ipywidgets.Layout(height="250px"),
)
self.dropdown_categories["scatter"].append(self.dropdown["scatter"])
self.dropdown["composition"] = ipywidgets.Select(
options=[],
layout=ipywidgets.Layout(height="250px"),
)
self.dropdown_categories["composition"].append(self.dropdown["composition"])
self.dropdown["composition_color"] = ipywidgets.Dropdown(
options=[],
description="Colors",
)
self.dropdown_categories["sample"].append(self.dropdown["composition_color"])
self.dropdown["composition_colorscale"] = ipywidgets.Dropdown(
options=px.colors.named_colorscales(),
description="Colorscale",
value="bluered",
)
self.dropdown["sel"] = ipywidgets.Dropdown(options=[])
self.dropdown_categories["sample"].append(self.dropdown["sel"])
self.dropdown["extract_from_var"] = ipywidgets.Dropdown(options=[])
self.dropdown_categories["composition"].append(
self.dropdown["extract_from_var"]
) # this is a hack...
self.dropdown["extract_from_coord"] = ipywidgets.Dropdown(options=[])
self.dropdown["sample_dim"] = ipywidgets.Dropdown(
description="Sample Dim",
options=[]
)
[docs]
def init_inputs(self):
self.text_input["cmin"] = ipywidgets.FloatText(
value=0.0,
layout=ipywidgets.Layout(width='100px')
)
self.text_input["cmax"] = ipywidgets.FloatText(
value=1.0,
layout = ipywidgets.Layout(width='100px')
)
self.text_input["index"] = ipywidgets.IntText(
description="Data Index:", value=0, min=0
)
self.text_input["xmin"] = ipywidgets.FloatText(
description="xmin",
value=0.001,
)
self.text_input["xmax"] = ipywidgets.FloatText(
description="xmax",
value=1.0,
)
self.text_input["sel"] = ipywidgets.Text(placeholder="e.g, 0, 0.75, or 'T1'")
self.text_input["combined_var"] = ipywidgets.Text(
placeholder="'comps'",
)
self.text_input["to_combine"] = ipywidgets.Text(
placeholder="e.g. ['conc_A','conc_B']"
)
[docs]
def run(self):
self.init_dropdowns()
self.init_checkboxes()
self.init_buttons()
self.init_inputs()
self.init_plots()
# Plot Tab
plot_top_control_box = ipywidgets.VBox(
[
ipywidgets.HBox(
[
self.dropdown["sample_dim"],
]
),
ipywidgets.HBox([
self.dropdown['composition_color'],
ipywidgets.Label("Color min/max:"),
self.text_input["cmin"],
self.text_input["cmax"],
]),
]
)
plot_box = ipywidgets.VBox([
ipywidgets.HBox([self.dropdown['scatter'],self.scatt_fig]),
ipywidgets.HBox([self.dropdown['composition'],self.comp_fig]),
])
plot_bottom_control_box = ipywidgets.HBox(
[
self.text_input["index"],
#self.button["goto"],
self.button["update_plot"],
self.button["next"],
self.button["prev"],
]
)
plot_box = ipywidgets.VBox(
[plot_top_control_box, plot_bottom_control_box,plot_box]
)
# Config Tab
config_tab = ipywidgets.VBox(
[
self.dropdown["composition_colorscale"],
self.dropdown["sample_dim"],
self.text_input["xmin"],
self.text_input["xmax"],
self.checkbox["logx"],
self.checkbox["logy"],
self.button["update_plot"],
]
)
# select_tab
select_tab = ipywidgets.VBox(
[
self.button["reset_dataset"],
ipywidgets.HBox(
[
self.dropdown["sel"],
self.text_input["sel"],
self.button["sel"],
self.button["isel"],
]
),
ipywidgets.HBox(
[
self.text_input["combined_var"],
self.text_input["to_combine"],
self.button["combine"],
]
),
ipywidgets.HBox(
[
self.dropdown["extract_from_var"],
self.dropdown["extract_from_coord"],
# self.text_input["extract_coord_value"],
self.button["extract"],
]
),
]
)
# Dataset HTML Tab
self.dataset_html = ipywidgets.HTML()
dataset_tab = ipywidgets.VBox(
[
select_tab,
self.dataset_html,
]
)
# Build Tabs
self.tabs = ipywidgets.Tab([dataset_tab, plot_box, config_tab])
self.tabs.titles = ["Dataset", "Plot", "Config"]
self.tabs.selected_index = 1
#self.output = ipywidgets.Output()
#output_hbox = ipywidgets.HBox([self.output],layout=Layout(height='100px', overflow_y='auto'))
#out = ipywidgets.VBox([self.tabs,output_hbox])
return self.tabs