{ "cells": [ { "cell_type": "markdown", "id": "5af8d526-b9fa-4568-9a58-fc34ee569243", "metadata": {}, "source": [ "# Tutorial 5: Unbalanced and Fused Gromov-Wasserstein distances\n", "The Gromov-Wasserstein distance is highly useful for quantifying differences in cell morphology, and a number of variants of Gromov-Wasserstein distance have been proposed in the literature. Here we introduce two such variants, \"Unbalanced Gromov-Wasserstein\" and \"Fused Gromov-Wasserstein\", and discuss their applications to neuron taxonomy. We will see that their technical advantages lead to better ability to recapitulate known labels such as the RNA family. Additional background can be found in the \"Variants of Gromov-Wasserstein\" page. We will use the same Patch-seq data from 645 neurons from the mouse motor cortex that was studied in Tutorial 4, and the same sampled points.\n", " All data can be downloaded from [this link](https://www.dropbox.com/scl/fo/a5b2t4rkek0j5xvjrt5un/ALrhJWIU0zYWuk2QShiGjLs?rlkey=qt79k4qzy2oeo5rnvik7jimu1&st=bu6yzcuw&dl=0).\n", "\n", "## Intuition for Unbalanced Gromov-Wasserstein\n", "The big-picture idea behind unbalanced Gromov–Wasserstein (UGW) is that it is less sensitive than ordinary GW to small changes in morphology. Whereas ordinary GW focuses on how well we can align two entire cell morphologies, UGW effectively ask how well we can partially align two cell morphologies, i.e. allowing for unmatched regions.\n", "In situations where it is acceptable to discard or down-weight small, negligible parts of each cell (so as not to disrupt the global morphology), UGW can be more robust than ordinary GW. Thus, UGW is expected to be less sensitive to tracing errors and missing morphological data.\n", "\n", "The definition of Gromov–Wasserstein distance involves searching through all possible 'couplings' between two cells, where each cell is treated as having total unit mass. In this strictly mass-conserving framework, all the mass in the first cell must be paired exactly with mass in the second cell. Concretely, if two neurons are each modeled by a point cloud of 100 points, each point is assigned mass 0.01, and any valid coupling must pair the entire 0.01 mass from each point in one neuron with 0.01 mass distributed across the points in the other neuron.\n", "\n", "For example, suppose we have two neurons that are identical except for one additional dendrite in the second neuron. This extra dendrite is biologically meaningful, and considering embeddings of the first neuron into a portion of the second would capture important structural similarities. However, Gromov–Wasserstein does not recognize such partial embeddings as valid couplings, because it violates the requirement of strict 'conservation of mass': all mass from the first neuron must be paired with an equivalent total mass in the second neuron, leaving the extra dendrite unmatched. As a result, the optimal Gromov–Wasserstein transport plan would likely fail to reflect the structural equivalence between the first neuron and most of the second.\n", "\n", "The Unbalanced Gromov–Wasserstein distance allows such embeddings by permitting transport plans that create or destroy mass, at the cost of an additional penalty. The size of this penalty is controlled by a user-supplied parameter $\\rho$. When $\\rho$ is large, the solution remains close to a 'perfect coupling,' but as $\\rho$ is reduced, the algorithm becomes more tolerant of deviations and allows looser fits. In their paper on Unbalanced Gromov–Wasserstein, Séjourné, Vialard, and Peyré provide several examples illustrating how this extra flexibility helps account for small differences between objects.\n", "\n", "Choosing a specific numerical value for $\\rho$ can be challenging because it is not immediately clear what order of magnitude $\\rho$ should take to produce sensible results. Instead, we introduce a more intuitive and interpretable control parameter: a lower bound on the fraction of mass retained during alignment. Specifically, the user can set mass_kept=0.90 to ensure that, when two neurons are aligned, at least 90% of the points in both neurons remain matched, and at most 10% can be discarded to improve the fit.\n", "\n", "Let us demonstrate how to use the implementation. We assume all the data is in folder `/home/jovyan`." ] }, { "cell_type": "code", "execution_count": 1, "id": "b6b7e64d", "metadata": {}, "outputs": [], "source": [ "from os.path import join\n", "from cajal.ugw import _multicore, UGW # Substitute _single_core for single-threaded usage, useful if you want to parallelize at the level of Python processes\n", "UGW_multicore = UGW(_multicore) # For GPU backends, the constructor has to negotiate a connection to the GPU, so it may take a long time to initialize.\n", "\n", "bd = \"/home/jovyan/tutorial5\" # Base directory" ] }, { "cell_type": "markdown", "id": "822c4abc", "metadata": {}, "source": [ "The appropriate parameter values are sensitive to the absolute scales of the data. To choose suitable coefficients, you can first run the ordinary GW computation to estimate these scales. See the ‘Variants of Gromov–Wasserstein’ page for details.\n", "\n", "Note that the algorithm for UGW is more computationally intensive than the algorithm for classical GW, and we do not recommend running it during the tutorial." ] }, { "cell_type": "code", "execution_count": 5, "id": "36765c5a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done first pass, cleaning up errors.\n" ] } ], "source": [ "eps = 100.0\n", "UGW_dmat = UGW_multicore.ugw_armijo_pairwise(\n", " mass_kept = 0.80,\n", " eps=eps,\n", " dmats=join(bd,\"geodesic_100_icdm.csv\")\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "0a14b977-2b08-44f5-8bed-6ab2822a899a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.save(join(bd,\"UGW_dmat_mass_80_eps_100.npy\"), UGW_dmat) # This file is available, pre-computed, in the folder linked above." ] }, { "cell_type": "markdown", "id": "f09b79aa", "metadata": {}, "source": [ "One can use the resulting morphological dissimilarity matrix as we have shown in other tutorials, for example, to derive morphological neuronal types or establish associations with molecular information. Here, we test the degree to which the resulting cell morphology space preserves information about the transcriptomic type of the neurons, similar to what we did in Tutorials 1 and 4. For that purpose, we train a nearest-neighbor classifier on the morphological dissimilarity matrix and evaluate its accuracy using the Matthews correlation coefficient (MCC) and leave-one-out cross-validation." ] }, { "cell_type": "code", "execution_count": 3, "id": "78e048c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MCC: 0.48718805743696014\n" ] } ], "source": [ "import pandas as pd\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.model_selection import LeaveOneOut\n", "from cajal.utilities import cell_iterator_csv\n", "from sklearn.model_selection import cross_val_predict\n", "from sklearn.metrics import matthews_corrcoef\n", "\n", "cells, idcms = zip(*cell_iterator_csv(intracell_csv_loc=join(bd,\"geodesic_100_icdm.csv\")))\n", "metadata = pd.read_csv(join(bd,'m1_patchseq_meta_data.csv'),sep='\\t',index_col='Cell').loc[pd.Series(cells)]\n", "RNA_family = metadata['RNA family']\n", "hq = RNA_family != 'low quality' # Filter down to the cells that have a well-defined RNA family.\n", "\n", "clf = KNeighborsClassifier(metric=\"precomputed\", n_neighbors=10, weights=\"distance\")\n", "cv = LeaveOneOut()\n", "cvp = cross_val_predict(clf, X= UGW_dmat[hq,:][:,hq], y=RNA_family.loc[hq], cv=cv)\n", "print(\"MCC: \", matthews_corrcoef(cvp, RNA_family.loc[hq]))" ] }, { "cell_type": "markdown", "id": "bb4f42b5", "metadata": {}, "source": [ "For comparison, we can perform the same analysis using the standard Gromov-Wasserstein morphological distance:" ] }, { "cell_type": "code", "execution_count": 4, "id": "46a4c5d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MCC: 0.41437665858208905\n" ] } ], "source": [ "import cajal.utilities\n", "from sklearn.model_selection import cross_val_score\n", "\n", "_, classical_gw_dists = cajal.utilities.read_gw_dists(join(bd, 'swc_bdad_100pts_geodesic_gw.csv'), header=True)\n", "classical_gw_dmat = cajal.utilities.dist_mat_of_dict(classical_gw_dists, metadata.index[hq].to_list())\n", "\n", "gw_results = cross_val_score(clf, X=classical_gw_dmat, y=RNA_family[hq],cv=cv)\n", "cvp = cross_val_predict(clf, X=classical_gw_dmat, y=RNA_family[hq], cv=cv)\n", "print(\"MCC: \", matthews_corrcoef(cvp, RNA_family[hq]))\n" ] }, { "cell_type": "markdown", "id": "6170bcc7", "metadata": {}, "source": [ "Thus, the MCC obtained by using by using the unbalanced Gromov-Wasserstein distance in these analyses is approximately 18% higher than using the standard Gromov-Wasserstein distance. Naturally, these statistics depend on the sampling distribution from which the neurons are drawn, so using a different set of neurons might yield different results." ] }, { "cell_type": "markdown", "id": "4e53e054", "metadata": {}, "source": [ "## Using Fused Gromov-Wasserstein\n", "The mathematical basis of Fused Gromov–Wasserstein distance is discussed in detail on the “Variants of Gromov–Wasserstein” page. Here,\n", " we focus on its application to neuronal morphological reconstructions. In this context, the Fused Gromov Wasserstein distance can be use to penalize couplings that map nodes of one type in one neuron (for example, apical dendrite) to nodes of a different type in the other neuron (for example, axon). The interface for the Fused Gromov–Wasserstein distance is similar to that for the standard Gromov–Wasserstein distance, but it requires a few additional pieces of information:\n", "\n", "* `swc_node_types`: A path to the location of the SWC node type identifiers for the sampled points.\n", "\n", "* `soma_dendrite_penalty` and `basal_apical_penalty`: Setting `soma_dendrite_penalty` to a high value means the algorithm will try to avoid pairing soma nodes from one neuron with dendrite nodes of another. Similarly, `basal_apical_penalty` indicates the penalty for pairing a basal dendrite node from one neuron with an apical dendrite node of the other.\n", "\n", "* `penalty_dictionary` (optional): This argument overrides both `soma_dendrite_penalty` and `basal_apical_penalty`. Users can directly specify the penalty for each pair of node types, which is most useful for SWC files containing structure IDs outside the commonly used range (0–4).\n", "\n", "* `worst_case_gw_increase` (optional): By default, node penalties are absolute. For example, if you specify a soma-to-dendrite penalty of 5.0, the FGW cost will increase by 5.0 whenever a soma node is paired with a dendrite node. Setting `worst_case_gw_increase` makes the node penalties relative, so that only the ratio between penalties matters—and rescales all penalties by a constant. This ensures that the median increase in GW cost from node-type penalties does not exceed `worst_case_gw_increase`." ] }, { "cell_type": "code", "execution_count": 8, "id": "e10f2269", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "163276.75404087408\n", "163276.75404087408\n", "0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 207690/207690 [04:31<00:00, 764.60it/s] \n" ] } ], "source": [ "from cajal.fused_gw_swc import fused_gromov_wasserstein_parallel\n", "\n", "fused_gw_dmat = fused_gromov_wasserstein_parallel(\n", " intracell_csv_loc=join(bd,'geodesic_100_icdm.csv'),\n", " swc_node_types=join(bd,\"geodesic_100_node_types.npy\"),\n", " fgw_dist_csv_loc=join(bd,\"geodesic_100_fgw.csv\"),\n", " num_processes=14,\n", " soma_dendrite_penalty= 1., # The cost of aligning a soma node to a dendrite node is initialized to be 1.0, but it will be rescaled based on the value of `worst_case_gw_increase`\n", " basal_apical_penalty=0., # The data set we are using doesn't distinguish basal and apical dendrites, so this parameter has no effect\n", " # penalty_dictionary: Optional[dict[tuple[int, int], float]] = None,\n", " chunksize = 100,\n", " worst_case_gw_increase= 0.10, # We want the GW cost to go up at most 50% for the pair of cells at the 15th percentile in the dataset. in the data set.\n", ")\n", "fused_gw_dmat = fused_gw_dmat[hq, :][:, hq]" ] }, { "cell_type": "code", "execution_count": null, "id": "788fb487", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.5667189952904239\n", "MCC: 0.4704884461886154\n" ] } ], "source": [ "fgw_results = cross_val_score(clf, X=fused_gw_dmat, y=RNA_family[hq],cv=cv)\n", "cvp = cross_val_predict(clf, X=fused_gw_dmat, y=RNA_family[hq], cv=cv)\n", "print(\"MCC: \", matthews_corrcoef(cvp, RNA_family[hq]))" ] }, { "cell_type": "markdown", "id": "2010a161", "metadata": {}, "source": [ "So, incorporating this additional data to distinguish between basal dendrites and apical dendrites, Fused Gromov-Wasserstein outperforms classical Gromov-Wasserstein by a similar margin as Unbalanced Gromov-Wasserstein." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }