From d5e2811c388747933fdd6c4f72e21b4d86eb21e1 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Mon, 23 Feb 2026 17:30:02 +0000 Subject: [PATCH] Add kernel inference for exponents on (0, 2). --- recipe.py | 28 +++++++++---------- .../config.py | 20 +++++++++++++ .../scripts/train_nn.py | 2 ++ 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/recipe.py b/recipe.py index 5052f71..ae97f18 100644 --- a/recipe.py +++ b/recipe.py @@ -265,20 +265,20 @@ # Inference for trees using a different method to compare with. -config = "gn_graph" -test_data = ROOT / config / "data/test" -target = ROOT / config / "cantwell/result.pkl" -args = [ - "python", - "-m", - "simulation_based_graph_inference.scripts.infer_tree_kernel", -] + dict2args(test=test_data, result=target) -create_task( - f"{config}/cantwell", - dependencies=[test_data / "meta.json"], - targets=[target], - action=args, -) +for config in ["gn_graph02", "gn_graph"]: + test_data = ROOT / config / "data/test" + target = ROOT / config / "cantwell/result.pkl" + args = [ + "python", + "-m", + "simulation_based_graph_inference.scripts.infer_tree_kernel", + ] + dict2args(test=test_data, result=target) + create_task( + f"{config}/cantwell", + dependencies=[test_data / "meta.json"], + targets=[target], + action=args, + ) # Profiling targets. diff --git a/src/simulation_based_graph_inference/config.py b/src/simulation_based_graph_inference/config.py index 5fdd7e7..1bb3d69 100644 --- a/src/simulation_based_graph_inference/config.py +++ b/src/simulation_based_graph_inference/config.py @@ -86,6 +86,22 @@ def create_estimator(self) -> th.nn.ModuleDict: scale=th.nn.LazyLinear(1), constraint_transforms={"scale": softplus}, ) + elif isinstance(constraint, constraints.interval): + estimator[name] = DistributionModule( + th.distributions.Beta, + concentration0=th.nn.LazyLinear(1), + concentration1=th.nn.LazyLinear(1), + constraint_transforms={ + "concentration0": softplus, + "concentration1": softplus, + }, + transforms=[ + th.distributions.AffineTransform( + loc=constraint.lower_bound, + scale=constraint.upper_bound - constraint.lower_bound, + ) + ], + ) else: raise NotImplementedError(f"{constraint} constraint is not supported") return th.nn.ModuleDict(estimator) @@ -194,6 +210,10 @@ def _gn_graph(num_nodes: int, gamma: float, **kwargs) -> nx.Graph: {"gamma": th.distributions.Beta(1, 1)}, _gn_graph, ), + "gn_graph02": Configuration( + {"gamma": th.distributions.Uniform(0, 2)}, + _gn_graph, + ), # "latent_space_graph": Configuration( # {"bias": th.distributions.Normal(1, 1), "scale": th.distributions.Gamma(2, 2)}, # generators.latent_space_graph, diff --git a/src/simulation_based_graph_inference/scripts/train_nn.py b/src/simulation_based_graph_inference/scripts/train_nn.py index cbbbf8f..8baf72c 100644 --- a/src/simulation_based_graph_inference/scripts/train_nn.py +++ b/src/simulation_based_graph_inference/scripts/train_nn.py @@ -1,5 +1,6 @@ import contextlib from datetime import datetime +import numpy as np import pickle import torch as th from torch_geometric import nn as tgnn @@ -303,6 +304,7 @@ def __main__(argv: typing.Optional[list[str]] = None) -> None: train_loss = run_epoch(model, train_loader, args.epsilon, optimizer)[ "epoch_loss" ] + assert np.isfinite(train_loss), f"Loss is not finite: {train_loss}" # Print parameter count after first epoch (when lazy layers are initialized) if not params_printed: