Task to create a heatmap plot illustrating the results of Latent Causal Variable
(LCV) analysis for a collection of upstream / downstream trait pairs.
The primary value plotted is the posterior mean of the Genetic Causality
Proportion (GCP). Significantly non-zero entries (under a Bonferroni correction
to the LCV GCP=0 p-values) are marked with an asterisk.
Partly implemented by asking Claude to mimic the logic of genetic_correlation_clustermap_task.py.
Classes:
Functions:
-
gcp_plot
–
Produce a plotly heatmap figure showing LCV GCP estimates for each
-
get_sig
–
Get a binary array indicating which elements of the LCV result matrix are
-
load_xr_lcv_dataset
–
Retrieve the LCV results. Return in the form of an xarray dataset.
Attributes:
LCVPlotMode
module-attribute
LCVPlotMode = GCPWithAsterisk
NUM_PAIRS
module-attribute
XR_DOWNSTREAM_TRAIT_DIM
module-attribute
XR_DOWNSTREAM_TRAIT_DIM = 'downstream_trait'
XR_GCP_ARRAY
module-attribute
XR_LCV_P_VALUE_ARRAY
module-attribute
XR_LCV_P_VALUE_ARRAY = 'lcv_p_value'
XR_LCV_RHO_ARRAY
module-attribute
XR_LCV_RHO_ARRAY = 'lcv_rho'
XR_UPSTREAM_TRAIT_DIM
module-attribute
XR_UPSTREAM_TRAIT_DIM = 'upstream_trait'
BonferoniSig
Attributes:
alpha
class-attribute
instance-attribute
GCPWithAsterisk
Attributes:
color_scale
class-attribute
instance-attribute
color_scale: str = 'RdBu_r'
sig_mode
class-attribute
instance-attribute
sig_mode: SigMode = BonferoniSig()
LCVClustermapTask
Bases: Task
Task to generate a heatmap of LCV results
Methods:
Attributes:
plot_options
instance-attribute
plot_options: LCVPlotMode
save_mode
class-attribute
instance-attribute
save_mode: bool | str = 'cdn'
source
instance-attribute
xr_pipe
instance-attribute
create_std_with_clustering
classmethod
create_std_with_clustering(
asset_id: str,
source: LCVSource,
plot_options: LCVPlotMode,
)
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
| @classmethod
def create_std_with_clustering(
cls,
asset_id: str,
source: LCVSource,
plot_options: LCVPlotMode,
):
src_meta = source.task.meta
assert isinstance(src_meta, ResultTableMeta)
meta = GWASPlotFileMeta(
trait=src_meta.trait,
project=src_meta.project,
extension=".html",
id=AssetId(asset_id),
)
xr_pipe = XRCompositePipe(
[
XRCluster(
array_name=XR_GCP_ARRAY,
dim=XR_UPSTREAM_TRAIT_DIM,
metric="euclidean",
),
XRCluster(
array_name=XR_GCP_ARRAY,
dim=XR_DOWNSTREAM_TRAIT_DIM,
metric="euclidean",
),
]
)
return cls(
meta=meta,
xr_pipe=xr_pipe,
source=source,
plot_options=plot_options,
)
|
execute
execute(scratch_dir: Path, fetch: Fetch, wf: WF) -> Asset
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
| def execute(self, scratch_dir: Path, fetch: Fetch, wf: WF) -> Asset:
ds = load_xr_lcv_dataset(
src=self.source,
fetch=fetch,
)
ds = self.xr_pipe.process(ds)
fig = gcp_plot(
ds=ds,
plot_mode=self.plot_options,
)
out_path = scratch_dir / "result.html"
fig.write_html(out_path, include_plotlyjs=self.save_mode)
return FileAsset(out_path)
|
LCVSource
Describe a dataframe from which to load LCV data
Attributes:
df_pipe
class-attribute
instance-attribute
df_pipe: DataProcessingPipe = IdentityPipe()
downstream_trait_col
class-attribute
instance-attribute
downstream_trait_col: str = DOWNSTREAM_TRAIT_COL
gcp_col
class-attribute
instance-attribute
gcp_col: str = LCV_MEAN_GCP_COL
p_col
class-attribute
instance-attribute
p_col: str = LCV_PVAL_ZERO_COL
rho_col
class-attribute
instance-attribute
rho_col: str = LCV_RHO_EST_COL
upstream_trait_col
class-attribute
instance-attribute
upstream_trait_col: str = UPSTREAM_TRAIT_COL
gcp_plot
gcp_plot(ds: Dataset, plot_mode: LCVPlotMode) -> Figure
Produce a plotly heatmap figure showing LCV GCP estimates for each
upstream / downstream trait pair.
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
| def gcp_plot(ds: xr.Dataset, plot_mode: LCVPlotMode) -> Figure:
"""
Produce a plotly heatmap figure showing LCV GCP estimates for each
upstream / downstream trait pair.
"""
if isinstance(plot_mode, GCPWithAsterisk):
gcp_df = ds[XR_GCP_ARRAY].to_pandas()
p_values = ds[XR_LCV_P_VALUE_ARRAY].values
rho_values = ds[XR_LCV_RHO_ARRAY].values
sig = get_sig(
p_value_matrix=p_values,
num_pairs=ds[NUM_PAIRS].values.item(),
sig_mode=plot_mode.sig_mode,
)
asterisk_matrix = np.where(sig, "★", "")
customdata = np.stack([p_values, rho_values], axis=-1)
fig = go.Figure(
data=go.Heatmap(
z=gcp_df,
x=gcp_df.columns,
y=gcp_df.index,
text=asterisk_matrix,
texttemplate="%{text}",
textfont={"size": 20, "color": "black"},
colorscale=plot_mode.color_scale,
zmin=-1,
zmax=1,
customdata=customdata,
hovertemplate="Upstream Trait: %{y}<br>"
+ "Downstream Trait: %{x}<br>"
+ "GCP: %{z}<br>"
+ "Genetic correlation (rho): %{customdata[1]}<br>"
+ "p value (GCP=0): %{customdata[0]}<br>"
+ "<extra></extra>",
showscale=True,
hovertemplatefallback="None",
)
)
fig.update_layout(
xaxis=dict(side="top"),
)
fig.update_yaxes(autorange="reversed") # want the origin in top left corner
return fig
raise NotImplementedError()
|
get_sig
get_sig(
p_value_matrix: ndarray,
num_pairs: int,
sig_mode: SigMode,
) -> np.ndarray
Get a binary array indicating which elements of the LCV result matrix are
significant under the given significance scheme.
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
| def get_sig(
p_value_matrix: np.ndarray,
num_pairs: int,
sig_mode: SigMode,
) -> np.ndarray:
"""
Get a binary array indicating which elements of the LCV result matrix are
significant under the given significance scheme.
"""
if isinstance(sig_mode, BonferoniSig):
thresh = sig_mode.alpha / num_pairs
logger.debug(
f"Bonferoni significance threshold when alpha={sig_mode.alpha} "
f"and there are {num_pairs} tests is {thresh}"
)
return p_value_matrix <= thresh
raise ValueError(f"Invalid mode {sig_mode}")
|
load_xr_lcv_dataset
load_xr_lcv_dataset(
src: LCVSource, fetch: Fetch
) -> xr.Dataset
Retrieve the LCV results. Return in the form of an xarray dataset.
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
| def load_xr_lcv_dataset(
src: LCVSource,
fetch: Fetch,
) -> xr.Dataset:
"""
Retrieve the LCV results. Return in the form of an xarray dataset.
"""
asset = fetch(src.task.asset_id)
df_nw = (
src.df_pipe.process(
scan_dataframe_asset(
asset,
meta=src.task.meta,
)
)
.select(*src.cols)
.collect()
)
num_pairs = len(df_nw)
df = df_nw.to_pandas()
pivoted_gcp = df.pivot(
index=src.upstream_trait_col,
columns=src.downstream_trait_col,
values=src.gcp_col,
)
pivoted_rho = df.pivot(
index=src.upstream_trait_col,
columns=src.downstream_trait_col,
values=src.rho_col,
)
pivoted_p = df.pivot(
index=src.upstream_trait_col,
columns=src.downstream_trait_col,
values=src.p_col,
)
gcp_da = xr.DataArray(
pivoted_gcp, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
)
p_da = xr.DataArray(
pivoted_p, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
)
rho_da = xr.DataArray(
pivoted_rho, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
)
ds = xr.Dataset(
{
XR_GCP_ARRAY: gcp_da,
XR_LCV_P_VALUE_ARRAY: p_da,
XR_LCV_RHO_ARRAY: rho_da,
}
)
ds[NUM_PAIRS] = num_pairs
return ds
|