Source code for myforestplot.vis_utils

from typing import Union, Optional, List, Dict, Tuple, Any
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec


[docs]def obtain_indexes_from_category_item(ser_cate : pd.Series, ser_item : pd.Series ) -> Tuple[np.array, np.array]: """Create index for category and item from series of category and item for vertically aligned labels and errorbar plot. It is noted that index has negative continuous values, starting from 0 to -n. Args: ser_cate: Series of categories matched with ser_item. ser_item: Series of items. Returns: Indexes for caategories and items. """ ser_cate = ser_cate.copy() n_cate = len(ser_cate.unique()) ser_cate[ser_cate.duplicated()] = np.nan y_index = [] y_index_cate = [] index = 0 for cate, item in zip(ser_cate, ser_item): if cate != cate: y_index.append(index) else: y_index_cate.append(index) index -= 1 y_index.append(index) index -= 1 y_index_cate = np.array(y_index_cate) y_index = np.array(y_index) return(y_index_cate, y_index)
[docs]def errorbar_forestplot( ax: plt.Axes, y_index: np.array, df: Optional[pd.DataFrame] = None, risk: str = "risk", lower: Union[str, int] = 0, upper: Union[str, int] = 1, y_adj: float = 0, errorbar_kwds: Optional[dict] = None, ref_kwds: Optional[dict] = None, errorbar_color: Optional[str] = None, ref_color: Optional[str] = None, label: Optional[str] = None, log_scale: bool = False, ): """Error bar plot for a forest plot. Args: ax: Axis to be drawn. y_index: index to be plotted. risk: Column name for risk. lower: Column name for lower confidence interval. upper: Column name for upper confidence interval. y_adj: For this value, plotting is moved. errorbar_kwds: Passed to ax.errorbar function. ref_kwds: Passed to ax.scatter function. df: Dataframe for another result. label: Label for stratified drawings. Passed to ax.errorbar. log_scale: Plot risk in log scale (np.log). """ if errorbar_color is not None: errorbar_kwds["ecolor"] = errorbar_color errorbar_kwds["color"] = errorbar_color if ref_color is not None: ref_kwds["color"] = ref_color y_index = y_index + y_adj df = df.copy() def_errorbar_kwds = dict(fmt="o", capsize=5, markeredgecolor="black", ecolor="black", color='white' ) errorbar_kwds = set_default_keywords(errorbar_kwds, def_errorbar_kwds) def_ref_kwds = dict(marker="s", s=20, color="black") ref_kwds = set_default_keywords(ref_kwds, def_ref_kwds) if log_scale: df[risk] = np.log(df[risk]) df[lower] = np.log(df[lower]) df[upper] = np.log(df[upper]) df["xerr_lower"] = df[risk] - df[lower] df["xerr_upper"] = df[upper] - df[risk] cond = df[risk].notnull() ax.errorbar(df.loc[cond, risk], y_index[cond], xerr=df.loc[cond, ["xerr_lower", "xerr_upper"]].T, label=label, zorder=5, **errorbar_kwds ) cond = df[risk].isnull() ref_v = 0 if log_scale else 1 df["ref"] = df[risk].mask(cond, ref_v).mask(~cond, np.nan) ax.scatter(df["ref"], y_index, zorder=5, **ref_kwds)
[docs]def embed_strings_forestplot( ax: plt.Axes, ser: pd.Series, y_index: np.array, x: float, header: str = "", fontsize: int = None, y_header: float = 1.0, text_kwds: Optional[dict] = None, header_kwds: Optional[dict] = None, replace: Optional[dict] = None, ): """Embed strings/values of one column with header. Args: ser: Values of this series will be embedded. x: x axis value of text position, ranging from 0 to 1. df: Dataframe for another result. """ if text_kwds is None: text_kwds = {} if header_kwds is None: header_kwds = {} ax.text(x, y_header, header, ha="left", va="center", fontsize=fontsize, **header_kwds) if replace is not None: ser = ser.replace(replace) for y, text in zip(y_index, ser): ax.text(x, y, text, ha="left", va="center", fontsize=fontsize, **text_kwds)
[docs]def set_default_keywords(kwds : Optional[dict], def_kwds: dict) -> dict: """Set default keywords arguments. """ if kwds is None: kwds = {} for k, v in def_kwds.items(): kwds[k] = kwds.get(k, v) return kwds
[docs]def get_multiple_y_adjs(n: int, scale: float) -> np.array: """For multiple vertical plotting, automatic adjustments of y_adj for y_index is needed. Args: n: Number of stratificaitons to be plotted. scale: [-scale, scale] is set to be a range of y_adj. """ y_adjs = [0.5 - 1/(n-1)*i for i in range(n)] y_adjs = np.array(y_adjs)*2*scale return y_adjs