Skip to content

mecfs_bio.build_system.task.lcv.lcv_clustermap

Task to create a heatmap plot illustrating the results of Latent Causal Variable (LCV) analysis for a collection of upstream / downstream trait pairs.

The primary value plotted is the posterior mean of the Genetic Causality Proportion (GCP). Significantly non-zero entries (under a Bonferroni correction to the LCV GCP=0 p-values) are marked with an asterisk.

Partly implemented by asking Claude to mimic the logic of genetic_correlation_clustermap_task.py.

Classes:

Functions:

  • gcp_plot

    Produce a plotly heatmap figure showing LCV GCP estimates for each

  • get_sig

    Get a binary array indicating which elements of the LCV result matrix are

  • load_xr_lcv_dataset

    Retrieve the LCV results. Return in the form of an xarray dataset.

Attributes:

LCVPlotMode module-attribute

LCVPlotMode = GCPWithAsterisk

NUM_PAIRS module-attribute

NUM_PAIRS = 'num_pairs'

SigMode module-attribute

SigMode = BonferoniSig

XR_DOWNSTREAM_TRAIT_DIM module-attribute

XR_DOWNSTREAM_TRAIT_DIM = 'downstream_trait'

XR_GCP_ARRAY module-attribute

XR_GCP_ARRAY = 'gcp'

XR_LCV_P_VALUE_ARRAY module-attribute

XR_LCV_P_VALUE_ARRAY = 'lcv_p_value'

XR_LCV_RHO_ARRAY module-attribute

XR_LCV_RHO_ARRAY = 'lcv_rho'

XR_UPSTREAM_TRAIT_DIM module-attribute

XR_UPSTREAM_TRAIT_DIM = 'upstream_trait'

logger module-attribute

logger = get_logger()

BonferoniSig

Attributes:

alpha class-attribute instance-attribute

alpha: float = 0.05

GCPWithAsterisk

Attributes:

color_scale class-attribute instance-attribute

color_scale: str = 'RdBu_r'

sig_mode class-attribute instance-attribute

sig_mode: SigMode = BonferoniSig()

LCVClustermapTask

Bases: Task

Task to generate a heatmap of LCV results

Methods:

Attributes:

deps property

deps: list[Task]

meta instance-attribute

meta: Meta

plot_options instance-attribute

plot_options: LCVPlotMode

save_mode class-attribute instance-attribute

save_mode: bool | str = 'cdn'

source instance-attribute

source: LCVSource

xr_pipe instance-attribute

xr_pipe: XRDataPipe

create_std_with_clustering classmethod

create_std_with_clustering(
    asset_id: str,
    source: LCVSource,
    plot_options: LCVPlotMode,
)
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
@classmethod
def create_std_with_clustering(
    cls,
    asset_id: str,
    source: LCVSource,
    plot_options: LCVPlotMode,
):
    src_meta = source.task.meta
    assert isinstance(src_meta, ResultTableMeta)
    meta = GWASPlotFileMeta(
        trait=src_meta.trait,
        project=src_meta.project,
        extension=".html",
        id=AssetId(asset_id),
    )
    xr_pipe = XRCompositePipe(
        [
            XRCluster(
                array_name=XR_GCP_ARRAY,
                dim=XR_UPSTREAM_TRAIT_DIM,
                metric="euclidean",
            ),
            XRCluster(
                array_name=XR_GCP_ARRAY,
                dim=XR_DOWNSTREAM_TRAIT_DIM,
                metric="euclidean",
            ),
        ]
    )
    return cls(
        meta=meta,
        xr_pipe=xr_pipe,
        source=source,
        plot_options=plot_options,
    )

execute

execute(scratch_dir: Path, fetch: Fetch, wf: WF) -> Asset
Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
def execute(self, scratch_dir: Path, fetch: Fetch, wf: WF) -> Asset:
    ds = load_xr_lcv_dataset(
        src=self.source,
        fetch=fetch,
    )
    ds = self.xr_pipe.process(ds)
    fig = gcp_plot(
        ds=ds,
        plot_mode=self.plot_options,
    )
    out_path = scratch_dir / "result.html"
    fig.write_html(out_path, include_plotlyjs=self.save_mode)
    return FileAsset(out_path)

LCVSource

Describe a dataframe from which to load LCV data

Attributes:

cols property

cols: list[str]

df_pipe class-attribute instance-attribute

df_pipe: DataProcessingPipe = IdentityPipe()

downstream_trait_col class-attribute instance-attribute

downstream_trait_col: str = DOWNSTREAM_TRAIT_COL

gcp_col class-attribute instance-attribute

gcp_col: str = LCV_MEAN_GCP_COL

p_col class-attribute instance-attribute

p_col: str = LCV_PVAL_ZERO_COL

rho_col class-attribute instance-attribute

rho_col: str = LCV_RHO_EST_COL

task instance-attribute

task: Task

upstream_trait_col class-attribute instance-attribute

upstream_trait_col: str = UPSTREAM_TRAIT_COL

gcp_plot

gcp_plot(ds: Dataset, plot_mode: LCVPlotMode) -> Figure

Produce a plotly heatmap figure showing LCV GCP estimates for each upstream / downstream trait pair.

Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
def gcp_plot(ds: xr.Dataset, plot_mode: LCVPlotMode) -> Figure:
    """
    Produce a plotly heatmap figure showing LCV GCP estimates for each
    upstream / downstream trait pair.
    """
    if isinstance(plot_mode, GCPWithAsterisk):
        gcp_df = ds[XR_GCP_ARRAY].to_pandas()
        p_values = ds[XR_LCV_P_VALUE_ARRAY].values
        rho_values = ds[XR_LCV_RHO_ARRAY].values
        sig = get_sig(
            p_value_matrix=p_values,
            num_pairs=ds[NUM_PAIRS].values.item(),
            sig_mode=plot_mode.sig_mode,
        )
        asterisk_matrix = np.where(sig, "★", "")
        customdata = np.stack([p_values, rho_values], axis=-1)

        fig = go.Figure(
            data=go.Heatmap(
                z=gcp_df,
                x=gcp_df.columns,
                y=gcp_df.index,
                text=asterisk_matrix,
                texttemplate="%{text}",
                textfont={"size": 20, "color": "black"},
                colorscale=plot_mode.color_scale,
                zmin=-1,
                zmax=1,
                customdata=customdata,
                hovertemplate="Upstream Trait: %{y}<br>"
                + "Downstream Trait: %{x}<br>"
                + "GCP: %{z}<br>"
                + "Genetic correlation (rho): %{customdata[1]}<br>"
                + "p value (GCP=0): %{customdata[0]}<br>"
                + "<extra></extra>",
                showscale=True,
                hovertemplatefallback="None",
            )
        )
        fig.update_layout(
            xaxis=dict(side="top"),
        )
        fig.update_yaxes(autorange="reversed")  # want the origin in top left corner
        return fig
    raise NotImplementedError()

get_sig

get_sig(
    p_value_matrix: ndarray,
    num_pairs: int,
    sig_mode: SigMode,
) -> np.ndarray

Get a binary array indicating which elements of the LCV result matrix are significant under the given significance scheme.

Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
def get_sig(
    p_value_matrix: np.ndarray,
    num_pairs: int,
    sig_mode: SigMode,
) -> np.ndarray:
    """
    Get a binary array indicating which elements of the LCV result matrix are
    significant under the given significance scheme.
    """
    if isinstance(sig_mode, BonferoniSig):
        thresh = sig_mode.alpha / num_pairs
        logger.debug(
            f"Bonferoni significance threshold when alpha={sig_mode.alpha} "
            f"and there are {num_pairs} tests is {thresh}"
        )
        return p_value_matrix <= thresh
    raise ValueError(f"Invalid mode {sig_mode}")

load_xr_lcv_dataset

load_xr_lcv_dataset(
    src: LCVSource, fetch: Fetch
) -> xr.Dataset

Retrieve the LCV results. Return in the form of an xarray dataset.

Source code in mecfs_bio/build_system/task/lcv/lcv_clustermap.py
def load_xr_lcv_dataset(
    src: LCVSource,
    fetch: Fetch,
) -> xr.Dataset:
    """
    Retrieve the LCV results.  Return in the form of an xarray dataset.
    """
    asset = fetch(src.task.asset_id)
    df_nw = (
        src.df_pipe.process(
            scan_dataframe_asset(
                asset,
                meta=src.task.meta,
            )
        )
        .select(*src.cols)
        .collect()
    )
    num_pairs = len(df_nw)
    df = df_nw.to_pandas()

    pivoted_gcp = df.pivot(
        index=src.upstream_trait_col,
        columns=src.downstream_trait_col,
        values=src.gcp_col,
    )
    pivoted_rho = df.pivot(
        index=src.upstream_trait_col,
        columns=src.downstream_trait_col,
        values=src.rho_col,
    )
    pivoted_p = df.pivot(
        index=src.upstream_trait_col,
        columns=src.downstream_trait_col,
        values=src.p_col,
    )
    gcp_da = xr.DataArray(
        pivoted_gcp, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
    )
    p_da = xr.DataArray(
        pivoted_p, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
    )
    rho_da = xr.DataArray(
        pivoted_rho, dims=(XR_UPSTREAM_TRAIT_DIM, XR_DOWNSTREAM_TRAIT_DIM)
    )

    ds = xr.Dataset(
        {
            XR_GCP_ARRAY: gcp_da,
            XR_LCV_P_VALUE_ARRAY: p_da,
            XR_LCV_RHO_ARRAY: rho_da,
        }
    )
    ds[NUM_PAIRS] = num_pairs
    return ds