# SPDX-License-Identifier: MIT
"""
Plotting utilities for visualizing evolutionary progress.
This module provides functions for visualizing:
- Fitness statistics (best, mean, median, std)
- Diversity over time
- Mutation probability and strength trends
- Fitness comparison
"""
from __future__ import annotations
import os
import tempfile
import warnings
from typing import Iterable, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from evonet.core import Nnet
from PIL import Image
from evolib.core.population import Pop
PopOrDf = Union["Pop", pd.DataFrame]
# Helper type: (x, y, label, style_dict)
ExtraLines = Iterable[
tuple[
Sequence[float] | np.ndarray,
Sequence[float] | np.ndarray,
Optional[str],
Optional[dict],
]
]
def _as_history_df(obj: PopOrDf) -> pd.DataFrame:
return obj.history_df if hasattr(obj, "history_df") else obj
[docs]
def plot_history(
histories: Sequence[PopOrDf] | Pop,
*,
metrics: list[str] = ["best_fitness"],
labels: Optional[list[str]] = None,
title: str = "Evolutionary Metrics",
xlabel: str = "Generation",
ylabel: Optional[str] = None,
show: bool = True,
log_y: bool = False,
save_path: Optional[str] = None,
with_std: bool = True,
figsize: tuple = (10, 6),
) -> None:
"""
General-purpose plotting function for evolutionary history metrics.
Supports multiple metrics (e.g. fitness, diversity, mutation )probability and
comparison across runs.
Args:
histories (pd.DataFrame | list[pd.DataFrame]): Single or multiple history
DataFrames.
metrics (list[str]): List of metric column names to plot (e.g., 'best_fitness').
labels (list[str], optional): Optional list of labels for each run.
title (str): Title of the plot.
xlabel (str): Label for the x-axis.
ylabel (str | None): Label for the y-axis (auto-generated if None).
show (bool): Whether to display the plot interactively.
log_y (bool): Apply logarithmic scale to the y-axis.
save_path (str | None): Optional file path to save the plot.
with_std (bool): If True, plot standard deviation shading when available.
figsize (tuple): Size of the figure (width, height).
"""
# Normalize input to list
if isinstance(histories, Pop):
histories = [histories]
elif isinstance(histories, pd.DataFrame):
histories = [histories]
dfs = [_as_history_df(h) for h in histories]
if labels is None:
labels = [f"Run {i+1}" for i in range(len(dfs))]
plt.figure(figsize=figsize)
if log_y:
plt.yscale("log")
for hist, label in zip(dfs, labels):
generations = hist["generation"]
for metric in metrics:
if metric not in hist.columns:
warnings.warn(
f"Metric '{metric}' not found in history for '{label}' — skipping."
)
continue
line_label = f"{label} - {metric}" if len(dfs) > 1 else metric
plt.plot(generations, hist[metric], label=line_label)
# Optional standard deviation band if available
if with_std and "mean" in metric:
std_col = metric.replace("mean", "std")
if std_col in hist.columns:
lower = hist[metric] - hist[std_col]
upper = hist[metric] + hist[std_col]
plt.fill_between(generations, lower, upper, alpha=0.2)
plt.xlabel(xlabel)
if ylabel:
plt.ylabel(ylabel)
elif len(metrics) == 1:
plt.ylabel(metrics[0].replace("_", " ").capitalize())
else:
plt.ylabel("Metric Value")
plt.title(title)
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300)
print(f"Plot saved to '{save_path}'")
if show:
plt.show()
else:
plt.close()
[docs]
def plot_fitness(
history: PopOrDf,
*,
title: str = "Fitness over Generations",
show: bool = True,
log: bool = False,
save_path: Optional[str] = None,
) -> None:
"""Wrapper to plot best, mean, and median fitness with optional std band."""
df = _as_history_df(history)
plot_history(
df,
metrics=["best_fitness", "mean_fitness", "median_fitness"],
title=title,
log_y=log,
ylabel="Fitness",
save_path=save_path,
show=show,
with_std=True,
)
[docs]
def plot_diversity(
history: PopOrDf,
*,
title: str = "Population Diversity",
show: bool = True,
log: bool = False,
save_path: Optional[str] = None,
) -> None:
"""Wrapper to plot diversity over generations."""
df = _as_history_df(history)
if "diversity" not in df.columns:
warnings.warn("Column 'diversity' not found in history.")
return
plot_history(
df,
metrics=["diversity"],
title=title,
log_y=log,
ylabel="Diversity",
save_path=save_path,
show=show,
with_std=False,
)
[docs]
def plot_mutation_trends(
history: PopOrDf,
*,
title: str = "Mutation Parameter Trends",
show: bool = True,
log: bool = False,
save_path: Optional[str] = None,
) -> None:
"""Wrapper to plot mutation and/or strength trends over time."""
df = _as_history_df(history)
metrics = []
if "mutation_probability_mean" in df.columns:
metrics.append("mutation_probability_mean")
if "mutation_strength_mean" in df.columns:
metrics.append("mutation_strength_mean")
if not metrics:
warnings.warn("No mutation-related columns found in history.")
return
plot_history(
df,
metrics=metrics,
title=title,
log_y=log,
ylabel="Mutation Parameters",
save_path=save_path,
show=show,
with_std=False,
)
[docs]
def plot_fitness_comparison(
histories: Sequence[PopOrDf],
*,
labels: Optional[list[str]] = None,
metric: str = "best_fitness",
title: str = "Fitness Comparison over Generations",
show: bool = True,
log: bool = False,
save_path: Optional[str] = None,
) -> None:
if not isinstance(histories, (list, tuple)):
raise TypeError("histories must be a List or a Tuple")
dfs = [_as_history_df(h) for h in histories]
# Skip histories that lack the requested metric
filtered = []
filtered_labels = []
for i, hist in enumerate(dfs):
if metric in hist.columns:
filtered.append(hist)
filtered_labels.append(labels[i] if labels else f"Run {i+1}")
else:
warnings.warn(f"Metric '{metric}' not found in run {i+1}. Skipping.")
if not filtered:
warnings.warn("No valid runs to plot.")
return
plot_history(
filtered,
metrics=[metric],
labels=filtered_labels,
title=title,
log_y=log,
save_path=save_path,
show=show,
with_std=True,
)
[docs]
def plot_approximation(
y_pred: Sequence[float] | np.ndarray,
y_true: Sequence[float] | np.ndarray,
*,
title: str = "Approximation",
show: bool = True,
show_grid: bool = True,
legend_location: str = "upper right",
save_path: Optional[str] = None,
pred_label: str = "Prediction",
true_label: str = "Target",
pred_marker: str | None = None,
true_marker: str | None = None,
pred_lw: int = 2,
true_lw: int = 2,
pred_ls: str = "--",
true_ls: str = "-",
x_vals: Sequence[float] | np.ndarray | None = None,
y_limits: tuple[float, float] | None = None,
dpi: int = 150,
size: tuple[float, float] = (6, 4),
fitness: float | None = None,
support_points: (
tuple[Sequence[float] | np.ndarray, Sequence[float] | np.ndarray] | None
) = None,
support_label: str = "Support points",
support_marker: str = "o",
support_size: float = 30.0,
xlabel: str | None = None,
ylabel: str | None = None,
residuals_style: Optional[Literal["overlay", "subplot"]] = None,
residuals_label: str = "Residuals (target - prediction)",
residuals_alpha: float = 0.5,
residuals_linewidth: float = 1.0,
residuals_ylimits: tuple[float, float] | None = None,
extra_lines: ExtraLines | None = None,
) -> None:
"""
Plot predicted values against targets with optional support points and residual
visualization.
Args:
y_pred: Predicted values (1D).
y_true: Target values (1D), same length as ``y_pred``.
title: Plot title. If ``fitness`` is given, it is appended.
show: If True, show the plot interactively.
show_grid: Toggle background grid.
legend_location: Matplotlib legend location string.
save_path: If provided, save the figure to this path.
pred_label: Legend label for predictions.
true_label: Legend label for targets.
x_vals: Optional x-axis values; defaults to ``range(len(y_true))``.
y_limits: Optional (ymin, ymax) for the main plot; if None, computed
with padding.
dpi: Figure DPI.
size: Figure size (width, height) in inches.
fitness: Optional fitness to append to the title (e.g. MSE).
support_points: Optional (x_support, y_support) to visualize discrete
support points.
support_label: Legend label for support points.
support_marker: Marker symbol for support points.
support_size: Marker size for support points.
xlabel, ylabel: Axis labels for the main plot.
residuals_style: "subplot" for a second axes below, or "overlay" to draw error
segments in the main axes.
residuals_label: Legend label for residuals (subplot mode).
residuals_alpha: Transparency for residual drawing.
residuals_linewidth: Line width for residual drawing.
extra_lines: Iterable of (x, y, label, style_dict). Each will be plotted via
ax.plot(x, y, **style_dict). Use it e.g. for a noisy
target line: (X, Y_noisy, "Noisy target",
{"ls": ":", "alpha": 0.6})
"""
y_true = np.asarray(y_true, dtype=float).ravel()
y_pred = np.asarray(y_pred, dtype=float).ravel()
if y_true.shape != y_pred.shape:
raise ValueError(
f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}"
)
if x_vals is None:
x_vals = np.arange(len(y_true), dtype=float)
else:
x_vals = np.asarray(x_vals, dtype=float).ravel()
if x_vals.shape[0] != y_true.shape[0]:
raise ValueError(
f"Length mismatch: x_vals {x_vals.shape[0]} vs y {y_true.shape[0]}"
)
if fitness is not None:
title = f"{title} (fitness={fitness:.4f})"
# figure/axes creation
if residuals_style == "subplot":
fig, (ax, ax_res) = plt.subplots(
2,
1,
sharex=True,
figsize=(size[0], size[1] * 1.4),
dpi=dpi,
gridspec_kw={"height_ratios": [3, 1], "hspace": 0.05},
constrained_layout=True,
)
use_tight = False
else:
fig, ax = plt.subplots(figsize=size, dpi=dpi)
ax_res = None
use_tight = True
# main plot
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.plot(
x_vals,
y_true,
color="black",
label=true_label,
marker=true_marker,
lw=true_lw,
ls=true_ls,
)
ax.plot(
x_vals,
y_pred,
color="red",
label=pred_label,
marker=pred_marker,
lw=pred_lw,
ls=pred_ls,
)
if support_points is not None:
sx, sy = support_points
sx = np.asarray(sx, dtype=float).ravel()
sy = np.asarray(sy, dtype=float).ravel()
if sx.shape != sy.shape:
raise ValueError(
f"Support points shape mismatch: x {sx.shape} vs y {sy.shape}"
)
ax.scatter(
sx, sy, label=support_label, s=support_size, marker=support_marker, zorder=3
)
ax.legend(loc=legend_location)
ax.grid(show_grid)
# y-limits (main)
if y_limits is not None:
ax.set_ylim(*y_limits)
else:
y_all = [y_true, y_pred]
if support_points is not None:
y_all.append(np.asarray(support_points[1], dtype=float).ravel())
y_all = np.concatenate(y_all)
pad = 0.05 * (y_all.max() - y_all.min() + 1e-12)
ax.set_ylim(y_all.min() - pad, y_all.max() + pad)
# Extra lines
if extra_lines is not None:
for xs, ys, lab, style in extra_lines:
xs_arr = np.asarray(xs, dtype=float).ravel()
ys_arr = np.asarray(ys, dtype=float).ravel()
if xs_arr.shape[0] != ys_arr.shape[0]:
raise ValueError(
f"extra_lines length mismatch: x {xs_arr.shape[0]} "
f"vs y {ys_arr.shape[0]}"
)
ax.plot(
xs_arr,
ys_arr,
color="grey",
label=lab if lab is not None else None,
**(style or {}),
)
# residuals
if residuals_style is not None:
residuals = y_true - y_pred
if residuals_style == "subplot":
ax_res.plot(
x_vals,
residuals,
color="grey",
lw=1.5,
alpha=residuals_alpha,
label=residuals_label,
)
if residuals_ylimits is not None:
ax_res.set_ylim(*residuals_ylimits)
else:
m = float(np.max(np.abs(residuals)))
pad = 0.05 * (m + 1e-12)
ax_res.set_ylim(-(m + pad), +(m + pad))
ax_res.axhline(0.0, color="black", lw=1.0, alpha=0.7)
ax_res.set_ylabel("resid")
ax_res.grid(show_grid)
ax_res.legend(loc="upper right")
elif residuals_style == "overlay":
y_min = np.minimum(y_true, y_pred)
y_max = np.maximum(y_true, y_pred)
ax.vlines(
x_vals,
y_min,
y_max,
alpha=residuals_alpha,
linewidth=residuals_linewidth,
)
else:
raise ValueError('residuals_style must be "subplot" or "overlay"')
# layout handling
if use_tight:
fig.tight_layout()
# saving
if save_path:
folder = os.path.dirname(save_path)
if folder:
os.makedirs(folder, exist_ok=True)
fig.savefig(save_path, dpi=dpi)
# show/close
if show:
plt.show()
else:
plt.close(fig)
[docs]
def save_combined_net_plot(
net: Nnet,
X: np.ndarray,
Y_true: np.ndarray,
Y_pred: np.ndarray,
path: str,
*,
dpi: int = 100,
title: str = "Approximation",
xlabel: str = "x",
ylabel: str = "y",
true_label: str = "Target",
pred_label: str = "Network Output",
show_grid: bool = True,
) -> None:
"""
Save a side-by-side image of the network graph and an approximation plot.
Parameters
----------
net : Nnet
EvoNet instance; must provide `plot(out_path_base, ...)`.
X : np.ndarray
X values used for plotting (1D).
Y_true : np.ndarray
Ground-truth targets aligned with X.
Y_pred : np.ndarray
Model predictions aligned with X.
path : str
Output PNG file.
dpi : int, optional
DPI for the approximation subplot.
title, xlabel, ylabel : str, optional
Plot annotations.
true_label, pred_label : str, optional
Legend labels for target and prediction.
show_grid : bool, optional
Toggle grid in the approximation subplot.
"""
# ensure folder exists
folder = os.path.dirname(path)
if folder:
os.makedirs(folder, exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir:
# left: network graph
graph_path = os.path.join(tmpdir, "net")
net.plot(graph_path, fillcolors_on=True, thickness_on=True)
img_net = Image.open(graph_path + ".png")
# right: approximation plot
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(X, Y_true, label=true_label, linestyle="-")
ax.plot(X, Y_pred, label=pred_label, linestyle="--")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.grid(show_grid)
ax.legend()
fig.tight_layout()
plot_path = os.path.join(tmpdir, "plot.png")
fig.savefig(plot_path, dpi=dpi)
plt.close(fig)
img_plot = Image.open(plot_path)
# combine
total_width = img_net.width + img_plot.width
height = max(img_net.height, img_plot.height)
combined = Image.new("RGB", (total_width, height), "white")
combined.paste(img_net, (0, 0))
combined.paste(img_plot, (img_net.width, 0))
combined.save(path)
[docs]
def plot_bit_prediction(
pred_values: list[float],
true_bits: list[int],
save_path: str,
input_bits: list[int] | None = None,
title: str = "Bit Prediction",
show: bool = False,
dpi: int = 150,
pred_name: str = "Prediction",
show_target_raster: bool = True,
) -> None:
"""
Visualize a bit-sequence task over time.
This plot combines two views:
- A raster plot (top) showing discrete input/target/predicted bits
- A line/scatter plot (bottom) showing target vs. predicted values
Parameters
----------
true_bits : list[int]
Ground truth sequence of bits (0/1).
pred_values : list[float]
Model output values. Shown as continuous in the line plot and rounded
to 0/1 in the raster plot.
input_bits : list[int], optional
Optional input sequence of bits (0/1). If given, it is shown in the raster
plot above target and prediction.
save_path : str
File path to save the plot image.
title : str, optional
Title of the figure. Default is "Bit Prediction".
show : bool, optional
If True, show the plot window interactively. Default: False.
dpi : int, optional
Resolution for saving the figure.
pred_name : str, optional
Label used for the predicted/output series (e.g., "Prediction" or "Echo").
show_target_raster : bool, optional
If True, include the target bits in the raster plot.
Notes
-----
- For general regression/function approximation, use `plot_approximation`.
"""
pred_round = np.rint(pred_values).astype(int)
# Build raster data
rows: list[np.ndarray] = []
row_labels: list[str] = []
if input_bits is not None:
rows.append(np.asarray(input_bits, dtype=int))
row_labels.append("Input")
if show_target_raster:
rows.append(np.asarray(true_bits, dtype=int))
row_labels.append("Target")
rows.append(pred_round)
row_labels.append(pred_name if pred_name != "Prediction" else "Pred")
grid = np.vstack(rows)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), constrained_layout=True)
# Raster view
axes[0].imshow(grid, aspect="auto", interpolation="nearest", cmap="Greys")
axes[0].set_yticks(range(len(row_labels)))
axes[0].set_yticklabels(row_labels)
axes[0].set_xticks([])
axes[0].set_title(title)
# Line/scatter view
x = np.arange(len(true_bits))
axes[1].plot(
x,
true_bits,
color="black",
marker="o",
linestyle="None",
label="Target",
markersize=4,
)
axes[1].plot(
x,
pred_values,
color="red",
linestyle="-",
alpha=0.7,
label=pred_name,
)
axes[1].set_xlim(0, len(true_bits))
axes[1].set_ylim(-0.2, 1.2)
axes[1].set_xlabel("Time step")
axes[1].set_ylabel("Value")
axes[1].legend(loc="upper right")
# Threshold line
axes[1].axhline(0.5, color="gray", linestyle="--", linewidth=1, alpha=0.6)
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=dpi)
if show:
plt.show()
plt.close()
[docs]
def save_current_plot(filename: str, dpi: int = 300) -> None:
"""Save the current matplotlib figure to file."""
plt.savefig(filename, dpi=dpi)
print(f"Plot saved to '{filename}'")