Skip to content

Commit 55959de

Browse files
committed
ensemble: fix for grad runtime error
- dummy param should be linked to loss to build gradient graph - Error : element 0 of tensors does not require grad and does not have a grad_fn
1 parent 4d3f4f6 commit 55959de

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

chebai/models/ensemble.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os.path
22
from abc import ABC, abstractmethod
3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any, Dict, Optional, Tuple, Union
44

55
import torch
66
from torch import Tensor
@@ -94,7 +94,7 @@ class ChebiEnsemble(_EnsembleBase):
9494
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
9595
super().__init__(model_configs, **kwargs)
9696
# Add a dummy trainable parameter
97-
self.dummy_param = torch.nn.Parameter(torch.randn(1))
97+
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
9898

9999
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
100100
predictions = {}
@@ -103,8 +103,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
103103
data["labels"].shape[0], data["labels"].shape[1], device=self.device
104104
).to(self.device)
105105

106-
print(data["features"].shape) # Debugging
107-
108106
for name, model in self.models.items():
109107
output = model(data)
110108
confidences[name] = torch.sigmoid(output["logits"])
@@ -193,7 +191,8 @@ def _execute(
193191
)
194192
loss = loss[0]
195193

196-
d["loss"] = loss
194+
d["loss"] = loss + 0 * self.dummy_param.sum()
195+
197196
self.log(
198197
f"{prefix}loss",
199198
loss.item(),
@@ -229,6 +228,28 @@ def aggregate_predictions(self, predictions, confidences):
229228

230229
return (true_scores > false_scores).long() # Final class decision
231230

231+
def _process_for_loss(
232+
self,
233+
model_output: Dict[str, Tensor],
234+
labels: Tensor,
235+
loss_kwargs: Dict[str, Any],
236+
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
237+
"""
238+
Process the model output for calculating the loss.
239+
240+
Args:
241+
model_output (Dict[str, Tensor]): The output of the model.
242+
labels (Tensor): The target labels.
243+
loss_kwargs (Dict[str, Any]): Additional loss arguments.
244+
245+
Returns:
246+
tuple: A tuple containing the processed model output, labels, and loss arguments.
247+
"""
248+
kwargs_copy = dict(loss_kwargs)
249+
if labels is not None:
250+
labels = labels.float()
251+
return model_output["logits"], labels, kwargs_copy
252+
232253

233254
class ChebiEnsembleLearning(_EnsembleBase):
234255

0 commit comments

Comments
 (0)