Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 19, 2024
1 parent fbb2a50 commit 52e4c22
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 16 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
name: Full Testing

# on: # yamllint disable-line rule:truthy
# workflow_dispatch:
# schedule:
# - cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST
on: # yamllint disable-line rule:truthy
workflow_dispatch:
schedule:
- cron: "0 6 * * *" # Everyday at 6:00am UTC/10:00pm PST
push:
branches:
- master
pull_request:

jobs:

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend="flit_core.buildapi"

[project]
name="torch_geometric"
version="2.5.2"
version="2.5.3"
authors=[
{name="Matthias Fey", email="[email protected]"},
]
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')

__version__ = '2.5.2'
__version__ = '2.5.3'

__all__ = [
'EdgeIndex',
Expand Down
38 changes: 27 additions & 11 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,15 @@ def __init__(
# Optimize `propagate()` via `*.jinja` templates:
if not self.propagate.__module__.startswith(jinja_prefix):
try:
if 'propagate' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'propagate' "
"method")
module = module_from_template(
module_name=f'{jinja_prefix}_propagate',
template_path=osp.join(root_dir, 'propagate.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
module=self.inspector._modules,
modules=self.inspector._modules,
collect_name='collect',
signature=self._get_propagate_signature(),
collect_param_dict=self.inspector.get_flat_param_dict(
Expand All @@ -198,6 +201,9 @@ def __init__(
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)):
try:
if 'edge_updater' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'edge_updater' "
"method")
module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
Expand Down Expand Up @@ -227,6 +233,7 @@ def __init__(
self._apply_sigmoid: bool = True

# Inference Decomposition:
self._decomposed_layers = 1
self.decomposed_layers = decomposed_layers

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -711,16 +718,20 @@ def decomposed_layers(self, decomposed_layers: int) -> None:
raise ValueError("Inference decomposition of message passing "
"modules is only supported on the Python module")

if decomposed_layers == self._decomposed_layers:
return # Abort early if nothing to do.

self._decomposed_layers = decomposed_layers

if decomposed_layers != 1:
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_orig_propagate'):
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)

elif ((self.explain is None or self.explain is False)
and not self.propagate.__module__.endswith('_propagate')):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)
elif self.explain is None or self.explain is False:
if hasattr(self.__class__, '_jinja_propagate'):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)

# Explainability ##########################################################

Expand All @@ -734,6 +745,9 @@ def explain(self, explain: Optional[bool]) -> None:
raise ValueError("Explainability of message passing modules "
"is only supported on the Python module")

if explain == self._explain:
return # Abort early if nothing to do.

self._explain = explain

if explain is True:
Expand All @@ -744,16 +758,18 @@ def explain(self, explain: Optional[bool]) -> None:
funcs=['message', 'explain_message', 'aggregate', 'update'],
exclude=self.special_args,
)
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_orig_propagate'):
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
else:
self._user_args = self.inspector.get_flat_param_names(
funcs=['message', 'aggregate', 'update'],
exclude=self.special_args,
)
if self.decomposed_layers == 1:
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)
if hasattr(self.__class__, '_jinja_propagate'):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)

def explain_message(
self,
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def module_from_template(
delete=False,
) as tmp:
tmp.write(module_repr)
tmp.flush()

spec = importlib.util.spec_from_file_location(module_name, tmp.name)
assert spec is not None
Expand Down

0 comments on commit 52e4c22

Please sign in to comment.