masskit_ai.spectrum.small_mol package

Subpackages

Submodules

masskit_ai.spectrum.small_mol.small_mol_datasets module

class masskit_ai.spectrum.small_mol.small_mol_datasets.TandemArrowSearchDataset(*args: Any, **kwargs: Any)

Bases: SpectrumDataset

class for accessing a tandem dataframe of spectra

How workers are set up requires some explanation:

  • if there is more than one gpu, each gpu has a corresponding main process.

  • the numbering of this gpu within a node is given by the environment variable LOCAL_RANK

  • if there is more than one node, the numbering of the node is given by the NODE_RANK environment variable

  • the number of nodes times the number of gpus is given by the WORLD_SIZE environment variable

  • the number of gpus on the current node can be found by parsing the PL_TRAINER_GPUS environment variable

  • these environment variables are only available when doing ddp. Otherwise sharding should be done using id and num_workers in torch.utils.data.get_worker_info()

  • each main process creates an instance of Dataset. This instance is NOT initialized by worker_init_fn, only the constructor.

  • each worker is created by forking the main process, giving each worker a copy of Dataset, already constructed. - each of these forked Datasets is initialized by worker_init_fn - the global torch.utils.data.get_worker_info() contains a reference to the forked Dataset and other info - these workers then take turns feeding minibatches into the training process - *important* since each worker is a copy, __init__() is only called once, only in the main process

  • the dataset in the main processes is used by other steps, such as the validation callback - this means that if there is any important initialization done in worker_init_fn, it must explicitly be done to the main process Dataset

  • alternative sources of parameters:
    • global_rank = trainer.node_rank * trainer.nudatasetm_processes + process_idx

    • world_size = trainer.num_nodes * trainer.num_processes

property data
get_data_row(index)

given the index, return corresponding data for the index

  • input is index in query db (on self.data)

  • do search on index and score. return n hits, where n is set in config somewhere

  • if hits are below threshold, set to 0

  • output is dict, query, hits, taninoto, where each is an array of length hitlist

  • query and hit are spectra, tanimoto is float

get_x(data_row)

given the data row, return the input to the network

the input in this case is a pairwise tensor of query spectrum, hit spectrum with length hitlist create embedding that turns two arrays of spectra into a tensor of size 2, mz_max

get_y(data_row)

returns tanimoto scores, of size hitlist size

Parameters:

data_row – the data row

Returns:

torch tensor

property index
property row2id
spectrum2array(spectrum)

given a spectrum, create a spectrum array

Parameters:

spectrum – the spectrum

Returns:

the numpy spectrum array

masskit_ai.spectrum.small_mol.small_mol_lightning module

class masskit_ai.spectrum.small_mol.small_mol_lightning.SearchLightningModule(*args: Any, **kwargs: Any)

Bases: BaseSpectrumLightningModule

pytorch lightning module used to train on search results

pairwise_search_model(batch)

feed batched hitlist data and corresponding spectra into the model. accomplishes this by reshaping data with form (batch, hitlist, query/hit, spectrum) into (batch, spectrum) vectors with model output (batch, fingerprint) converted to (batch, hitlist, query/hit, fingerprint) cosine score is then computed between query/hit pairs of fingerprints

Parameters:
  • batch – input to the batch

  • output – expected output

Returns:

results from the model

training_step(batch, batch_idx)
validation_test_step(batch, batch_idx, loop)

step shared with test and validation loops

Parameters:
  • batch – batch

  • batch_idx – index into data for batch

  • loop – the name of the loop

Returns:

loss

class masskit_ai.spectrum.small_mol.small_mol_lightning.SmallMolSearchDataModule(*args: Any, **kwargs: Any)

Bases: MasskitDataModule

data loader for tandem small molecule search

get_subsets(set_to_load)

create datasets

Parameters:

set_to_load – train, valid or test dataset

Returns:

a list of datasets

masskit_ai.spectrum.small_mol.small_mol_losses module

class masskit_ai.spectrum.small_mol.small_mol_losses.SearchLoss(*args: Any, **kwargs: Any)

Bases: BaseLoss

loss based on search hitlist

forward(output, batch, params=None) torch.Tensor

Computes hitlist loss assumes: - batch.y < 0 means hitlist row is invalid - batch.y >= 1.0 means identical match - batch.y and output.y_prime is structured (batch, …, hitlistrow)

Parameters:
  • output – ModelOutput

  • batch – ModelInput

  • params – optional parameters, defaults to None

Returns:

loss

Module contents