import copy
import itertools
import warnings
from typing import List, Optional, Dict, Set, Any
import numpy as np
from scipy.optimize import lsq_linear, Bounds
from AFL.automation.mixing.Context import Context
from AFL.automation.APIServer.Driver import Driver
from AFL.automation.mixing.PipetteAction import PipetteAction
from AFL.automation.mixing.Solution import Solution
from AFL.automation.shared.units import enforce_units
# --- Shared utility functions ---
def _extract_masses(solution: Solution, components: List[str], array: np.ndarray, unit: str = 'g') -> None:
if array is None:
array = np.zeros(len(components))
for i, component in enumerate(components):
if solution.contains(component):
array[i] = solution[component].mass.to(unit).magnitude
else:
array[i] = 0
def _extract_mass_fractions(stocks: List[Solution], components: List[str], matrix: np.ndarray) -> None:
for i, component in enumerate(components):
for j, stock in enumerate(stocks):
if stock.contains(component):
matrix[i, j] = stock.mass_fraction[component].to('').magnitude
else:
matrix[i, j] = 0
def _make_balanced_target(mass_transfers, target):
balanced_target = Solution(name="")
balanced_target.protocol = []
for stock, mass in mass_transfers.items():
measured = stock.measure_out(mass)
balanced_target = balanced_target + measured
balanced_target.protocol.append(
PipetteAction(
source=stock.location,
dest=target.location,
volume=measured.volume.to('ul').magnitude,
)
)
balanced_target.name = target.name + "-balanced"
for name, component in target:
if not balanced_target.contains(name):
balanced_target[name] = component.copy()
balanced_target[name].mass = '0.0 g'
return balanced_target
def _balance(mass_fraction_matrix: np.ndarray, target_masses: np.ndarray, bounds: Bounds, stocks: List[Solution]) -> List[Dict[Solution, str]]:
result = lsq_linear(mass_fraction_matrix, target_masses, bounds=bounds)
base_mass_transfer = {stock: f'{mass} g' for stock, mass in zip(stocks, result.x)}
mass_transfers = [base_mass_transfer]
negative_one_indices = [i for i, x in enumerate(result.active_mask) if x == -1]
for combination in itertools.product(negative_one_indices):
adjusted_transfer = base_mass_transfer.copy()
for idx in combination:
adjusted_transfer[stocks[idx]] = '0 g'
mass_transfers.append(adjusted_transfer)
return mass_transfers
# --- MassBalance Base Class ---
[docs]
class MassBalanceBase:
[docs]
def __init__(self):
self.balanced = []
self.bounds = None
@property
def components(self) -> Set[str]:
return self.stock_components.union(self.target_components)
@property
def stock_components(self) -> Set[str]:
raise NotImplementedError
@property
def target_components(self) -> Set[str]:
raise NotImplementedError
[docs]
def mass_fraction_matrix(self) -> np.ndarray:
components = list(self.components)
matrix = np.zeros((len(components), len(self.stocks)))
for i, component in enumerate(components):
for j, stock in enumerate(self.stocks):
if stock.contains(component):
matrix[i, j] = stock.mass_fraction[component].to('').magnitude
else:
matrix[i, j] = 0
return matrix
[docs]
def make_target_names(self, n_letters: int = 2, components=None, name_map: Optional[Dict] = None):
if components is None:
components = self.components
if name_map is None:
name_map = {}
for target in self.targets:
name = ''
for component in components:
comp = name_map.get(component, component[:n_letters])
name += f'{comp}{target.concentration[component].to("mg/ml").magnitude:.2f}'
target.name = name + '-mgml'
[docs]
def balance(self, tol=0.05):
if any([stock.location is None for stock in self.stocks]):
raise ValueError("Some stocks don't have a location specified. This should be specified when the stocks are instantiated")
self._set_bounds()
components = list(self.components)
target_masses = np.zeros(len(components))
balanced_masses = np.zeros(len(components))
self.balanced = []
for target in self.targets:
_extract_masses(target, components, array=target_masses)
mass_transfers = _balance(self.mass_fraction_matrix(), target_masses, self.bounds, self.stocks)
balanced_targets = []
for transfers in mass_transfers:
balanced_target = _make_balanced_target(transfers, target)
_extract_masses(balanced_target, components, array=balanced_masses)
difference = (balanced_masses - target_masses) / target_masses
if all(difference < tol):
balanced_targets.append({
'target':balanced_target,
'difference':difference,
'transfers':transfers,
})
if not balanced_targets:
warnings.warn(f'No suitable mass balance found for {target.name}')
self.balanced.append({
'target':target,
'balanced_target':None,
'transfers':None,
})
else:
balanced_target = min(balanced_targets, key=lambda x: sum(x['difference']))
self.balanced.append({
'target':target,
'balanced_target':balanced_target['target'],
'transfers':balanced_target['transfers'],
})
def _set_bounds(self):
raise NotImplementedError
# --- MassBalanceContext ---
[docs]
class MassBalance(MassBalanceBase, Context):
[docs]
def __init__(self, name='MassBalance', minimum_volume='20 ul'):
Context.__init__(self, name=name)
MassBalanceBase.__init__(self)
self.context_type = 'MassBalance'
self.stocks = []
self.targets = []
self.minimum_volume = enforce_units(minimum_volume, 'volume')
self.config = {'stocks': [], 'targets': [], 'minimum_volume': minimum_volume}
def __call__(self, reset=False, reset_stocks=False, reset_targets=False):
if reset or reset_stocks:
self.stocks.clear()
if reset or reset_targets:
self.targets.clear()
return self
@property
def stock_components(self) -> Set[str]:
return {component for stock in self.stocks for component in stock.components}
@property
def target_components(self) -> Set[str]:
return {component for target in self.targets for component in target.components}
def _set_bounds(self):
self.bounds = Bounds(
lb=[stock.measure_out(self.minimum_volume).mass.to('g').magnitude for stock in self.stocks],
ub=[np.inf] * len(self.stocks),
keep_feasible=False
)
# --- MassBalanceDriver ---
[docs]
class MassBalanceDriver(MassBalanceBase, Driver):
defaults = {'minimum_volume': '20 ul', 'stocks': [], 'targets': [], 'tol': 1e-3}
[docs]
def __init__(self, overrides=None):
MassBalance.__init__(self)
Driver.__init__(self, name='MassBalance', defaults=self.gather_defaults(), overrides=overrides)
self.minimum_transfer_volume = None
self.stocks = []
self.targets = []
@property
def stock_components(self) -> Set[str]:
if not self.stocks:
raise ValueError('No stocks have been added; Must call process_stocks before accessing components')
return {component for stock in self.stocks for component in stock.components}
@property
def target_components(self) -> Set[str]:
if not self.targets:
raise ValueError('No targets have been added; Must call process_stocks before accessing components')
return {component for target in self.targets for component in target.components}
[docs]
def process_stocks(self):
self.stocks = []
for stock_config in self.config['stocks']:
stock = Solution(**stock_config)
self.stocks.append(stock)
[docs]
def process_targets(self):
self.targets = []
for target_config in self.config['targets']:
target = Solution(**target_config)
self.targets.append(target)
[docs]
def add_stock(self, solution: Dict, reset: bool = False):
if reset:
self.reset_stocks()
self.config['stocks'] = self.config['stocks'] + [solution]
[docs]
def add_target(self, target: Dict, reset: bool = False):
if reset:
self.reset_targets()
self.config['targets'] = self.config['targets'] + [target]
[docs]
def reset_stocks(self):
self.config['stocks'] = []
[docs]
def reset_targets(self):
self.config['targets'] = []
def _set_bounds(self):
self.minimum_transfer_volume = enforce_units(self.config['minimum_volume'], 'volume')
self.bounds = Bounds(
lb=[stock.measure_out(self.minimum_transfer_volume).mass.to('g').magnitude for stock in self.stocks],
ub=[np.inf] * len(self.stocks),
keep_feasible=False
)
[docs]
def balance(self):
self.process_stocks()
self.process_targets()
super().balance(tol=self.config['tol'])