{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3D Flex: Custom Latent Trajectory\n", "\n", "```{note}\n", "[Read the tutorial](https://guide.cryosparc.com/processing-data/tutorials-and-case-studies/tutorial-3d-flexible-refinement) to learn more about 3D Flexible Refinement.\n", "```\n", "\n", "This example covers the following:\n", "\n", "* Load particle latent coordinates from a 3D Flex Training job\n", "* Plot the latent coordinates\n", "* Interactively draw a trajectory through the latent space\n", "* Output the trajectory as a new output in CryoSPARC\n", "\n", "The resulting trajectory may be used as input to the 3D Flex Generator job to generate a volume series along the trajectory. In this way, you can visualize specific regions or pathways through the latent conformational distribution of the particle." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Connection succeeded to CryoSPARC command_core at http://cryoem0.sbi:40002\n", "Connection succeeded to CryoSPARC command_vis at http://cryoem0.sbi:40003\n", "Connection succeeded to CryoSPARC command_rtp at http://cryoem0.sbi:40005\n" ] } ], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import numpy as n\n", "\n", "from cryosparc.tools import CryoSPARC\n", "\n", "with open(Path(\"~\", \"instance-info.json\").expanduser(), \"r\") as f:\n", " credentials = json.load(f)\n", "\n", "cs = CryoSPARC(**credentials)\n", "assert cs.test_connection()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the particles dataset from the 3D Flex Training job." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
alignments2D/alphaalignments2D/alpha_minalignments2D/classalignments2D/class_essalignments2D/class_posterioralignments2D/cross_coralignments2D/erroralignments2D/error_minalignments2D/image_powalignments2D/pose...pick_stats/powerpick_stats/template_idxsym_expand/helix_num_risessym_expand/helix_rise_Asym_expand/helix_twist_radsym_expand/idxsym_expand/is_helixsym_expand/src_uidsym_expand/symmetryuid
01.01.06751881.0002610.99986976.2714849163.1806640.09203.7285164.744446...781.205933300.00.0006947839038024507105C26947839038024507105
11.01.06751881.0002610.99986976.2714849163.1806640.09203.7285164.744446...781.205933300.00.0106947839038024507105C213932325671802011693
21.00.86851521.0067420.99664262.4218758964.0664060.08990.5527343.205707...596.804443300.00.00018174470766076440683C218174470766076440683
31.00.86851521.0067420.99664262.4218758964.0664060.08990.5527343.205707...596.804443300.00.01018174470766076440683C216990665930294442579
41.01.15800501.6567030.74906447.2880868603.1875000.08630.0576175.385587...758.841248300.00.00010868047345897440647C210868047345897440647
51.01.15800501.6567030.74906447.2880868603.1875000.08630.0576175.385587...758.841248300.00.01010868047345897440647C213431911868430011898
61.00.99261261.0320840.98429839.3408208956.4697270.08975.9941410.000000...601.065002300.00.00011087656592019060700C211087656592019060700
71.00.99261261.0320840.98429839.3408208956.4697270.08975.9941410.000000...601.065002300.00.01011087656592019060700C23521966667339501106
81.00.99227081.0007960.99960271.2246098801.6923830.08837.0273440.128228...669.910767300.00.00010759805808414285613C210759805808414285613
91.00.99227081.0007960.99960271.2246098801.6923830.08837.0273440.128228...669.910767300.00.01010759805808414285613C22825550586744122481
\n", "

10 rows × 110 columns

\n", "
" ], "text/plain": [ " alignments2D/alpha alignments2D/alpha_min alignments2D/class \\\n", "0 1.0 1.067518 8 \n", "1 1.0 1.067518 8 \n", "2 1.0 0.868515 2 \n", "3 1.0 0.868515 2 \n", "4 1.0 1.158005 0 \n", "5 1.0 1.158005 0 \n", "6 1.0 0.992612 6 \n", "7 1.0 0.992612 6 \n", "8 1.0 0.992270 8 \n", "9 1.0 0.992270 8 \n", "\n", " alignments2D/class_ess alignments2D/class_posterior \\\n", "0 1.000261 0.999869 \n", "1 1.000261 0.999869 \n", "2 1.006742 0.996642 \n", "3 1.006742 0.996642 \n", "4 1.656703 0.749064 \n", "5 1.656703 0.749064 \n", "6 1.032084 0.984298 \n", "7 1.032084 0.984298 \n", "8 1.000796 0.999602 \n", "9 1.000796 0.999602 \n", "\n", " alignments2D/cross_cor alignments2D/error alignments2D/error_min \\\n", "0 76.271484 9163.180664 0.0 \n", "1 76.271484 9163.180664 0.0 \n", "2 62.421875 8964.066406 0.0 \n", "3 62.421875 8964.066406 0.0 \n", "4 47.288086 8603.187500 0.0 \n", "5 47.288086 8603.187500 0.0 \n", "6 39.340820 8956.469727 0.0 \n", "7 39.340820 8956.469727 0.0 \n", "8 71.224609 8801.692383 0.0 \n", "9 71.224609 8801.692383 0.0 \n", "\n", " alignments2D/image_pow alignments2D/pose ... pick_stats/power \\\n", "0 9203.728516 4.744446 ... 781.205933 \n", "1 9203.728516 4.744446 ... 781.205933 \n", "2 8990.552734 3.205707 ... 596.804443 \n", "3 8990.552734 3.205707 ... 596.804443 \n", "4 8630.057617 5.385587 ... 758.841248 \n", "5 8630.057617 5.385587 ... 758.841248 \n", "6 8975.994141 0.000000 ... 601.065002 \n", "7 8975.994141 0.000000 ... 601.065002 \n", "8 8837.027344 0.128228 ... 669.910767 \n", "9 8837.027344 0.128228 ... 669.910767 \n", "\n", " pick_stats/template_idx sym_expand/helix_num_rises \\\n", "0 3 0 \n", "1 3 0 \n", "2 3 0 \n", "3 3 0 \n", "4 3 0 \n", "5 3 0 \n", "6 3 0 \n", "7 3 0 \n", "8 3 0 \n", "9 3 0 \n", "\n", " sym_expand/helix_rise_A sym_expand/helix_twist_rad sym_expand/idx \\\n", "0 0.0 0.0 0 \n", "1 0.0 0.0 1 \n", "2 0.0 0.0 0 \n", "3 0.0 0.0 1 \n", "4 0.0 0.0 0 \n", "5 0.0 0.0 1 \n", "6 0.0 0.0 0 \n", "7 0.0 0.0 1 \n", "8 0.0 0.0 0 \n", "9 0.0 0.0 1 \n", "\n", " sym_expand/is_helix sym_expand/src_uid sym_expand/symmetry \\\n", "0 0 6947839038024507105 C2 \n", "1 0 6947839038024507105 C2 \n", "2 0 18174470766076440683 C2 \n", "3 0 18174470766076440683 C2 \n", "4 0 10868047345897440647 C2 \n", "5 0 10868047345897440647 C2 \n", "6 0 11087656592019060700 C2 \n", "7 0 11087656592019060700 C2 \n", "8 0 10759805808414285613 C2 \n", "9 0 10759805808414285613 C2 \n", "\n", " uid \n", "0 6947839038024507105 \n", "1 13932325671802011693 \n", "2 18174470766076440683 \n", "3 16990665930294442579 \n", "4 10868047345897440647 \n", "5 13431911868430011898 \n", "6 11087656592019060700 \n", "7 3521966667339501106 \n", "8 10759805808414285613 \n", "9 2825550586744122481 \n", "\n", "[10 rows x 110 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "project = cs.find_project(\"P312\")\n", "particles = project.find_job(\"J243\").load_output(\"particles\")\n", "\n", "# only display the first ten rows, since creating an entire dataframe is slow\n", "pd.DataFrame(particles.rows()[:10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we create a plot of two of the dimensions and click to generate the trajectory through latent space. This approach necessarily allows only the creation of a trajectory through two of the latent coordinates. If you need to change three or more coordinates simultaneously, you will have to manually create a list of the coordinates you want to use." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "selected_dimensions = [0, 2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we create an interative plot that responds to on-click events.\n", "\n", "Once the plot is drawn, click repeatedly on the plot along a trajectory you wish to sample in the latent space. The points along this trajectory will form the output of the notebook and be used in 3D Flex Generator.\n", "\n", "```{note}\n", "Using interactive plots in Jupyter Notebooks can present some challenges.\n", "Connecting to a remote VS Code remote notebook requires the ipympl package\n", "and the `%matplotlib widget` \"magic\" line.\n", "See [the matplotlib documentation](https://matplotlib.org/stable/users/explain/figure/interactive.html)\n", "for more detail.\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib widget\n", "from matplotlib import pyplot as plt\n", "\n", "fig = plt.figure(figsize=(5, 5))\n", "\n", "\n", "def do_plot():\n", " plt.plot(\n", " particles[f\"components_mode_{selected_dimensions[0]}/value\"][::10],\n", " particles[f\"components_mode_{selected_dimensions[1]}/value\"][::10],\n", " \".\",\n", " alpha=0.5,\n", " color=\"gray\",\n", " )\n", " plt.grid()\n", "\n", "\n", "do_plot()\n", "pts = []\n", "\n", "\n", "def onclick(event):\n", " fig.clf()\n", " do_plot()\n", " pts.append([event.xdata, event.ydata])\n", " apts = n.array(pts)\n", " plt.plot(apts[0, 0], apts[0, 1], \"xk\")\n", " plt.plot(apts[:, 0], apts[:, 1], \".-r\")\n", "\n", "\n", "cid = plt.gcf().canvas.mpl_connect(\"button_press_event\", onclick)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "![custom-trajectory.png](attachments/custom-trajectory.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Print out the selected trajectory points. Each column corresponds to the position along that selected coordinate, i.e., in this example column 1 is coordinate 0 and column 2 is coordinate 2." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-0.35, 0.67\n", "-0.78, 0.09\n", "-0.50, -0.67\n", " 0.77, -0.86\n", " 1.06, -0.18\n", " 0.95, 0.56\n", " 0.04, 0.43\n", "-0.27, -0.11\n" ] } ], "source": [ "latent_pts = n.array(pts)\n", "for pt in latent_pts:\n", " print(f\"{pt[0]:5.2f}, {pt[1]:5.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set up an external job to save the custom latent components. Connect to the train job to ensure output fields get passed through." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# each component has a components_mode_n/component and components_mode_n/value field,\n", "# so we need to divide the components_mode fields by two to get the total number of components\n", "num_components = int(len([x for x in particles.fields() if \"components_mode\" in x]) / 2)\n", "\n", "slot_spec = [{\"dtype\": \"components\", \"prefix\": f\"components_mode_{k}\", \"required\": True} for k in range(num_components)]\n", "job = project.create_external_job(\"W5\", \"Custom Latents\")\n", "job.connect(\"particles\", \"J243\", \"particles\", slots=slot_spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add and allocate an output for the job to store the custom latent components." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "latents_dset = job.add_output(\n", " type=\"particle\",\n", " name=\"latents\",\n", " slots=slot_spec,\n", " title=\"Latents\",\n", " alloc=len(latent_pts),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save the points into the allocated dataset." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "for k in range(num_components):\n", " latents_dset[f\"components_mode_{k}/component\"] = k\n", " try:\n", " latents_dset[f\"components_mode_{k}/value\"] = latent_pts[:, selected_dimensions.index(k)]\n", " except ValueError:\n", " # if a coordinate is not in our selected_dimensions, set the trajectory to zero in that coordinate\n", " latents_dset[f\"components_mode_{k}/value\"] = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save the output." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "with job.run():\n", " job.save_output(\"latents\", latents_dset)" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 4 }