11import os .path
22from abc import ABC , abstractmethod
3- from typing import Any , Dict , Optional , Union
3+ from typing import Any , Dict , Optional , Tuple , Union
44
55import torch
66from torch import Tensor
@@ -94,7 +94,7 @@ class ChebiEnsemble(_EnsembleBase):
9494 def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
9595 super ().__init__ (model_configs , ** kwargs )
9696 # Add a dummy trainable parameter
97- self .dummy_param = torch .nn .Parameter (torch .randn (1 ))
97+ self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
9898
9999 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
100100 predictions = {}
@@ -103,8 +103,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
103103 data ["labels" ].shape [0 ], data ["labels" ].shape [1 ], device = self .device
104104 ).to (self .device )
105105
106- print (data ["features" ].shape ) # Debugging
107-
108106 for name , model in self .models .items ():
109107 output = model (data )
110108 confidences [name ] = torch .sigmoid (output ["logits" ])
@@ -193,7 +191,8 @@ def _execute(
193191 )
194192 loss = loss [0 ]
195193
196- d ["loss" ] = loss
194+ d ["loss" ] = loss + 0 * self .dummy_param .sum ()
195+
197196 self .log (
198197 f"{ prefix } loss" ,
199198 loss .item (),
@@ -229,6 +228,28 @@ def aggregate_predictions(self, predictions, confidences):
229228
230229 return (true_scores > false_scores ).long () # Final class decision
231230
231+ def _process_for_loss (
232+ self ,
233+ model_output : Dict [str , Tensor ],
234+ labels : Tensor ,
235+ loss_kwargs : Dict [str , Any ],
236+ ) -> Tuple [Tensor , Tensor , Dict [str , Any ]]:
237+ """
238+ Process the model output for calculating the loss.
239+
240+ Args:
241+ model_output (Dict[str, Tensor]): The output of the model.
242+ labels (Tensor): The target labels.
243+ loss_kwargs (Dict[str, Any]): Additional loss arguments.
244+
245+ Returns:
246+ tuple: A tuple containing the processed model output, labels, and loss arguments.
247+ """
248+ kwargs_copy = dict (loss_kwargs )
249+ if labels is not None :
250+ labels = labels .float ()
251+ return model_output ["logits" ], labels , kwargs_copy
252+
232253
233254class ChebiEnsembleLearning (_EnsembleBase ):
234255
0 commit comments