Skip to content

Commit ef489f2

Browse files
authored
[Fabric] Ignore Keyword Arguments Outside of Callback Signature (#21258)
1 parent 5ea509a commit ef489f2

File tree

4 files changed

+126
-9
lines changed

4 files changed

+126
-9
lines changed

docs/source-fabric/guide/callbacks.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ The :meth:`~lightning.fabric.fabric.Fabric.call` calls the callback objects in t
8383
Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name.
8484
The ones that have a matching method name will get called.
8585

86+
The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on
87+
each callback's function signature, allowing callbacks with different signatures to work together seamlessly.
88+
89+
.. code-block:: python
90+
91+
class TrainingMetricsCallback:
92+
def on_train_epoch_end(self, train_loss):
93+
print(f"Training loss: {train_loss:.4f}")
94+
95+
class ValidationMetricsCallback:
96+
def on_train_epoch_end(self, val_accuracy):
97+
print(f"Validation accuracy: {val_accuracy:.4f}")
98+
99+
class ComprehensiveCallback:
100+
def on_train_epoch_end(self, epoch, **kwargs):
101+
print(f"Epoch {epoch} complete with metrics: {kwargs}")
102+
103+
fabric = Fabric(
104+
callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()]
105+
)
106+
107+
# Each callback receives only the arguments it can handle
108+
fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001)
109+
86110
87111
----
88112

src/lightning/fabric/CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [unreleased] - YYYY-MM-DD
9+
10+
### Added
11+
12+
- Added kwargs-filtering for `Fabric.call` to support different callback method signatures ([#21258](https://github.com/Lightning-AI/pytorch-lightning/pull/21258))
13+
14+
15+
### Removed
16+
17+
-
18+
19+
20+
---
21+
722
## [2.6.0] - 2025-11-21
823

924
### Changed

src/lightning/fabric/fabric.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,34 @@ def train_function(fabric):
985985
)
986986
return self._wrap_and_launch(function, self, *args, **kwargs)
987987

988+
def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
989+
"""Filter keyword arguments to only include those that match the callback method's signature.
990+
991+
Args:
992+
method: The callback method to inspect
993+
kwargs: The keyword arguments to filter
994+
995+
Returns:
996+
A filtered dictionary of keyword arguments that match the method's signature
997+
998+
"""
999+
try:
1000+
sig = inspect.signature(method)
1001+
except (ValueError, TypeError):
1002+
# If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003+
return kwargs
1004+
1005+
filtered_kwargs = {}
1006+
for name, param in sig.parameters.items():
1007+
# If the method accepts **kwargs, pass all original kwargs directly
1008+
if param.kind == inspect.Parameter.VAR_KEYWORD:
1009+
return kwargs
1010+
# If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1011+
if name in kwargs:
1012+
filtered_kwargs[name] = kwargs[name]
1013+
1014+
return filtered_kwargs
1015+
9881016
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9891017
r"""Trigger the callback methods with the given name and arguments.
9901018
@@ -994,7 +1022,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9941022
Args:
9951023
hook_name: The name of the callback method.
9961024
*args: Optional positional arguments that get passed down to the callback method.
997-
**kwargs: Optional keyword arguments that get passed down to the callback method.
1025+
**kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1026+
that are not present in the callback's signature will be filtered out automatically, allowing
1027+
callbacks to have different signatures for the same hook.
9981028
9991029
Example::
10001030
@@ -1016,13 +1046,8 @@ def on_train_epoch_end(self, results):
10161046
)
10171047
continue
10181048

1019-
method(*args, **kwargs)
1020-
1021-
# TODO(fabric): handle the following signatures
1022-
# method(self, fabric|trainer, x, y=1)
1023-
# method(self, fabric|trainer, *args, x, y=1)
1024-
# method(self, *args, y=1)
1025-
# method(self, *args, **kwargs)
1049+
filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs)
1050+
method(*args, **filtered_kwargs)
10261051

10271052
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
10281053
"""Log a scalar to all loggers that were added to Fabric.

tests/tests_fabric/test_fabric.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
import os
1516
import warnings
1617
from contextlib import nullcontext
@@ -20,7 +21,6 @@
2021

2122
import pytest
2223
import torch
23-
import torch.distributed
2424
import torch.nn.functional
2525
from lightning_utilities.test.warning import no_warning_call
2626
from torch import nn
@@ -1294,3 +1294,56 @@ def test_verify_launch_called():
12941294
fabric.launch()
12951295
assert fabric._launched
12961296
fabric._validate_launched()
1297+
1298+
1299+
def test_callback_kwargs_filtering():
1300+
"""Test that callbacks receive only the kwargs they can handle based on their signature."""
1301+
1302+
class CallbackWithLimitedKwargs:
1303+
def on_train_epoch_end(self, epoch: int):
1304+
self.epoch = epoch
1305+
1306+
class CallbackWithVarKeywords:
1307+
def on_train_epoch_end(self, epoch: int, **kwargs):
1308+
self.epoch = epoch
1309+
self.kwargs = kwargs
1310+
1311+
class CallbackWithNoParams:
1312+
def on_train_epoch_end(self):
1313+
self.called = True
1314+
1315+
callback1 = CallbackWithLimitedKwargs()
1316+
callback2 = CallbackWithVarKeywords()
1317+
callback3 = CallbackWithNoParams()
1318+
fabric = Fabric(callbacks=[callback1, callback2, callback3])
1319+
fabric.call("on_train_epoch_end", epoch=5, loss=0.1, metrics={"acc": 0.9})
1320+
1321+
assert callback1.epoch == 5
1322+
assert not hasattr(callback1, "loss")
1323+
assert callback2.epoch == 5
1324+
assert callback2.kwargs == {"loss": 0.1, "metrics": {"acc": 0.9}}
1325+
assert callback3.called is True
1326+
1327+
1328+
def test_callback_kwargs_filtering_signature_inspection_failure():
1329+
"""Test behavior when signature inspection fails - should fallback to passing all kwargs."""
1330+
callback = Mock()
1331+
fabric = Fabric(callbacks=[callback])
1332+
original_signature = inspect.signature
1333+
1334+
def mock_signature(obj):
1335+
if hasattr(obj, "_mock_name") or hasattr(obj, "_mock_new_name"):
1336+
raise ValueError("Cannot inspect mock signature")
1337+
return original_signature(obj)
1338+
1339+
# Temporarily replace signature function in fabric module
1340+
import lightning.fabric.fabric
1341+
1342+
lightning.fabric.fabric.inspect.signature = mock_signature
1343+
1344+
try:
1345+
# Should still work by passing all kwargs when signature inspection fails
1346+
fabric.call("on_test_hook", arg1="value1", arg2="value2")
1347+
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
1348+
finally:
1349+
lightning.fabric.fabric.inspect.signature = original_signature

0 commit comments

Comments
 (0)