Skip to content

Commit 561f2c7

Browse files
authored
UltraQuery release (#22)
Main UltraQuery code
1 parent c414f83 commit 561f2c7

17 files changed

+2725
-20
lines changed

README.md

+183-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
[![pytorch](https://img.shields.io/badge/PyTorch_2.1+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
66
[![pyg](https://img.shields.io/badge/PyG_2.4+-3C2179?logo=pyg&logoColor=#3C2179)](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
7-
[![arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
7+
[![ULTRA arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
8+
[![UltraQuery arxiv](http://img.shields.io/badge/arxiv-2404.07198-yellow.svg)](https://arxiv.org/abs/2404.07198)
89
[![HuggingFace Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-black)](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)
910
![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)
1011

@@ -37,6 +38,7 @@ This repository is based on PyTorch 2.1 and PyTorch-Geometric 2.4.
3738
* [Pre-train](#pretraining) ULTRA on your own mixture of graphs.
3839
* Run [evaluation on many datasets](#run-on-many-datasets) sequentially.
3940
* Use the pre-trained checkpoints to run inference and fine-tuning on [your own KGs](#adding-your-own-graph).
41+
* (NEW) Execute complex logical queries on any KG with [UltraQuery](#ultraquery)
4042

4143
Table of contents:
4244
* [Installation](#installation)
@@ -47,8 +49,10 @@ Table of contents:
4749
* [Pretraining](#pretraining)
4850
* [Datasets](#datasets)
4951
* [Adding custom datasets](#adding-your-own-graph)
52+
* [UltraQuery](#ultraquery)
5053

5154
## Updates
55+
* **Apr 23rd, 2024**: Release of [UltraQuery](#ultraquery) for complex multi-hop logical query answering on _any_ KG (with new checkpoint and 23 datasets).
5256
* **Jan 15th, 2024**: Accepted at [ICLR 2024](https://openreview.net/forum?id=jVEoydFOl9)!
5357
* **Dec 4th, 2023**: Added a new ULTRA checkpoint `ultra_50g` pre-trained on 50 graphs. Averaged over 16 larger transductive graphs, it delivers 0.389 MRR / 0.549 Hits@10 compared to 0.329 MRR / 0.479 Hits@10 of the `ultra_3g` checkpoint. The inductive performance is still as good! Use this checkpoint for inference on larger graphs.
5458
* **Dec 4th, 2023**: Pre-trained ULTRA models (3g, 4g, 50g) are now also available on the [HuggingFace Hub](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)!
@@ -340,17 +344,188 @@ class CustomDataset(InductiveDataset):
340344
TSV / CSV files are supported by setting a delimiter (eg, `delimiter = "\t"`) in the class definition.
341345
After adding your own dataset, you can immediately run 0-shot inference or fine-tuning of any ULTRA checkpoint.
342346

347+
## UltraQuery ##
348+
349+
You can now run complex logical queries on any KG with UltraQuery, an inductive query answering approach that uses any Ultra checkpoint with non-parametric fuzzy logic operators. Read more in the [new preprint](https://arxiv.org/abs/2404.07198).
350+
351+
Similar to Ultra, UltraQuery transfers to any KG in the zero-shot fashion and sets a few SOTA results on a variety of query answering benchmarks.
352+
353+
### Checkpoint ###
354+
355+
Any existing ULTRA checkpoint is compatible with UltraQuery but we also ship a newly trained `ultraquery.pth` checkpoint in the `ckpts` folder.
356+
357+
* A new `ultraquery.pth` checkpoint trained on complex queries from the `FB15k237LogicalQuery` dataset for 40,000 steps, the config is in `config/ultraquery/pretrain.yaml` - the same ULTRA architecture but tuned for the multi-source propagation needed in complex queries (no need for score thresholding)
358+
* You can use any existing ULTRA checkpoint (`3g` / `4g` / `50g`) for starters - don't forget to set the `--threshold` argument to 0.8 or higher (depending on the dataset). Score thresholding is required because those models were trained on simple one-hop link prediction and there are certain issues (namely, the multi-source propagation issue, read Section 4.1 in the [new preprint](https://arxiv.org/abs/2404.07198) for more details)
359+
360+
### Performance
361+
362+
The numbers reported in the preprint were obtained with a model trained with TorchDrug. In this PyG implementation, we managed to get even better performance across the board with the `ultraquery.pth` checkpoint.
363+
364+
`EPFO` is the averaged performance over 9 queries with relation projection, intersection, and union. `Neg` is the averaged performance over 5 queries with negation.
365+
366+
<table>
367+
<tr>
368+
<th rowspan=2>Model</th>
369+
<th colspan=4>Total Average (23 datasets)</th>
370+
<th colspan=4>Transductive (3 datasets)</th>
371+
<th colspan=4>Inductive (e) (9 graphs)</th>
372+
<th colspan=4>Inductive (e,r) (11 graphs)</th>
373+
</tr>
374+
<tr>
375+
<th>EPFO MRR</th>
376+
<th>EPFO Hits@10</th>
377+
<th>Neg MRR</th>
378+
<th>Neg Hits@10</th>
379+
<th>EPFO MRR</th>
380+
<th>EPFO Hits@10</th>
381+
<th>Neg MRR</th>
382+
<th>Neg Hits@10</th>
383+
<th>EPFO MRR</th>
384+
<th>EPFO Hits@10</th>
385+
<th>Neg MRR</th>
386+
<th>Neg Hits@10</th>
387+
<th>EPFO MRR</th>
388+
<th>EPFO Hits@10</th>
389+
<th>Neg MRR</th>
390+
<th>Neg Hits@10</th>
391+
</tr>
392+
<tr>
393+
<th>UltraQuery Paper</th>
394+
<td align="center">0.301</td>
395+
<td align="center">0.428</td>
396+
<td align="center">0.152</td>
397+
<td align="center">0.264</td>
398+
<td align="center">0.335</td>
399+
<td align="center">0.467</td>
400+
<td align="center">0.132</td>
401+
<td align="center">0.260</td>
402+
<td align="center">0.321</td>
403+
<td align="center">0.479</td>
404+
<td align="center">0.156</td>
405+
<td align="center">0.291</td>
406+
<td align="center">0.275</td>
407+
<td align="center">0.375</td>
408+
<td align="center">0.153</td>
409+
<td align="center">0.242</td>
410+
</tr>
411+
<tr>
412+
<th>UltraQuery PyG</th>
413+
<td align="center">0.309</td>
414+
<td align="center">0.432</td>
415+
<td align="center">0.178</td>
416+
<td align="center">0.286</td>
417+
<td align="center">0.411</td>
418+
<td align="center">0.518</td>
419+
<td align="center">0.240</td>
420+
<td align="center">0.352</td>
421+
<td align="center">0.312</td>
422+
<td align="center">0.468</td>
423+
<td align="center">0.139</td>
424+
<td align="center">0.262</td>
425+
<td align="center">0.280</td>
426+
<td align="center">0.380</td>
427+
<td align="center">0.193</td>
428+
<td align="center">0.288</td>
429+
</tr>
430+
</table>
431+
432+
In particular, we reach SOTA on FB15k queries (0.764 MRR & 0.834 Hits@10 on EPFO; 0.567 MRR & 0.725 Hits@10 on negation) compared to much larger and heavier baselines (such as QTO).
433+
434+
### Run Inference ###
435+
436+
The running format is similar to the KG completion pipeline - use `run_query.py` and `run_query_many` for running a single expriment on one dataset or on a sequence of datasets.
437+
Due to the size of the datasets and query complexity, it is recommended to run inference on a GPU.
438+
439+
An example command for running transductive inference with UltraQuery on FB15k237 queries
440+
441+
```bash
442+
python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
443+
```
444+
445+
An example command for running transductive inference with a vanilla Ultra 4g on FB15k237 queries with scores thresholding
446+
447+
```bash
448+
python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.8 --ultra_ckpt /path/to/ultra/ckpts/ultra_4g.pth --qe_ckpt null
449+
```
450+
451+
An example command for running inductive inference with UltraQuery on `InductiveFB15k237Query:550` queries
452+
453+
```bash
454+
python script/run_query.py -c config/ultraquery/inductive.yaml --dataset InductiveFB15k237Query --version 550 --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
455+
```
456+
457+
New arguments for `_query` scripts:
458+
* `--threshold`: set to 0.0 when using the main UltraQuery checkpoint `ultraquery.pth` or 0.8 (and higher) when using vanilla Ultra checkpoints
459+
* `--qe_ckpt`: path to the UltraQuery checkpoint, set to `null` if you want to run vanilla Ultra checkpoints
460+
* `--ultra_ckpt`: path to the original Ultra checkpoints, set to `null` if you want to run the UltraQuery checkpoint
461+
462+
### Datasets ###
463+
464+
23 new datasets available in `datasets_query.py` that will be automatically downloaded upon the first launch.
465+
All datasets include 14 standard query types (`1p`, `2p`, `3p`, `2i`, `3i`, `ip`, `pi`, `2u-DNF`, `up-DNF`, `2in`, `3in`,`inp`, `pin`, `pni`).
466+
467+
The standard protocol is training on 10 patterns without unions and `ip`,`pi` queries (`1p`, `2p`, `3p`, `2i`, `3i`, `2in`, `3in`,`inp`, `pin`, `pni`) and running evaluation on all 14 patterns including `2u`, `up`, `ip`, `pi`.
468+
469+
<details>
470+
<summary>Transductive query datasets (3)</summary>
471+
472+
All are the [BetaE](https://arxiv.org/abs/2010.11465) versions of the datasets including queries with negation and limiting the max number of answers to 100
473+
* `FB15k237LogicalQuery`, `FB15kLogicalQuery`, `NELL995LogicalQuery`
474+
475+
</details>
476+
477+
<details>
478+
<summary>Inductive (e) query datasets (9)</summary>
479+
480+
9 inductive datasets extracted from FB15k237 - first proposed in [Inductive Logical Query Answering in Knowledge Graphs](https://openreview.net/forum?id=-vXEN5rIABY) (NeurIPS 2022)
481+
482+
`InductiveFB15k237Query` with 9 versions where the number shows the how large is the inference graph compared to the train graph (in the number of nodes):
483+
* `550`, `300`, `217`, `175`, `150`, `134`, `122`, `113`, `106`
484+
485+
In addition, we include the `InductiveFB15k237QueryExtendedEval` dataset with the same versions. Those are supposed to be inference-only datasets that measure the _faithfulness_ of complex query answering approaches. In each split, as validation and test graphs extend the train graphs with more nodes and edges, training queries now have more true answers achievable by simple edge traversal (no missing link prediction required) - the task is to measure how well CLQA models can retrieve new easy answers on training queries but on larger unseen graphs.
486+
487+
</details>
488+
489+
<details>
490+
<summary>Inductive (e,r) query datasets (11)</summary>
491+
492+
11 new inductive query datasets (WikiTopics-CLQA) that we built specifically for testing UltraQuery.
493+
The queries were sampled from the WikiTopics splits proposed in [Double Equivariance for Inductive Link Prediction for Both New Nodes and New Relation Types](https://arxiv.org/abs/2302.01313)
494+
495+
`WikiTopicsQuery` with 11 versions
496+
* `art`, `award`, `edu`, `health`, `infra`, `loc`, `org`, `people`, `sci`, `sport`, `tax`
497+
498+
</details>
499+
500+
### Metrics
501+
502+
New metrics include `auroc`, `spearmanr`, `mape`. We don't support Mean Rank `mr` in complex queries. If you ever see `nan` in one of those metrics, consider reducing the batch size as those metrics are computed with the variadic functions that might be numerically unstable on large batches.
503+
343504
## Citation ##
344505

345-
If you find this codebase useful in your research, please cite the original paper.
506+
If you find this codebase useful in your research, please cite the original papers.
507+
508+
The main ULTRA paper:
509+
510+
```bibtex
511+
@inproceedings{galkin2023ultra,
512+
title={Towards Foundation Models for Knowledge Graph Reasoning},
513+
author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
514+
booktitle={The Twelfth International Conference on Learning Representations},
515+
year={2024},
516+
url={https://openreview.net/forum?id=jVEoydFOl9}
517+
}
518+
```
519+
520+
UltraQuery:
346521

347522
```bibtex
348-
@article{galkin2023ultra,
349-
title={Towards Foundation Models for Knowledge Graph Reasoning},
350-
author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
351-
year={2023},
352-
eprint={2310.04562},
523+
@article{galkin2024ultraquery,
524+
title={Zero-shot Logical Query Reasoning on any Knowledge Graph},,
525+
author={Mikhail Galkin and Jincheng Zhou and Bruno Ribeiro and Jian Tang and Zhaocheng Zhu},
526+
year={2024},
527+
eprint={2404.07198},
353528
archivePrefix={arXiv},
354-
primaryClass={cs.CL}
529+
primaryClass={cs.AI}
355530
}
356531
```

ckpts/ultraquery.pth

2.03 MB
Binary file not shown.

config/inductive/inference.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ dataset:
88
model:
99
class: Ultra
1010
relation_model:
11-
class: NBFNet
11+
class: RelNBFNet
1212
input_dim: 64
1313
hidden_dims: [64, 64, 64, 64, 64, 64]
1414
message_func: distmult
1515
aggregate_func: sum
1616
short_cut: yes
1717
layer_norm: yes
1818
entity_model:
19-
class: IndNBFNet
19+
class: EntityNBFNet
2020
input_dim: 64
2121
hidden_dims: [64, 64, 64, 64, 64, 64]
2222
message_func: distmult

config/transductive/inference.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ dataset:
77
model:
88
class: Ultra
99
relation_model:
10-
class: NBFNet
10+
class: RelNBFNet
1111
input_dim: 64
1212
hidden_dims: [64, 64, 64, 64, 64, 64]
1313
message_func: distmult
1414
aggregate_func: sum
1515
short_cut: yes
1616
layer_norm: yes
1717
entity_model:
18-
class: IndNBFNet
18+
class: EntityNBFNet
1919
input_dim: 64
2020
hidden_dims: [64, 64, 64, 64, 64, 64]
2121
message_func: distmult

config/transductive/pretrain_3g.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ dataset:
88
model:
99
class: Ultra
1010
relation_model:
11-
class: NBFNet
11+
class: RelNBFNet
1212
input_dim: 64
1313
hidden_dims: [64, 64, 64, 64, 64, 64]
1414
message_func: distmult
1515
aggregate_func: sum
1616
short_cut: yes
1717
layer_norm: yes
1818
entity_model:
19-
class: IndNBFNet
19+
class: EntityNBFNet
2020
input_dim: 64
2121
hidden_dims: [64, 64, 64, 64, 64, 64]
2222
message_func: distmult

config/transductive/pretrain_4g.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ dataset:
88
model:
99
class: Ultra
1010
relation_model:
11-
class: NBFNet
11+
class: RelNBFNet
1212
input_dim: 64
1313
hidden_dims: [64, 64, 64, 64, 64, 64]
1414
message_func: distmult
1515
aggregate_func: sum
1616
short_cut: yes
1717
layer_norm: yes
1818
entity_model:
19-
class: IndNBFNet
19+
class: EntityNBFNet
2020
input_dim: 64
2121
hidden_dims: [64, 64, 64, 64, 64, 64]
2222
message_func: distmult

config/ultraquery/inductive.yaml

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
output_dir: ~/git/ULTRA/output
2+
3+
dataset:
4+
class: {{ dataset }}
5+
root: ~/git/ULTRA/query-datasets/
6+
version: {{ version }} # specify dataset version here or when running the script
7+
8+
model:
9+
class: UltraQuery
10+
model:
11+
class: Ultra
12+
relation_model:
13+
class: RelNBFNet
14+
input_dim: 64
15+
hidden_dims: [64, 64, 64, 64, 64, 64]
16+
message_func: distmult
17+
aggregate_func: sum
18+
short_cut: yes
19+
layer_norm: yes
20+
entity_model:
21+
class: QueryNBFNet
22+
input_dim: 64
23+
hidden_dims: [64, 64, 64, 64, 64, 64]
24+
message_func: distmult
25+
aggregate_func: sum
26+
short_cut: yes
27+
layer_norm: yes
28+
logic: product
29+
dropout_ratio: 0.5
30+
threshold: {{ threshold }}
31+
more_dropout: 0.0
32+
33+
task:
34+
name: InductiveInference
35+
strict_negative: yes
36+
adversarial_temperature: 0.2
37+
sample_weight: no
38+
metric: [mrr, hits@1, hits@3, hits@10, auroc, spearmanr] # mape is supported as well
39+
40+
optimizer:
41+
class: Adam
42+
lr: 5.0e-4
43+
44+
train:
45+
gpus: {{ gpus }}
46+
batch_size: {{ bs }} # reduce if doesn't fit on a GPU
47+
num_epoch: {{ epochs }} # total number of optimization steps will be num_epochs * batch_per_epoch
48+
batch_per_epoch: {{ bpe }} # number of batches to be considered as "one epoch"
49+
log_interval: 100
50+
fast_test: 1000 # UltraQuery is slower in inference, use this option for a random subsample of valid data
51+
52+
ultra_ckpt: {{ ultra_ckpt }} # Ultra checkpoint pre-trained on simple link prediction
53+
ultraquery_ckpt: {{ qe_ckpt }} # UltraQuery checkpoint trained on complex queries

0 commit comments

Comments
 (0)