-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Looks good to me! I do have some comments/questions/requests though
I noticed aComputeConfigwas added. I’m a bit concerned about the direction in its current form, mainly because it introduces a second “global source of truth” for dtype/device that can be easy to get out of sync with normal PyTorch conventions.
- It keeps a global dtype that, if I understand correctly, should be the default dtype to initialize float tensors if not specified otherwise. PyTorch already has this built in (see
torch.set_default_dtype). If these are not in sync it will cause unexpected behavior. If nothing is implemented it already behaves exactly as intended.- It checks whether a GPU is available and if so will use it as default, and to disable this a user has to explicitly disable this by adding
ComputeConfig.set_device('cpu')any time a model is used and would give an error otherwise. I think it would be more intuitive to use the cpu by default, for which no extra code is required either.My impression is that PyTorch Modules generally follow the following convention
class A(nn.Module): def __init__(self, *args, dtype=None, device=None): # pass dtype and device to any sub-modules or parameters that are initialized # for any initialized tensor for which dtype or device are None they are initialized to `torch.get_default_dtype` and `'cpu'`, respectively. ... def forward(self, x): # infer dtype and device from x # initialize any tensors to be consistent with dtype and device inferred from x ...and that by following this convention no other global config would be required.
@ronvree Thanks for the review and taking a close look! We think the
ComputeConfigapproach is a solid (and cleaner) alternative to usingtorch.set_default_dtype/devicein diffwofost. Relying on PyTorch’s global defaults can get risky for reusable software since they’re process-wide and affect everything created afterward, including third-party code, tests, or models we don’t control. As a rule of thumb,torch.set_default_*is fine for notebooks or tiny scripts. For a modular code we use a config likeComputeConfige.g.:dtype = ComputeConfig.get_dtype() device = ComputeConfig.get_device() torch.tensor(..., dtype=dtype, device=device)That said, I noticed that we’re currently setting device/dtype in two places (
ComputeConfigandEngineTestHelper(device=..., dtype=...)) Also, I didn’t set device/dtype correctly when creating tensors in the notebooks. We’ll fix that in #84.One more thing:
model.to("cuda")won’t actually do anything here. Instead, the intended pattern is to set the device/dtype viaComputeConfigbefore running the forward pass. This approach is consistent with how other scientific simulators handle execution context.
Thanks for your reply! I completely agree with the concerns regarding torch.set_default_dtype and my suggestion was not to use this to sync with ComputeConfig but rather to not have a globally defined device/dtype for diffwofost. I'm mainly proposing to use the pytorch conventions where the torch defaults are used if nothing is specified, but the user could have different instances on different devices if they explicitly mention this
Originally posted by @ronvree in #76 (comment)