Fused Gromov-Wasserstein Module
- fused_gromov_wasserstein(cell1_dmat: ndarray[Any, dtype[float64]], cell1_distribution: ndarray[Any, dtype[float64]], cell1_node_types: ndarray[Any, dtype[int32]], cell2_dmat: ndarray[Any, dtype[float64]], cell2_distribution: ndarray[Any, dtype[float64]], cell2_node_types: ndarray[Any, dtype[int32]], penalty_dictionary: dict[tuple[int, int], float], worst_case_gw_increase: Optional[float] = None, **kwargs)
Compute the fused Gromov-Wasserstein distance between cells.
Penalties for mismatched node types should be supplied by the user.
- Parameters
cell1_dmat (ndarray[Any, dtype[float64]]) – A squareform distance matrix of shape (n,n).
cell1_distribution (ndarray[Any, dtype[float64]]) – A probability distribution of shape (n).
cell1_node_types (ndarray[Any, dtype[int32]]) – A vector of integer structure id’s (type labels) of shape (n)
cell2_dmat (ndarray[Any, dtype[float64]]) – A squareform distance matrix of shape (m,m).
cell2_distribution (ndarray[Any, dtype[float64]]) – A probability distribution of shape (m).
cell2_node_types (ndarray[Any, dtype[int32]]) – A vector of integer structure id’s (type labels) of shape (m)
penalty_dictionary (dict[tuple[int, int], float]) – A dictionary whose keys are pairs (i,j) of distinct structure ids for points occurring in the sample data (with i < j) and whose values are non-negative floating point numbers, representing the “fused” penalty for aligning a node of type i with a node of type j. Pairs (i,j) which aren’t in the penalty dictionary have penalty weight 0.
worst_case_gw_increase (Optional[float]) –
This parameter is meant to give a more interpretable and intuitively accessible way to control the fused GW penalty in situations where it is difficult to assess the appropriate order of magnitude for the values in the penalty matrix a priori. The notion of fused GW involves a compromise between minimizing distortion (GW cost) and minimizing label misalignment (conflicts in node type labels). Adding a penalty for label misalignment will tend to increase the distortion of the transport plan, because the algorithm now has to balance both of these considerations, and the higher the penalty for label misalignment, the higher the distortion of the associated transport plan will be, because the algorithm will focus primarily on aligning node types. Therefore, we offer a way to control the maximum increase in distortion (above and beyond the distortion associated to the ordinary, classical GW transport plan) due to the additional constraint of ensuring label alignment.
If worst_case_gw_increase is None (the default) then the values in penalty_dictionary are taken as absolute penalties, and the fused GW cost matrix is directly computed from the weights supplied. If worst_case_gw_increase is a non-negative floating point number then the values in penalty_dictionary are interpreted in a relative way, so that only the ratios of one value to another become meaningful - for example, if penalty_dictionary[(1,3)] = 10.0 and penalty_dictionary[(3,4)] = 2.0, then the resulting fused GW cost matrix will have the property that aligning a soma node with a basal dendrite is 5 times more costly than aligning a basal dendrite node with an apical dendrite node. The absolute values of the fused GW cost matrix are determined by the following heuristic, which guarantees that the GW cost of the final matrix for the fused GW cost is at most a factor of worst_case_gw_increase greater than that of the ordinary GW distance. For instance, if the user supplies worst_case_gw_increase = 0.50, then the transport plan found by the fused GW algorithm is guaranteed to have a GW cost at most 50% higher than the transport plan found by ordinary GW.
kwargs – This function wraps the implementation of Fused GW provided by the Python Optimal Transport library and all additional keyword arguments supplied by the user are passed to that function. See the documentation here for keyword arguments which can be used to customize the behavior of the algorithm.
- fused_gromov_wasserstein_parallel(intracell_csv_loc: str, swc_node_types: str, fgw_dist_csv_loc: str, num_processes: int, soma_dendrite_penalty: float, basal_apical_penalty: float, penalty_dictionary: Optional[dict[tuple[int, int], float]] = None, chunksize: int = 20, sample_points_npz: Optional[str] = None, worst_case_gw_increase: Optional[float] = None, dynamically_adjust: bool = False, sample_size: int = 100, quantile: float = 0.15, **kwargs)
Compute the fused GW distance pairwise in parallel between many neurons.
- Parameters
intracell_csv_loc (str) – The path to the file where the sampled points are stored.
swc_node_types (str) – The path to the swc node type file, expected to be in npy format; consistent with the files written by functions in the sample_swc module.
fgw_dist_csv_loc (str) – Where you want the fused GW distances to be written.
num_processes (int) – How many parallel processes you want this to run on.
soma_dendrite_penalty (float) – This represents the penalty paid by the transport plan for aligning a soma node with a dendrite node. By choosing this coefficient sufficiently large, the algorithm favors transport plans which align soma nodes to soma nodes and dendrite nodes to dendrite nodes. Choosing the coefficient to be too large may be counterproductive.
basal_apical_penalty (float) – The penalty paid by the transport plan for aligning a basal dendrite node with an apical dendrite node, if this distinction is indeed captured in the morphological reconstructions.
penalty_dictionary (Optional[dict[tuple[int, int], float]]) – For the meaning of this parameter, see the documentation for
cajal.fused_gw_swc.fused_gromov_wasserstein(). If penalty_dictionary is None, it is automatically constructed as a function of the arguments soma_dendrite_penalty and apical_dendrite_penalty. If this parameter is supplied then the previous two parameters are ignored as this parameter overrides them; the user can reproduce the behavior by adding penalty keys for (1,3), (1,4) and (3,4) appropriately. Thechunksize (int) – A parallelization parameter, the number of jobs fed to each process at a time.
worst_case_gw_increase (Optional[float]) – The meaning of this parameter is closely related to the parameter documented in
cajal.fused_gw_swc.fused_gromov_wasserstein(), but see the documentation for dynamically_adjust.dynamically_adjust (bool) – If dynamically_adjust is True, then the argument worst_case_gw_increase is passed directly to the function
cajal.fused_gw_swc.fused_gromov_wasserstein()for each pair of arguments. However, this would mean that a different cost matrix will be computed for each pair of cells, so one is not computing the same notion of fused GW throughout the data. We regard it as more statistically appropriate to use the same fixed parameters for fused GW throughout all cell pairs in the data. If dynamically_adjust is False (the default) then the effect of worst_case_gw_increase is to set a global cost matrix for all cell pairs, chosen such that for a pair of cells in the same neighborhood of the GW space, the increase in distortion will be at most a factor of worst_case_gw_increase. (This is a statistical heuristic, this is not guaranteed.)sample_size (int) – Only relevant if dynamically_adjust is True. Indicates the number of cell pairs to sample in order to estimate the distribution of GW costs.
quantile (float) – Only relevant if dynamically_adjust is True. This informs the notion of “in the same neighborhood” described in dynamically_adjust. The cost matrix is constructed by looking at cells whose GW cost is less than the given quantile in the sample distribution.
kwargs – See documentation for
cajal.fused_gw_swc.fused_gromov_wasserstein()