Skip to content

API reference

Auto-generated from the package docstrings.

Search engine

mhc_tp.engine.search

Search Gibbs cluster matrices against the reference array.

Method (how a cluster is matched to an allotype)

  1. Each GibbsCluster motif and each reference allotype is represented as a position-specific scoring matrix (PSSM): n_positions x 20 amino-acid weights. Gibbs matrices are zero-padded to a common position count so they can be batched.
  2. For every (cluster, allotype) pair the two PSSMs are compared by Pearson correlation of their flattened weights, computed only over the cells V = {k : g_k != 0 and g_k not NaN} that are informative in the cluster matrix (so padding and empty positions do not dilute the score)::

    PCC(g, r) = Σ_{k∈V}(g_k - ḡ)(r_k - r̄) / ( |V| · σ_g · σ_r ) ∈ [-1, 1]

with means/std taken over V. It is scale- and offset-invariant, so it scores motif shape, not absolute magnitudes. Full derivation and numerical guards: :func:mhc_tp.engine.kernels.compute_all_correlations. 3. Per cluster the allotypes are ranked by correlation (PCC, -1..1; 1.0 = identical motif). Selection is then either threshold-gated (default) or pure top-N (always_top_n); see :func:search.

The correlation is a motif-shape similarity: it rewards matching the relative preference pattern across positions, not the absolute weight magnitudes.

search

search(
    reference,
    gibbs_matrices,
    threshold=0.7,
    top_n=3,
    hla_filter=None,
    always_top_n=False,
)

Return {(gibbs_name, ref_formatted): correlation} for top-N hits.

By default a hit must score >= threshold to be returned, so a cluster may yield fewer than top_n rows (or none). When always_top_n is set, every cluster returns its top_n best matches regardless of threshold — the threshold then only annotates confidence downstream, it never drops a row.

Source code in src/mhc_tp/engine/search.py
def search(
    reference: pd.DataFrame,
    gibbs_matrices: dict[str, np.ndarray],
    threshold: float = 0.70,
    top_n: int = 3,
    hla_filter: list[str] | None = None,
    always_top_n: bool = False,
) -> dict[tuple[str, str], float]:
    """Return ``{(gibbs_name, ref_formatted): correlation}`` for top-N hits.

    By default a hit must score ``>= threshold`` to be returned, so a cluster
    may yield fewer than ``top_n`` rows (or none). When ``always_top_n`` is set,
    every cluster returns its ``top_n`` best matches regardless of threshold —
    the threshold then only annotates confidence downstream, it never drops a
    row.
    """
    ref_arr, max_positions = build_reference_array(reference)
    names = list(gibbs_matrices.keys())

    padded = np.zeros((len(names), max_positions, N_AMINO_ACIDS), dtype=np.float32)
    for i, name in enumerate(names):
        m = gibbs_matrices[name]
        padded[i, : m.shape[0], : m.shape[1]] = m

    mask = np.ones(len(reference), dtype=np.bool_)
    if hla_filter:
        mask = reference["formatted"].isin(hla_filter).to_numpy()

    # In always-top-N mode, store every valid correlation (kernel keeps a -1.0
    # sentinel for cells below its threshold), then rank in Python.
    kernel_threshold = -2.0 if always_top_n else threshold
    corr, _invalid = compute_all_correlations(
        padded, ref_arr.astype(np.float32), mask, kernel_threshold
    )

    formatted = reference["formatted"].to_numpy()
    out: dict[tuple[str, str], float] = {}
    for i, name in enumerate(names):
        row = corr[i, :]
        order = np.argsort(row)[::-1]
        if always_top_n:
            # Top-N among computed (non-sentinel, unmasked) cells, any score.
            hits = [j for j in order if mask[j] and row[j] > -1.0][:top_n]
        else:
            hits = [j for j in order if row[j] >= threshold][:top_n]
        for j in hits:
            out[(name, str(formatted[j]))] = float(row[j])
    return out

mhc_tp.engine.kernels

Numba JIT correlation kernel (ported from the proven NumbaSearch engine).

compute_all_correlations

compute_all_correlations(
    gibbs_matrices, ref_matrices, hla_mask, threshold
)

All-pairs flattened Pearson correlation, parallel over Gibbs matrices.

Each PSSM is flattened to a vector. Only the cells that are informative in the Gibbs matrix are scored: the valid set is

V = { k : g_k != 0 and g_k is not NaN }

Restricted to those cells, the score for a (Gibbs g, reference r) pair is the Pearson correlation coefficient

          (1/|V|) * Σ_{k in V} (g_k - ḡ)(r_k - r̄)
PCC(g,r) = ----------------------------------------
                        σ_g · σ_r

where ḡ, r̄ are the means and σ_g, σ_r the population standard deviations taken over V:

ḡ   = (1/|V|) Σ g_k ,                σ_g = sqrt( (1/|V|) Σ (g_k - ḡ)^2 )
r̄   = (1/|V|) Σ r_k ,                σ_r = sqrt( (1/|V|) Σ (r_k - r̄)^2 )

PCC lies in [-1, 1] (1 = identical motif shape) and is scale/offset invariant, so it measures the pattern of position preferences rather than absolute weight magnitudes.

Guards: a Gibbs matrix with |V| < 10 or σ_g = 0 is flagged invalid (its row is skipped); a reference with σ_r = 0 is skipped for that pair. A score is stored only when PCC >= threshold; otherwise the cell keeps the -1.0 sentinel. Returns (correlations, invalid_flags).

Source code in src/mhc_tp/engine/kernels.py
@jit(nopython=True, parallel=True, cache=True)
def compute_all_correlations(gibbs_matrices, ref_matrices, hla_mask, threshold):
    r"""All-pairs flattened Pearson correlation, parallel over Gibbs matrices.

    Each PSSM is flattened to a vector. Only the cells that are informative in
    the Gibbs matrix are scored: the valid set is

        V = { k : g_k != 0 and g_k is not NaN }

    Restricted to those cells, the score for a (Gibbs g, reference r) pair is
    the Pearson correlation coefficient

                  (1/|V|) * Σ_{k in V} (g_k - ḡ)(r_k - r̄)
        PCC(g,r) = ----------------------------------------
                                σ_g · σ_r

    where ḡ, r̄ are the means and σ_g, σ_r the population standard deviations
    taken over V:

        ḡ   = (1/|V|) Σ g_k ,                σ_g = sqrt( (1/|V|) Σ (g_k - ḡ)^2 )
        r̄   = (1/|V|) Σ r_k ,                σ_r = sqrt( (1/|V|) Σ (r_k - r̄)^2 )

    PCC lies in [-1, 1] (1 = identical motif shape) and is scale/offset
    invariant, so it measures the *pattern* of position preferences rather than
    absolute weight magnitudes.

    Guards: a Gibbs matrix with |V| < 10 or σ_g = 0 is flagged invalid (its row
    is skipped); a reference with σ_r = 0 is skipped for that pair. A score is
    stored only when PCC >= ``threshold``; otherwise the cell keeps the -1.0
    sentinel. Returns (correlations, invalid_flags).
    """
    n_gibbs = gibbs_matrices.shape[0]
    n_refs = ref_matrices.shape[0]
    correlations = np.full((n_gibbs, n_refs), -1.0, dtype=np.float32)
    invalid_flags = np.zeros(n_gibbs, dtype=np.int32)

    for i in prange(n_gibbs):
        gibbs_flat = gibbs_matrices[i].flatten()
        valid = ~(np.isnan(gibbs_flat) | (gibbs_flat == 0.0))
        gibbs_clean = gibbs_flat[valid]
        if len(gibbs_clean) < 10:
            invalid_flags[i] = 1
            continue
        g_mean = np.mean(gibbs_clean)
        g_std = np.std(gibbs_clean)
        if g_std == 0.0:
            invalid_flags[i] = 1
            continue
        for j in range(n_refs):
            if not hla_mask[j]:
                continue
            ref_clean = ref_matrices[j].flatten()[valid]
            r_mean = np.mean(ref_clean)
            r_std = np.std(ref_clean)
            if r_std == 0.0:
                continue
            num = np.mean((gibbs_clean - g_mean) * (ref_clean - r_mean))
            corr = num / (g_std * r_std)
            if corr >= threshold:
                correlations[i, j] = corr
    return correlations, invalid_flags

Reference data

mhc_tp.refdata.fetch

Resolve and fetch prebuilt reference parquets to a per-user data dir.

End users run mhc-tp fetch to download the prebuilt class I+II reference parquets (with embedded Seq2Logo reference logos) instead of building them. The download source + checksums live in the packaged reference_manifest.tsv; the maintainer fills them in on each release.

data_dir

data_dir()

User data dir for reference files. Overridable via MHC_TP_DATA_DIR.

Source code in src/mhc_tp/refdata/fetch.py
def data_dir() -> Path:
    """User data dir for reference files. Overridable via ``MHC_TP_DATA_DIR``."""
    override = os.environ.get("MHC_TP_DATA_DIR")
    return Path(override) if override else Path(platformdirs.user_data_dir(_APP))

reference_path

reference_path(species)

Expected path of a species reference parquet in the data dir.

Source code in src/mhc_tp/refdata/fetch.py
def reference_path(species: str) -> Path:
    """Expected path of a species reference parquet in the data dir."""
    return data_dir() / f"{species.lower()}.parquet"

resolve_reference

resolve_reference(species, override=None)

Return the reference parquet path, raising a helpful error if absent.

Source code in src/mhc_tp/refdata/fetch.py
def resolve_reference(species: str, override: str | None = None) -> Path:
    """Return the reference parquet path, raising a helpful error if absent."""
    if override:
        p = Path(override)
        if not p.exists():
            raise FileNotFoundError(f"reference not found: {p}")
        return p
    p = reference_path(species)
    if not p.exists():
        raise FileNotFoundError(
            f"No {species} reference at {p}. Run: mhc-tp fetch --species {species} "
            f"(or pass --reference <path>)."
        )
    return p

load_manifest

load_manifest()

Parse the packaged reference manifest (species/filename/sha256/url).

Source code in src/mhc_tp/refdata/fetch.py
def load_manifest() -> list[dict]:
    """Parse the packaged reference manifest (species/filename/sha256/url)."""
    text = files("mhc_tp.refdata").joinpath("reference_manifest.tsv").read_text()
    rows = []
    for line in text.splitlines():
        line = line.strip()
        if not line or line.startswith("#") or line.startswith("species\t"):
            continue
        species, filename, sha256, url = line.split("\t")
        rows.append(
            {"species": species, "filename": filename, "sha256": sha256, "url": url}
        )
    return rows

fetch

fetch(species='all', dest=None)

Download the reference parquet(s) into the data dir; verify checksums.

Source code in src/mhc_tp/refdata/fetch.py
def fetch(species: str = "all", dest: str | None = None) -> list[Path]:
    """Download the reference parquet(s) into the data dir; verify checksums."""
    out_dir = Path(dest) if dest else data_dir()
    out_dir.mkdir(parents=True, exist_ok=True)
    fetched = []
    for row in load_manifest():
        if species != "all" and row["species"] != species:
            continue
        if not row["url"] or row["url"] == "-":
            raise RuntimeError(
                f"No download URL configured for the {row['species']} reference yet. "
                f"Build it locally with `mhc-tp build-ref`, or point --reference "
                f"at a parquet."
            )
        target = out_dir / row["filename"]
        urllib.request.urlretrieve(
            row["url"], target
        )  # noqa: S310 (trusted manifest URL)
        expected = row["sha256"]
        if expected and expected != "-":
            actual = hashlib.sha256(target.read_bytes()).hexdigest()
            if actual != expected:
                target.unlink(missing_ok=True)
                raise RuntimeError(f"checksum mismatch for {row['filename']}")
        fetched.append(target)
    return fetched

Report

mhc_tp.report.render

Assemble the standalone HTML report.

render_report

render_report(
    correlation_dict,
    reference_df,
    gibbs_matrices,
    output_dir,
    kld_df=None,
    version="",
    gibbs_dir=None,
    logo_map=None,
    name_map=None,
    top_n=3,
    threshold=0.7,
    always_top_n=False,
)

Write /clust_result/mhc-tp-result.html and return its path.

logo_map ({formatted: png_bytes}) supplies reference logos when the reference DataFrame was loaded without the heavy logo column. name_map ({formatted: display}) supplies pretty allele labels.

Source code in src/mhc_tp/report/render.py
def render_report(
    correlation_dict,
    reference_df,
    gibbs_matrices,
    output_dir,
    kld_df: pd.DataFrame | None = None,
    version: str = "",
    gibbs_dir: str | None = None,
    logo_map: dict | None = None,
    name_map: dict | None = None,
    top_n: int = 3,
    threshold: float = 0.70,
    always_top_n: bool = False,
) -> str:
    """Write <output_dir>/clust_result/mhc-tp-result.html and return its path.

    ``logo_map`` ({formatted: png_bytes}) supplies reference logos when the
    reference DataFrame was loaded without the heavy ``logo`` column.
    ``name_map`` ({formatted: display}) supplies pretty allele labels.
    """
    logo_map = logo_map or {}
    name_map = name_map or {}
    ref_by_fmt = {r.formatted: r for r in reference_df.itertuples()}

    table_rows = datatable_rows(correlation_dict, kld_df, name_map, threshold)
    pcc_json = json.dumps(pcc_records(correlation_dict, name_map))

    # Best HLA per cluster id, grouped by the number of clusters N, with both
    # logos rendered from the matrices.
    sections: dict[int, list] = {}
    seen: set[str] = set()
    for (gibbs_name, hla), corr in sorted(
        correlation_dict.items(), key=lambda kv: -kv[1]
    ):
        cid, group, nclust = parse_cluster_id(gibbs_name)
        ref = ref_by_fmt.get(hla)
        if ref is None:
            continue
        if cid in seen:
            continue
        seen.add(cid)
        hla_display = name_map.get(hla, hla)
        # Reference logo: embedded Seq2Logo PNG from the parquet if present,
        # else the logomaker fallback rendered from the matrix.
        ref_logo_bytes = logo_map.get(hla) or getattr(ref, "logo", None)
        if ref_logo_bytes:
            ref_logo = png_bytes_to_data_uri(ref_logo_bytes)
        else:
            ref_mat = np.asarray(ref.matrix, dtype=np.float32).reshape(
                int(ref.n_positions), N_AMINO_ACIDS
            )
            ref_logo = _render_logo(ref_mat, title=hla_display)

        # Cluster logo: GibbsCluster's own Seq2Logo output if available, else fallback.
        cluster_logo = find_cluster_logo(gibbs_dir, cid) if gibbs_dir else None
        if not cluster_logo:
            gibbs_mat = gibbs_matrices.get(gibbs_name)
            cluster_logo = (
                _render_logo(gibbs_mat, title=cid) if gibbs_mat is not None else ""
            )

        sections.setdefault(nclust, []).append(
            {
                "cid": cid,
                "group": group,
                "hla": hla_display,
                "correlation": round(float(corr), 3),
                "below": float(corr) < threshold,
                "kld": _kld(kld_df, group, nclust),
                "ref_logo": ref_logo,
                "cluster_logo": cluster_logo,
            }
        )

    cluster_sections = [
        {
            "n_clusters": n,
            "groups": sorted(sections[n], key=lambda c: c["group"]),
        }
        for n in sorted(sections)
    ]

    env = Environment(
        loader=FileSystemLoader(str(_TEMPLATES)),
        autoescape=select_autoescape(["html", "j2"]),
    )
    template = env.get_template("report.html.j2")
    html = template.render(
        version=version,
        pcc_json=pcc_json,
        table_rows=table_rows,
        cluster_sections=cluster_sections,
        top_n=top_n,
        threshold=threshold,
        always_top_n=always_top_n,
    )

    out_dir = Path(output_dir) / "clust_result"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "mhc-tp-result.html"
    out_path.write_text(html)
    return str(out_path)

mhc_tp.naming

Display formatting for allele names.

The reference parquet stores each allotype in the very string that was burned into its Seq2Logo motif title (e.g. HLA-A25:08, HLA-A3301, DRB1_0101, H-2-IAb). To keep the report's text labels consistent with those embedded reference logo titles, the canonical display name is the source allotype verbatim.

pretty_allele

pretty_allele(allotype)

Canonical display name for an allotype, matching its reference logo title.

The source allotype is already the name shown on the embedded Seq2Logo motif (HLA-A25:08, HLA-A3301, DRB1_0101 ...), so it is returned verbatim apart from surrounding whitespace.

``HLA-A25:08`` -> ``HLA-A25:08``
``HLA-A3301``  -> ``HLA-A3301``
``DRB1_0101``  -> ``DRB1_0101``
``H-2-IAb``    -> ``H-2-IAb``
Source code in src/mhc_tp/naming.py
def pretty_allele(allotype: str | None) -> str:
    """Canonical display name for an allotype, matching its reference logo title.

    The source allotype is already the name shown on the embedded Seq2Logo
    motif (``HLA-A25:08``, ``HLA-A3301``, ``DRB1_0101`` ...), so it is returned
    verbatim apart from surrounding whitespace.

        ``HLA-A25:08`` -> ``HLA-A25:08``
        ``HLA-A3301``  -> ``HLA-A3301``
        ``DRB1_0101``  -> ``DRB1_0101``
        ``H-2-IAb``    -> ``H-2-IAb``
    """
    return (allotype or "").strip()