masskit_ai.spectrum.small_mol package¶
Subpackages¶
- masskit_ai.spectrum.small_mol.models package
Submodules¶
masskit_ai.spectrum.small_mol.small_mol_datasets module¶
- class masskit_ai.spectrum.small_mol.small_mol_datasets.TandemArrowSearchDataset(*args: Any, **kwargs: Any)¶
Bases:
SpectrumDataset
class for accessing a tandem dataframe of spectra
How workers are set up requires some explanation:
if there is more than one gpu, each gpu has a corresponding main process.
the numbering of this gpu within a node is given by the environment variable LOCAL_RANK
if there is more than one node, the numbering of the node is given by the NODE_RANK environment variable
the number of nodes times the number of gpus is given by the WORLD_SIZE environment variable
the number of gpus on the current node can be found by parsing the PL_TRAINER_GPUS environment variable
these environment variables are only available when doing ddp. Otherwise sharding should be done using id and num_workers in torch.utils.data.get_worker_info()
each main process creates an instance of Dataset. This instance is NOT initialized by worker_init_fn, only the constructor.
each worker is created by forking the main process, giving each worker a copy of Dataset, already constructed. - each of these forked Datasets is initialized by worker_init_fn - the global torch.utils.data.get_worker_info() contains a reference to the forked Dataset and other info - these workers then take turns feeding minibatches into the training process - *important* since each worker is a copy, __init__() is only called once, only in the main process
the dataset in the main processes is used by other steps, such as the validation callback - this means that if there is any important initialization done in worker_init_fn, it must explicitly be done to the main process Dataset
- alternative sources of parameters:
global_rank = trainer.node_rank * trainer.nudatasetm_processes + process_idx
world_size = trainer.num_nodes * trainer.num_processes
- property data¶
- property data_search¶
- get_data_row(index)¶
given the index, return corresponding data for the index
input is index in query db (on self.data)
do search on index and score. return n hits, where n is set in config somewhere
if hits are below threshold, set to 0
output is dict, query, hits, taninoto, where each is an array of length hitlist
query and hit are spectra, tanimoto is float
- get_x(data_row)¶
given the data row, return the input to the network
the input in this case is a pairwise tensor of query spectrum, hit spectrum with length hitlist create embedding that turns two arrays of spectra into a tensor of size 2, mz_max
- get_y(data_row)¶
returns tanimoto scores, of size hitlist size
- Parameters:
data_row – the data row
- Returns:
torch tensor
- property index¶
- property row2id¶
- property row2id_search¶
- spectrum2array(spectrum)¶
given a spectrum, create a spectrum array
- Parameters:
spectrum – the spectrum
- Returns:
the numpy spectrum array
masskit_ai.spectrum.small_mol.small_mol_lightning module¶
- class masskit_ai.spectrum.small_mol.small_mol_lightning.SearchLightningModule(*args: Any, **kwargs: Any)¶
Bases:
BaseSpectrumLightningModule
pytorch lightning module used to train on search results
- pairwise_search_model(batch)¶
feed batched hitlist data and corresponding spectra into the model. accomplishes this by reshaping data with form (batch, hitlist, query/hit, spectrum) into (batch, spectrum) vectors with model output (batch, fingerprint) converted to (batch, hitlist, query/hit, fingerprint) cosine score is then computed between query/hit pairs of fingerprints
- Parameters:
batch – input to the batch
output – expected output
- Returns:
results from the model
- training_step(batch, batch_idx)¶
- validation_test_step(batch, batch_idx, loop)¶
step shared with test and validation loops
- Parameters:
batch – batch
batch_idx – index into data for batch
loop – the name of the loop
- Returns:
loss
- class masskit_ai.spectrum.small_mol.small_mol_lightning.SmallMolSearchDataModule(*args: Any, **kwargs: Any)¶
Bases:
MasskitDataModule
data loader for tandem small molecule search
- get_subsets(set_to_load)¶
create datasets
- Parameters:
set_to_load – train, valid or test dataset
- Returns:
a list of datasets
masskit_ai.spectrum.small_mol.small_mol_losses module¶
- class masskit_ai.spectrum.small_mol.small_mol_losses.SearchLoss(*args: Any, **kwargs: Any)¶
Bases:
BaseLoss
loss based on search hitlist
- forward(output, batch, params=None) torch.Tensor ¶
Computes hitlist loss assumes: - batch.y < 0 means hitlist row is invalid - batch.y >= 1.0 means identical match - batch.y and output.y_prime is structured (batch, …, hitlistrow)
- Parameters:
output – ModelOutput
batch – ModelInput
params – optional parameters, defaults to None
- Returns:
loss