3. 3D Flex: Custom Latent Trajectory#
Note
Read the tutorial to learn more about 3D Flexible Refinement.
This example covers the following:
Load particle latent coordinates from a 3D Flex Training job
Plot the latent coordinates
Interactively draw a trajectory through the latent space
Output the trajectory as a new output in CryoSPARC
The resulting trajectory may be used as input to the 3D Flex Generator job to generate a volume series along the trajectory. In this way, you can visualize specific regions or pathways through the latent conformational distribution of the particle.
import numpy as n
from cryosparc.tools import CryoSPARC
cs = CryoSPARC(host="cryoem5", base_port=40000)
assert cs.test_connection()
Connection succeeded to CryoSPARC command_core at http://cryoem5:40002
Connection succeeded to CryoSPARC command_vis at http://cryoem5:40003
Load the particles dataset from the 3D Flex Training job.
import pandas as pd
project = cs.find_project("P251")
particles = project.find_job("J21").load_output("particles")
pd.DataFrame(particles.rows())
components_mode_0/component | components_mode_0/value | components_mode_1/component | components_mode_1/value | uid | |
---|---|---|---|---|---|
0 | 0 | 0.016000 | 1 | 0.314667 | 3070975014664207456 |
1 | 0 | -0.122667 | 1 | -0.293333 | 4618435520677801850 |
2 | 0 | -0.496000 | 1 | 0.048000 | 5473205460946538064 |
3 | 0 | 0.357333 | 1 | 0.133333 | 16310523440862460071 |
4 | 0 | 0.698667 | 1 | 0.048000 | 13340786341065263892 |
... | ... | ... | ... | ... | ... |
83995 | 0 | 0.229333 | 1 | 0.314667 | 16055293458929423443 |
83996 | 0 | -0.208000 | 1 | 0.784000 | 4138809144408497890 |
83997 | 0 | 0.069333 | 1 | 0.005333 | 12856011472364466207 |
83998 | 0 | -0.176000 | 1 | 0.272000 | 14511870028664140377 |
83999 | 0 | 0.890667 | 1 | 0.048000 | 17110459322516273008 |
84000 rows × 5 columns
Create an interative plot that responds to on-click events.
Once the plot is drawn, click repeatedly on the plot along a trajectory you wish to sample in the latent space. The points along this trajectory will form the output of the notebook and be used in 3D Flex Generator.
from matplotlib import pyplot as plt
fig = plt.figure(figsize=(5, 5))
def do_plot():
plt.plot(
particles["components_mode_0/value"][::10],
particles["components_mode_1/value"][::10],
".",
alpha=0.5,
color="gray",
)
plt.grid()
do_plot()
chain = True
pts = []
def onclick(event):
fig.clf()
do_plot()
pts.append([event.xdata, event.ydata])
apts = n.array(pts)
if chain:
plt.plot(apts[0, 0], apts[0, 1], "xk")
plt.plot(apts[:, 0], apts[:, 1], ".-r")
else:
for k, i in enumerate(range(0, len(apts), 2)):
plt.plot(apts[i, 0], apts[i, 1], "xk")
plt.plot(apts[i : i + 2, 0], apts[i : i + 2, 1], ".-", label=str(k))
plt.legend()
cid = plt.gcf().canvas.mpl_connect("button_press_event", onclick)
Print out the selected trajectory points.
latent_pts = n.array(pts)
for pt in latent_pts:
print(f"{pt[0]:5.2f}, {pt[1]:5.2f}")
0.05, -0.52
-0.13, -0.47
-0.26, -0.32
-0.34, -0.14
-0.35, -0.06
-0.35, 0.08
-0.35, 0.18
-0.28, 0.23
-0.16, 0.27
-0.06, 0.27
0.00, 0.27
0.23, 0.27
0.25, 0.31
0.25, 0.40
0.21, 0.45
0.16, 0.46
0.06, 0.49
-0.07, 0.51
-0.20, 0.49
-0.34, 0.37
-0.47, 0.28
-0.52, -0.00
-0.50, -0.20
-0.45, -0.26
-0.26, -0.27
-0.23, -0.27
-0.07, -0.30
0.19, -0.32
0.26, -0.26
0.29, -0.20
0.29, -0.20
0.29, -0.20
0.30, -0.11
0.38, -0.07
0.50, -0.04
0.56, 0.05
0.56, 0.06
0.56, 0.06
Set up an external job to save the custom latent components. Connect to the train job to ensure output fields get passed through.
job = project.create_external_job("W4", "Custom Latents")
job.connect("particles", "J21", "particles", slots=["components_mode_%d" % k for k in range(2)])
Add and allocate an output for the job to store the custom latent components.
latents_dset = job.add_output(
type="particle",
name="latents",
slots=[{"prefix": "components_mode_%d" % k, "dtype": "components", "required": True} for k in range(2)],
title="Latents",
alloc=len(latent_pts),
)
Save the points into the allocated dataset.
for k in range(2):
latents_dset["components_mode_%d/component" % k] = k
latents_dset["components_mode_%d/value" % k] = latent_pts[:, k]
Save the output.
with job.run():
job.save_output("latents", latents_dset)