You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I'm currently implementing a custom heterograph gnn and I found something that I don't understand which shows up while performing a forward pass with a heterodata object.
Concretely, I'm trying to print the network output inside a training loop. Turns out that instead of printing the output tensor as I would expect in a normal torch network, it only prints Proxy(getattr_1) at the first iteration of the loop and then it stops for the rest of the remaining iterations.
def forward(self, x, edge_index):
print('forward')
x = self.gat(x=x, edge_index=edge_index)
print('x',x['node1'].x)
return x
for _ in range(3):
model(heterograph.x_dict, heterograph.edge_index_dict)
`
The output prints this:
forward x Proxy(getattr_1)
Whereas I was expecting something more like this:
forward x node1 attributes tensor...
forward x node1 attributes tensor...
forward x node1 attributes tensor...
Can someone explain me what is this proxy thing for and why it does not work as a normal torch forward pass?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi! I'm currently implementing a custom heterograph gnn and I found something that I don't understand which shows up while performing a forward pass with a heterodata object.
Concretely, I'm trying to print the network output inside a training loop. Turns out that instead of printing the output tensor as I would expect in a normal torch network, it only prints Proxy(getattr_1) at the first iteration of the loop and then it stops for the rest of the remaining iterations.
A simplification of the code is as follows:
`
class gnn(torch.nn.Module):
def init(self):
super(gnn, self).init()
self.gat = torch_geometric.nn.conv.GATv2Conv(in_channels=(-1,-1), out_channels=4, heads=2, concat=False, add_self_loops=False)
`
The output prints this:
forward
x Proxy(getattr_1)
Whereas I was expecting something more like this:
forward
x node1 attributes tensor...
forward
x node1 attributes tensor...
forward
x node1 attributes tensor...
Can someone explain me what is this proxy thing for and why it does not work as a normal torch forward pass?
Thank you very much for your help :)
Beta Was this translation helpful? Give feedback.
All reactions