diff --git a/src/probeinterface/plotting.py b/src/probeinterface/plotting.py index da84dc1b..2830ca0b 100644 --- a/src/probeinterface/plotting.py +++ b/src/probeinterface/plotting.py @@ -252,6 +252,8 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs): import matplotlib.pyplot as plt + figsize = kargs.pop("figsize", None) + n = len(probegroup.probes) if same_axes: @@ -259,16 +261,16 @@ def plot_probegroup(probegroup, same_axes: bool = True, **kargs): ax = kargs.pop("ax") else: if probegroup.ndim == 2: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=figsize) else: - fig = plt.figure() + fig = plt.figure(figsize=figsize) ax = fig.add_subplot(1, 1, 1, projection="3d") axs = [ax] * n else: if "ax" in kargs: raise ValueError("when same_axes=False, an axes object cannot be passed into this function.") if probegroup.ndim == 2: - fig, axs = plt.subplots(ncols=n, nrows=1) + fig, axs = plt.subplots(ncols=n, nrows=1, figsize=figsize) if n == 1: axs = [axs] else: