Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/colorbars_legends.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,44 @@
ax = axs[1]
ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows")
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo")
# %% [raw] raw_mimetype="text/restructuredtext"
# .. _ug_guides_decouple:
#
# Decoupling legend content and location
# --------------------------------------
#
# Sometimes you may want to generate a legend using handles from specific axes
# but place it relative to other axes. In UltraPlot, you can achieve this by passing
# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend`
# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the
# axes used to generate the legend handles, while the `ref` keyword specifies the
# reference axes used to determine the legend location.
#
# For example, to draw a legend based on the handles in the second row of subplots
# but place it below the first row of subplots, you can use
# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list
# of axes, UltraPlot intelligently infers the span (width or height) and anchors
# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'``
# or the right-most axis for ``loc='right'``).

# %%
import numpy as np

import ultraplot as uplt

fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False)
axs.format(abc="A.", suptitle="Decoupled legend location demo")

# Plot data on all axes
state = np.random.RandomState(51423)
data = (state.rand(20, 4) - 0.5).cumsum(axis=0)
for ax in axs:
ax.plot(data, cycle="mplotcolors", labels=list("abcd"))

# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :])
# This places a legend describing the bottom row data underneath the top row.
fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2")

# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :])
# This places a legend describing the top row data underneath the bottom row.
fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1")
235 changes: 214 additions & 21 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,6 +2594,8 @@ def colorbar(
"""
# Backwards compatibility
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax
cax = kwargs.pop("cax", None)
if isinstance(values, maxes.Axes):
cax = _not_none(cax_positional=values, cax=cax)
Expand All @@ -2613,20 +2615,102 @@ def colorbar(
with context._state_context(cax, _internal_call=True): # do not wrap pcolor
cb = super().colorbar(mappable, cax=cax, **kwargs)
# Axes panel colorbar
elif ax is not None:
elif loc_ax is not None:
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None

# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True

if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True

# Extract a single axes from array if span is provided
# Otherwise, pass the array as-is for normal colorbar behavior
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
try:
ax_single = next(iter(ax))
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the colorbar side
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

except (TypeError, StopIteration):
ax_single = ax
best_ax = None
best_coord = float("-inf")

# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()

if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0

if val > best_coord:
best_coord = val
best_ax = axi

# Fallback to first axis
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax
else:
ax_single = ax
ax_single = loc_ax

# Pass span parameters through to axes colorbar
cb = ax_single.colorbar(
Expand Down Expand Up @@ -2700,27 +2784,136 @@ def legend(
matplotlib.axes.Axes.legend
"""
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax

# Axes panel legend
if ax is not None:
if loc_ax is not None:
content_ax = ax if ax is not None else loc_ax
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None
# Extract a single axes from array if span is provided
# Otherwise, pass the array as-is for normal legend behavior
# Automatically collect handles and labels from spanned axes if not provided
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
# Auto-collect handles and labels if not explicitly provided
if handles is None and labels is None:
handles, labels = [], []
for axi in ax:

# Automatically collect handles and labels from content axes if not provided
# Case 1: content_ax is a list (we must auto-collect)
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
must_collect = (
np.iterable(content_ax)
and not isinstance(content_ax, (str, maxes.Axes))
) or (content_ax is not loc_ax)

if must_collect and handles is None and labels is None:
handles, labels = [], []
# Handle list of axes
if np.iterable(content_ax) and not isinstance(
content_ax, (str, maxes.Axes)
):
for axi in content_ax:
h, l = axi.get_legend_handles_labels()
handles.extend(h)
labels.extend(l)
try:
ax_single = next(iter(ax))
except (TypeError, StopIteration):
ax_single = ax
# Handle single axis
else:
handles, labels = content_ax.get_legend_handles_labels()

# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True

if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True

# Extract a single axes from array if span is provided (or if ref is a list)
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the legend side
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

best_ax = None
best_coord = float("-inf")

# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()

if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0

if val > best_coord:
best_coord = val
best_ax = axi

# Fallback to first axis if no best axis found (or side is None)
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax

else:
ax_single = ax
ax_single = loc_ax
if isinstance(ax_single, list):
try:
ax_single = pgridspec.SubplotGrid(ax_single)
except ValueError:
ax_single = ax_single[0]

leg = ax_single.legend(
handles,
labels,
Expand Down
11 changes: 10 additions & 1 deletion ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False):
nums = []
idxs = self._get_indices(which=which, panel=panel)
for arg in args:
if isinstance(arg, (list, np.ndarray)):
try:
nums.append([idxs[int(i)] for i in arg])
except (IndexError, TypeError):
raise ValueError(f"Invalid gridspec index {arg}.")
continue
try:
nums.append(idxs[arg])
except (IndexError, TypeError):
Expand Down Expand Up @@ -1612,10 +1618,13 @@ def __getitem__(self, key):
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
"""
# Allow 1D list-like indexing
if isinstance(key, int):
if isinstance(key, (Integral, np.integer)):
return list.__getitem__(self, key)
elif isinstance(key, slice):
return SubplotGrid(list.__getitem__(self, key))
elif isinstance(key, (list, np.ndarray)):
# NOTE: list.__getitem__ does not support numpy integers
return SubplotGrid([list.__getitem__(self, int(i)) for i in key])

# Allow 2D array-like indexing
# NOTE: We assume this is a 2D array of subplots, because this is
Expand Down
Loading
Loading