Skip to content

Commit d721f51

Browse files
committed
Init submit
1 parent 8862a11 commit d721f51

10 files changed

+1455
-1
lines changed

Diff for: .DS_Store

6 KB
Binary file not shown.

Diff for: README.md

+31-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,31 @@
1-
# GraphCTA
1+
# GraphCTA
2+
Collaborate to Adapt: Source-Free Graph Domain Adaptation via Bi-directional Adaptation (WWW 2024)
3+
4+
![](https://github.com/cszhangzhen/GraphCTA/blob/main/fig/model.png)
5+
6+
This is a PyTorch implementation of the GraphCTA algorithm, which tries to address the domain adaptation problem without accessing the labelled source graph. It performs model adaptation and graph adaptation collaboratively through a series of procedures: (1) conduct model adaptation based on node's neighborhood predictions in target graph considering both local and global information; (2) perform graph adaptation by updating graph structure and node attributes via neighborhood constrastive learning; and (3) the updated graph serves as an input to facilitate the subsequent iteration of model adaptation, thereby establishing a collaborative loop between model adaptation and graph adaptation.
7+
8+
9+
## Requirements
10+
* python3.8
11+
* pytorch==2.0.0
12+
* torch-scatter==2.1.1+pt20cu118
13+
* torch-sparse==0.6.17+pt20cu118
14+
* torch-cluster==1.6.1+pt20cu118
15+
* torch-geometric==2.3.1
16+
* numpy==1.24.3
17+
* scipy==1.10.1
18+
* tqdm==4.65.0
19+
20+
## Datasets
21+
Datasets used in the paper are all publicly available datasets. You can find [Elliptic](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set), [Twitch](https://github.com/benedekrozemberczki/datasets#twitch-social-networks) and [Citation](https://github.com/yuntaodu/ASN/tree/main/data) via the links.
22+
23+
## Quick Start:
24+
Just execuate the following command for source model pre-training:
25+
```
26+
python train_source.py
27+
```
28+
Then, execuate the following command for adaptation:
29+
```
30+
python train_target.py
31+
```

Diff for: data/Citation.zip

5.42 MB
Binary file not shown.

Diff for: datasets.py

+271
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
import os.path as osp
2+
import torch
3+
import numpy as np
4+
from torch_geometric.data import InMemoryDataset, Data
5+
from torch_geometric.io import read_txt_array
6+
import torch.nn.functional as F
7+
8+
import scipy
9+
import pickle as pkl
10+
import csv
11+
import json
12+
13+
import warnings
14+
warnings.filterwarnings('ignore', category=DeprecationWarning)
15+
16+
17+
class CitationDataset(InMemoryDataset):
18+
def __init__(self,
19+
root,
20+
name,
21+
transform=None,
22+
pre_transform=None,
23+
pre_filter=None):
24+
self.name = name
25+
self.root = root
26+
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)
27+
28+
self.data, self.slices = torch.load(self.processed_paths[0])
29+
30+
@property
31+
def raw_file_names(self):
32+
return ["docs.txt", "edgelist.txt", "labels.txt"]
33+
34+
@property
35+
def processed_file_names(self):
36+
return ['data.pt']
37+
38+
def download(self):
39+
pass
40+
41+
def process(self):
42+
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
43+
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()
44+
45+
docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
46+
f = open(docs_path, 'rb')
47+
content_list = []
48+
for line in f.readlines():
49+
line = str(line, encoding="utf-8")
50+
content_list.append(line.split(","))
51+
x = np.array(content_list, dtype=float)
52+
x = torch.from_numpy(x).to(torch.float)
53+
54+
label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
55+
f = open(label_path, 'rb')
56+
content_list = []
57+
for line in f.readlines():
58+
line = str(line, encoding="utf-8")
59+
line = line.replace("\r", "").replace("\n", "")
60+
content_list.append(line)
61+
y = np.array(content_list, dtype=int)
62+
y = torch.from_numpy(y).to(torch.int64)
63+
64+
data_list = []
65+
data = Data(edge_index=edge_index, x=x, y=y)
66+
67+
random_node_indices = np.random.permutation(y.shape[0])
68+
training_size = int(len(random_node_indices) * 0.8)
69+
val_size = int(len(random_node_indices) * 0.1)
70+
train_node_indices = random_node_indices[:training_size]
71+
val_node_indices = random_node_indices[training_size:training_size + val_size]
72+
test_node_indices = random_node_indices[training_size + val_size:]
73+
74+
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
75+
train_masks[train_node_indices] = 1
76+
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
77+
val_masks[val_node_indices] = 1
78+
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
79+
test_masks[test_node_indices] = 1
80+
81+
data.train_mask = train_masks
82+
data.val_mask = val_masks
83+
data.test_mask = test_masks
84+
85+
if self.pre_transform is not None:
86+
data = self.pre_transform(data)
87+
88+
data_list.append(data)
89+
90+
data, slices = self.collate([data])
91+
92+
torch.save((data, slices), self.processed_paths[0])
93+
94+
95+
class EllipticDataset(InMemoryDataset):
96+
def __init__(self,
97+
root,
98+
name,
99+
transform=None,
100+
pre_transform=None,
101+
pre_filter=None):
102+
self.name = name
103+
self.root = root
104+
super(EllipticDataset, self).__init__(root, transform, pre_transform, pre_filter)
105+
106+
self.data, self.slices = torch.load(self.processed_paths[0])
107+
108+
@property
109+
def raw_file_names(self):
110+
return [".pkl"]
111+
112+
@property
113+
def processed_file_names(self):
114+
return ['data.pt']
115+
116+
def download(self):
117+
pass
118+
119+
def process(self):
120+
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
121+
result = pkl.load(open(path, 'rb'))
122+
A, label, features = result
123+
label = label + 1
124+
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
125+
features = np.array(features)
126+
x = torch.from_numpy(features).to(torch.float)
127+
y = torch.tensor(label).to(torch.int64)
128+
129+
data_list = []
130+
data = Data(edge_index=edge_index, x=x, y=y)
131+
132+
random_node_indices = np.random.permutation(y.shape[0])
133+
training_size = int(len(random_node_indices) * 0.8)
134+
val_size = int(len(random_node_indices) * 0.1)
135+
train_node_indices = random_node_indices[:training_size]
136+
val_node_indices = random_node_indices[training_size:training_size + val_size]
137+
test_node_indices = random_node_indices[training_size + val_size:]
138+
139+
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
140+
train_masks[train_node_indices] = 1
141+
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
142+
val_masks[val_node_indices] = 1
143+
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
144+
test_masks[test_node_indices] = 1
145+
146+
data.train_mask = train_masks
147+
data.val_mask = val_masks
148+
data.test_mask = test_masks
149+
150+
if self.pre_transform is not None:
151+
data = self.pre_transform(data)
152+
153+
data_list.append(data)
154+
155+
data, slices = self.collate([data])
156+
157+
torch.save((data, slices), self.processed_paths[0])
158+
159+
160+
class TwitchDataset(InMemoryDataset):
161+
def __init__(self,
162+
root,
163+
name,
164+
transform=None,
165+
pre_transform=None,
166+
pre_filter=None):
167+
self.name = name
168+
self.root = root
169+
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter)
170+
171+
self.data, self.slices = torch.load(self.processed_paths[0])
172+
173+
@property
174+
def raw_file_names(self):
175+
return ["edges.csv, features.json, target.csv"]
176+
177+
@property
178+
def processed_file_names(self):
179+
return ['data.pt']
180+
181+
def download(self):
182+
pass
183+
184+
def load_twitch(self, lang):
185+
assert lang in ('DE', 'EN', 'FR'), 'Invalid dataset'
186+
filepath = self.raw_dir
187+
label = []
188+
node_ids = []
189+
src = []
190+
targ = []
191+
uniq_ids = set()
192+
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
193+
reader = csv.reader(f)
194+
next(reader)
195+
for row in reader:
196+
node_id = int(row[5])
197+
# handle FR case of non-unique rows
198+
if node_id not in uniq_ids:
199+
uniq_ids.add(node_id)
200+
label.append(int(row[2]=="True"))
201+
node_ids.append(int(row[5]))
202+
203+
node_ids = np.array(node_ids, dtype=np.int32)
204+
205+
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
206+
reader = csv.reader(f)
207+
next(reader)
208+
for row in reader:
209+
src.append(int(row[0]))
210+
targ.append(int(row[1]))
211+
212+
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
213+
j = json.load(f)
214+
215+
src = np.array(src)
216+
targ = np.array(targ)
217+
label = np.array(label)
218+
219+
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
220+
reorder_node_ids = np.zeros_like(node_ids)
221+
for i in range(label.shape[0]):
222+
reorder_node_ids[i] = inv_node_ids[i]
223+
224+
n = label.shape[0]
225+
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n))
226+
features = np.zeros((n,3170))
227+
for node, feats in j.items():
228+
if int(node) >= n:
229+
continue
230+
features[int(node), np.array(feats, dtype=int)] = 1
231+
new_label = label[reorder_node_ids]
232+
label = new_label
233+
234+
return A, label, features
235+
236+
def process(self):
237+
A, label, features = self.load_twitch(self.name)
238+
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
239+
features = np.array(features)
240+
x = torch.from_numpy(features).to(torch.float)
241+
y = torch.from_numpy(label).to(torch.int64)
242+
243+
data_list = []
244+
data = Data(edge_index=edge_index, x=x, y=y)
245+
246+
random_node_indices = np.random.permutation(y.shape[0])
247+
training_size = int(len(random_node_indices) * 0.8)
248+
val_size = int(len(random_node_indices) * 0.1)
249+
train_node_indices = random_node_indices[:training_size]
250+
val_node_indices = random_node_indices[training_size:training_size + val_size]
251+
test_node_indices = random_node_indices[training_size + val_size:]
252+
253+
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
254+
train_masks[train_node_indices] = 1
255+
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
256+
val_masks[val_node_indices] = 1
257+
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
258+
test_masks[test_node_indices] = 1
259+
260+
data.train_mask = train_masks
261+
data.val_mask = val_masks
262+
data.test_mask = test_masks
263+
264+
if self.pre_transform is not None:
265+
data = self.pre_transform(data)
266+
267+
data_list.append(data)
268+
269+
data, slices = self.collate([data])
270+
271+
torch.save((data, slices), self.processed_paths[0])

Diff for: fig/model.png

206 KB
Loading

0 commit comments

Comments
 (0)