Skip to content

ComputeConfig and torch defaults #87

@SarahAlidoost

Description

@SarahAlidoost

Looks good to me! I do have some comments/questions/requests though
I noticed a ComputeConfig was added. I’m a bit concerned about the direction in its current form, mainly because it introduces a second “global source of truth” for dtype/device that can be easy to get out of sync with normal PyTorch conventions.

  • It keeps a global dtype that, if I understand correctly, should be the default dtype to initialize float tensors if not specified otherwise. PyTorch already has this built in (see torch.set_default_dtype). If these are not in sync it will cause unexpected behavior. If nothing is implemented it already behaves exactly as intended.
  • It checks whether a GPU is available and if so will use it as default, and to disable this a user has to explicitly disable this by adding ComputeConfig.set_device('cpu') any time a model is used and would give an error otherwise. I think it would be more intuitive to use the cpu by default, for which no extra code is required either.

My impression is that PyTorch Modules generally follow the following convention

class A(nn.Module):

   def __init__(self, *args, dtype=None, device=None):
   	
   	# pass dtype and device to any sub-modules or parameters that are initialized
   	# for any initialized tensor for which dtype or device are None they are initialized to `torch.get_default_dtype` and `'cpu'`, respectively.
   	...
   	
   def forward(self, x):
   	
   	# infer dtype and device from x 
   	# initialize any tensors to be consistent with dtype and device inferred from x
   	... 

and that by following this convention no other global config would be required.

@ronvree Thanks for the review and taking a close look! We think the ComputeConfig approach is a solid (and cleaner) alternative to using torch.set_default_dtype/device in diffwofost. Relying on PyTorch’s global defaults can get risky for reusable software since they’re process-wide and affect everything created afterward, including third-party code, tests, or models we don’t control. As a rule of thumb, torch.set_default_* is fine for notebooks or tiny scripts. For a modular code we use a config like ComputeConfig e.g.:

dtype = ComputeConfig.get_dtype()
device = ComputeConfig.get_device()
torch.tensor(..., dtype=dtype, device=device)

That said, I noticed that we’re currently setting device/dtype in two places (ComputeConfig and EngineTestHelper(device=..., dtype=...)) Also, I didn’t set device/dtype correctly when creating tensors in the notebooks. We’ll fix that in #84.

One more thing: model.to("cuda") won’t actually do anything here. Instead, the intended pattern is to set the device/dtype via ComputeConfig before running the forward pass. This approach is consistent with how other scientific simulators handle execution context.

Thanks for your reply! I completely agree with the concerns regarding torch.set_default_dtype and my suggestion was not to use this to sync with ComputeConfig but rather to not have a globally defined device/dtype for diffwofost. I'm mainly proposing to use the pytorch conventions where the torch defaults are used if nothing is specified, but the user could have different instances on different devices if they explicitly mention this

Originally posted by @ronvree in #76 (comment)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions