When are buffers moved to gpu? #12207
-
I have an issue with a weighted mse function that I instantiate in the setup, with a buffer as parameter. Something like this: @torch.jit.script
def weighted_mse_func(weights, y, y_hat):
# weighted regression loss
reg_loss = torch.dot(weights,torch.mean(F.mse_loss(y_hat, y, reduction='none'), dim=0))
return reg_loss
def weighted_mse(weights):
def func(y, y_hat):
return weighted_mse_func(weights, y, y_hat)
return func
class model(pl.LightningModule):
def __init__(self, weights):
weights = torch.tensor(weights.copy(), dtype=self.dtype, device=self.device)
self.register_buffer("weights", weights)
def setup(self, stage):
super().setup(stage)
self.loss = weighted_mse(self.weights) When initializing training on the GPU I get an error because |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
try: class model(pl.LightningModule):
def __init__(self, weights):
super().__init__()
self.register_buffer("weights", torch.tensor(weights.copy(), dtype=self.dtype))
def on_fit_start(self):
self.loss = weighted_mse(self.weights) |
Beta Was this translation helpful? Give feedback.
-
It seems that pytorch doesn't move buffers parameters in-place (like it is done for parameters), this results in references to buffers being useless if they are moved from one device to another. This issue is discussed in pytorch/pytorch#43815. |
Beta Was this translation helpful? Give feedback.
It seems that pytorch doesn't move buffers parameters in-place (like it is done for parameters), this results in references to buffers being useless if they are moved from one device to another. This issue is discussed in pytorch/pytorch#43815.