Subcellular Protein Localization (Deep Learning)

make_NN_training_data(save_path, cell_objects, reference_cell_object, mapped_channel_distributions, channel, center='cell', rescale=True, shape=(64, 64))

Generate training data from CellAligner_Cell objects for neural network training.

Creates paired cell images and their corresponding mapped versions for training deep learning models. Images are aligned, normalized, and saved as numpy arrays.

Parameters
  • save_path (str) – Directory path to save processed images. Cell images saved to ‘<save_path>/cell_images’ and mapped images to ‘<save_path>/mapped_cell_images’.

  • cell_objects (list) – List of CellAligner_Cell objects or paths to pickled CellAligner_Cell objects.

  • reference_cell_object (CellAligner_Cell or str) – Reference cell object or path to pickled reference cell object used as template for mapped distributions.

  • mapped_channel_distributions (numpy.ndarray) – Array of mapped protein distributions for each cell.

  • channel (str) – Channel name to use for image processing.

  • center (str) – Centering method for image alignment: ‘cell’ or ‘nucleus’. Default is ‘cell’.

  • rescale (bool) – Whether to rescale images to a fixed size. Default is True.

  • shape (tuple of int) – Target shape (height, width) for resizing images. Default is (64, 64).

Returns

None. Images are saved to disk as .npy files.

Return type

None

class dCellAlignerNetwork(input_channels=3, embedding_size=50, image_size=64)

Deep CellAligner Network.

A complete neural network architecture that combines feature extraction, distance computation, and image reconstruction in a multi-task learning framework. Designed for approximating CellAligner mapping to anchor cell morphologies and learning embeddings that preserve metrics for quantifying differences in protein localization after mapping.

Parameters
  • input_channels (int, optional) – Number of input image channels. Default is 3.

  • embedding_size (int, optional) – Dimensionality of the feature embedding space. Default is 50.

  • image_size (int, optional) – Size of input/output images (assumed square). Default is 64.

class PairedDataset(image_dir, mapped_image_dir, distances, image_pairs, transform=None, augment_transform=None, n_augment=1)

PyTorch Dataset for loading paired cell images with distance labels.

Loads pairs of cell images and their mapped counterparts from numpy files for training distance-based models. Supports data augmentation and lazy loading for memory efficiency.

Parameters
  • image_dir (str) – Path to directory containing cell image .npy files with naming convention ‘cell_{index}.npy’.

  • mapped_image_dir (str) – Path to directory containing mapped cell image .npy files with naming convention ‘mapped_cell_{index}.npy’.

  • distances (list of float) – Distance values corresponding to each image pair for supervised learning.

  • image_pairs (list of tuple) – List of (index1, index2) tuples specifying which images to pair.

  • transform (callable, optional) – Transform function to apply to all images. Default is None.

  • augment_transform (callable, optional) – Additional augmentation transform for data augmentation. Default is None.

  • n_augment (int, optional) – Number of augmented copies to create for each pair. Default is 1.

Note

The dataset expects file naming conventions: - Cell images: ‘cell_{index}.npy’ - Mapped images: ‘mapped_cell_{index}.npy’

class RandomHorizontalRescale(min_relative_width=0.1, max_relative_width=1.0)

Data augmentation transform that randomly rescales image width.

Randomly rescales the horizontal axis of cell images to achieve uniform distribution of cell mask widths. Uses appropriate interpolation methods for different channel types (bilinear for intensity, nearest for masks).

Parameters
  • min_relative_width (float, optional) – Minimum relative width of the cell mask as fraction of image width. Default is 0.1.

  • max_relative_width (float, optional) – Maximum relative width of the cell mask as fraction of image width. Default is 1.0.

Note

  • Channel 0: Resized with bilinear interpolation (intensity/probability)

  • Other channels: Resized with nearest neighbor interpolation (binary masks)

The transform maintains the original image width by padding or cropping after rescaling.

generate_dataset_split_pairs(indices, n_pairs, proportions=None, seed=None)

Generate cell pairs for Deep CellAligner dataset splits.

Creates stratified cell pairs for deep learning model training. Supports both random sampling across all cells and stratified sampling from predefined groups to ensure balanced representation across train/validation/test splits.

Parameters
  • indices (list or array-like) – List of available cell indices corresponding to processed cell images.

  • n_pairs (list of int) – N-length list specifying number of pairs to generate for each dataset split. Typically [n_train_pairs, n_val_pairs, n_test_pairs].

  • proportions (list of float, optional) – N-length list of proportions that sum to 1.0 for stratified sampling. If provided, cell indices are split into N groups according to these proportions, and pairs are drawn only within each group. This ensures train/val/test sets use disjoint cell populations. If None, all pairs are drawn randomly from all available cells. Default is None.

  • seed (int, optional) – Random seed for reproducible dataset splits. Default is None.

Returns

N-length list where each element is a 2D array of shape (n_pairs, 2) containing cell index pairs for each dataset split (train, val, test).

Return type

list of numpy.ndarray

pretrain_model(paired_dataset, model, dataset_name, save_path=None, batch_size=64, epochs=10, lr=0.001, device=None, return_model=True)

Pretrain a model using paired input and target images provided as a PairedDataset.

This function accepts a PairedDataset object and first converts it into a PretrainPairedDataset by collecting the unique image indices referenced in the paired dataset. For each unique image index i, the input path is ‘<image_dir>/cell_i.npy’ and the target path is ‘<mapped_image_dir>/mapped_cell_i.npy’. After conversion the rest of the training loop is identical to the previous implementation.

Parameters
  • paired_dataset (PairedDataset) – Dataset containing paired indices and directory information. Must have attributes image_dir, mapped_image_dir and image_pairs.

  • model (torch.nn.Module) – Neural network model to pretrain. Must have a forward method that takes two identical inputs and returns reconstructions.

  • dataset_name (str) – Name prefix for saved model files.

  • save_path (str, optional) – Directory path to save the pretrained model. If None, model is not saved. Default is None.

  • batch_size (int, optional) – Batch size for training. Default is 64.

  • epochs (int, optional) – Number of training epochs. Default is 10.

  • lr (float, optional) – Learning rate for the Adam optimizer. Default is 1e-3.

  • device (torch.device, optional) – Device to run training on. If None, automatically selects GPU if available. Default is None.

  • return_model (bool, optional) – Whether to return the trained model. If False, returns None. Default is True.

Returns

The pretrained model if return_model is True, otherwise None.

Return type

torch.nn.Module or None

train_dCellAligner(train_dataset, valid_dataset, test_dataset, save_path, dataset_name, embedding_size=50, image_shape=(64, 64), batch_size=100, epochs=100, device=None, learning_rate=0.001, dist_weight=1.0, early_stopping=True, patience=3, weight_decay=1e-05, lr_gamma=0.95, sparsity_weight=0.0, sparsity_target=0.05, pretrained_path=None, show_loss_components=False)

Trains a Deep CellAligner model using multi-task learning with distance prediction and image reconstruction objectives. Supports early stopping, learning rate scheduling, and optional sparsity constraints.

Parameters
  • train_dataset (Dataset) – PyTorch datasets for training, validation, and testing.

  • valid_dataset (Dataset) – Validation dataset.

  • test_dataset (Dataset) – Test dataset.

  • save_path (str) – Directory path to save the trained model and checkpoints.

  • dataset_name (str) – Name prefix for saved model files.

  • embedding_size (int, optional) – Dimensionality of the feature embedding space. Default is 50.

  • image_shape (tuple of int, optional) – Shape of input images as (height, width). Default is (64, 64).

  • batch_size (int, optional) – Batch size for training. Default is 100.

  • epochs (int, optional) – Maximum number of training epochs. Default is 100.

  • device (torch.device, optional) – Device for training. If None, automatically selects GPU if available.

  • learning_rate (float, optional) – Initial learning rate for Adam optimizer. Default is 0.001.

  • dist_weight (float, optional) – Weight for distance loss vs reconstruction loss in total loss. Default is 1.0.

  • early_stopping (bool, optional) – Whether to use early stopping based on validation loss. Default is True.

  • patience (int, optional) – Number of epochs to wait for improvement before stopping. Default is 3.

  • weight_decay (float, optional) – L2 regularization weight for optimizer. Default is 1e-5.

  • lr_gamma (float, optional) – Decay factor for exponential learning rate scheduler. Default is 0.95.

  • sparsity_weight (float, optional) – Weight for sparsity constraint loss. Default is 0.0 (disabled).

  • sparsity_target (float, optional) – Target sparsity level for hidden activations. Default is 0.05.

  • pretrained_path (str, optional) – Path to pretrained model weights to initialize from. Default is None.

  • show_loss_components (bool, optional) – Whether to display individual loss components (distance, reconstruction, sparsity) during training. Default is False.

Returns

Tuple containing the trained model, training loss history, and validation loss history.

Return type

tuple (model: torch.nn.Module, train_losses: list of float, val_losses: list of float)

load_dCellAligner_model(checkpoint_path, device=None)

Load a Deep CellAligner model from a checkpoint containing state dict and config.

Parameters
  • checkpoint_path (str) – Path to the checkpoint file containing both state_dict and config.

  • device (torch.device, optional) – Device to load the model on. If None, uses GPU if available.

Returns

Loaded dCellAligner model ready for inference or further training.

Return type

torch.nn.Module

extract_embeddings(model, data, batch_size=64, device=None, process_info_path=None, channel=None)

Extract latent embeddings from a trained Deep CellAligner model.

Processes input data through the feature extractor to obtain latent embeddings. Supports CellAligner_Cell objects, NumPy arrays of images, PyTorch datasets, and PairedDatasets as inputs.

Parameters
  • model (torch.nn.Module) – Trained dCellAligner model with a feature_extractor attribute.

  • data (list, tuple, numpy.ndarray, torch.utils.data.Dataset, or PairedDataset) – Input data to extract embeddings from. Can be: - PairedDataset - PyTorch Dataset - List of CellAligner_Cell objects - NumPy array with shape (N, H, W, 3) or (H, W, 3)

  • batch_size (int, optional) – Batch size for processing. Default is 64.

  • device (torch.device, optional) – Device to run computation on. If None, uses model’s current device.

  • process_info_path (str or None, optional) – Path to the saved cell image processing JSON or the directory containing it. Used when data contains CellAligner_Cell objects.

  • channel (str or None, optional) – Channel name to use when data contains CellAligner_Cell objects or file paths.

Returns

Extracted embeddings of shape (N, embedding_size) where N is the number of input images.

Return type

numpy.ndarray

predict_distances(model, data, batch_size=64, device=None, process_info_path=None, channel=None)

Predict distances from a trained Deep CellAligner model.

Extracts embeddings for all unique images in the input data and computes pairwise distances. Supports CellAligner_Cell objects, NumPy arrays of images, PyTorch datasets, and PairedDatasets as inputs.

Parameters
  • model (torch.nn.Module) – Trained dCellAligner model with a feature_extractor attribute.

  • data (PairedDataset, Dataset, numpy.ndarray, list, or tuple) – Input data to extract embeddings from. Can be: - PairedDataset - PyTorch Dataset - List of CellAligner_Cell objects - NumPy array with shape (N, H, W, 3) or (H, W, 3)

  • batch_size (int, optional) – Batch size for processing embeddings. Default is 64.

  • device (torch.device, optional) – Device to run computation on. If None, uses model’s current device.

  • process_info_path (str or None, optional) – Path to the saved cell image processing JSON or the directory containing it. Used when data contains CellAligner_Cell objects.

  • channel (str or None, optional) – Channel name to use when data contains CellAligner_Cell objects or file paths.

Returns

If data is a PairedDataset, returns distances of shape (len(data),) corresponding to each pair. Otherwise returns a square distance matrix of shape (N, N) for the N input images, with zeros on the diagonal.

Return type

numpy.ndarray

plot_distance_predictions(model, paired_dataset, batch_size=64, device=None, figsize=(8, 8), return_plot=False, title=None, alpha=0.6, s=20)

Plot predicted vs true distances for a PairedDataset.

Creates a scatter plot comparing model predictions against ground truth distances with a diagonal reference line and correlation metrics.

Parameters
  • model (torch.nn.Module) – Trained dCellAligner model with a feature_extractor attribute.

  • paired_dataset (PairedDataset) – Dataset containing paired images with known distances.

  • batch_size (int, optional) – Batch size for processing embeddings. Default is 64.

  • device (torch.device, optional) – Device to run computation on. If None, uses model’s current device.

  • figsize (tuple, optional) – Figure size as (width, height). Default is (8, 8).

  • return_plot (bool, optional) – Whether to return the matplotlib figure and axes objects. Default is False.

  • title (str, optional) – Custom title for the plot. If None, uses default with correlation metrics.

  • alpha (float, optional) – Transparency of scatter points. Default is 0.6.

  • s (int, optional) – Size of scatter points. Default is 20.

Returns

If return_plot is False: displays the plot and returns None. If return_plot is True: returns (fig, ax) matplotlib objects.

Return type

None or tuple

plot_reconstruction_comparison(model, paired_dataset, n_cells=5, device=None, figsize=None, seed=None)

Plot comparison of original mapped protein distributions vs model reconstructions.

Randomly selects cell images from the dataset, shows the original mapped protein distributions in the top row and their reconstructions from the model in the bottom row.

Parameters
  • model (torch.nn.Module) – Trained dCellAligner model with reconstruction capabilities.

  • paired_dataset (PairedDataset) – Dataset containing paired images for reconstruction.

  • n_cells (int, optional) – Number of image pairs to display. Default is 5.

  • device (torch.device, optional) – Device to run model on. If None, uses model’s current device.

  • figsize (tuple, optional) – Figure size as (width, height). If None, automatically calculated based on number of images.

  • seed (int, optional) – Random seed for reproducible image selection. Default is None.

Returns

None. Displays the plot using matplotlib.

Return type

None