Skip to content

Commit e3214a5

Browse files
authored
Style fix (#35)
* precommit changes * add precommit hook
1 parent 7cf6c04 commit e3214a5

24 files changed

+1126
-699
lines changed

Diff for: .pre-commit-config.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# To run locally:
2+
# % pre-commit run -a
3+
repos:
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
5+
rev: v2.3.0
6+
hooks:
7+
- id: check-yaml
8+
- id: end-of-file-fixer
9+
- id: trailing-whitespace
10+
- repo: local
11+
hooks:
12+
- id: black
13+
name: black
14+
entry: black
15+
language: system
16+
types: [python]

Diff for: LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
2525
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
2626
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2727
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28-
SOFTWARE.
28+
SOFTWARE.

Diff for: README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ There are four scales in our competition:
5151
- `xlarge`: 12.8B pool size, 12.8B examples seen
5252

5353

54-
The script will create two directories inside `$data_dir`: `metadata` and `shards`.
54+
The script will create two directories inside `$data_dir`: `metadata` and `shards`.
5555

5656
Along with the images and captions, this script will also download metadata, including `.parquet` files that contain the image urls, captions, and other potentially useful information such as the similarities between the images and captions given by trained OpenAI CLIP models.
5757
If the flag `--download_npz` is used, the script will also download the `.npz` files with features extracted by the trained OpenAI CLIP models for each sample.
@@ -161,7 +161,7 @@ A image clustering based method that retains samples whose images have content c
161161
python baselines.py --metadata_dir path/to/metadata --save_path path/to/image_based.npy --name image_based --image_based_scale small --batch_size 512
162162
```
163163

164-
**Note**: this baseline requires pre-computed image cluster centroids which will be downloaded automatically the first time you run it.
164+
**Note**: this baseline requires pre-computed image cluster centroids which will be downloaded automatically the first time you run it.
165165
If you want to generate the centroids yourself, please see `baselines/image_based_clustering.md` for instructions.
166166

167167
### Intersection of image-based and CLIP score filtering
@@ -239,7 +239,7 @@ We also highly encourage participants to also upload the checkpoints for their t
239239

240240
## Checkpoints
241241

242-
We release the checkpoints for our main baselines as part of [OpenCLIP](https://github.com/mlfoundations/open_clip). More details can be found at https://github.com/mlfoundations/open_clip/blob/main/docs/datacomp_models.md.
242+
We release the checkpoints for our main baselines as part of [OpenCLIP](https://github.com/mlfoundations/open_clip). More details can be found at https://github.com/mlfoundations/open_clip/blob/main/docs/datacomp_models.md.
243243

244244
## Citation
245245

Diff for: aggregate_scores.py

+38-21
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,68 @@
11
import argparse
2-
import pandas as pd
32
import os
3+
44
import numpy as np
5+
import pandas as pd
56

67
DATASET_GROUPS = {
7-
'ImageNet dist. shifts': {
8-
'ImageNet Sketch', 'ImageNet v2', 'ImageNet-A', 'ImageNet-O', 'ImageNet-R', 'ObjectNet'
8+
"ImageNet dist. shifts": {
9+
"ImageNet Sketch",
10+
"ImageNet v2",
11+
"ImageNet-A",
12+
"ImageNet-O",
13+
"ImageNet-R",
14+
"ObjectNet",
915
},
10-
'VTAB': {
11-
'Caltech-101', 'CIFAR-100', 'CLEVR Counts', 'CLEVR Distance', 'Describable Textures', 'EuroSAT',
12-
'KITTI Vehicle Distance', 'Oxford Flowers-102', 'Oxford-IIIT Pet', 'PatchCamelyon', 'RESISC45',
13-
'SVHN', 'SUN397'},
14-
'Retrieval': {'Flickr', 'MSCOCO', 'WinoGAViL'},
16+
"VTAB": {
17+
"Caltech-101",
18+
"CIFAR-100",
19+
"CLEVR Counts",
20+
"CLEVR Distance",
21+
"Describable Textures",
22+
"EuroSAT",
23+
"KITTI Vehicle Distance",
24+
"Oxford Flowers-102",
25+
"Oxford-IIIT Pet",
26+
"PatchCamelyon",
27+
"RESISC45",
28+
"SVHN",
29+
"SUN397",
30+
},
31+
"Retrieval": {"Flickr", "MSCOCO", "WinoGAViL"},
1532
}
1633

1734

1835
def get_aggregate_scores(results_file):
1936
"""Returns a dictionary with aggregated scores from a results file."""
2037
df = pd.read_json(results_file, lines=True)
21-
df = pd.concat([df.drop(['metrics'], axis=1), df['metrics'].apply(pd.Series)], axis=1)
22-
df = df.dropna(subset=['main_metric'])
23-
assert len(df) == 38, f'Results file has unexpected size, {len(df)}'
38+
df = pd.concat(
39+
[df.drop(["metrics"], axis=1), df["metrics"].apply(pd.Series)], axis=1
40+
)
41+
df = df.dropna(subset=["main_metric"])
42+
assert len(df) == 38, f"Results file has unexpected size, {len(df)}"
2443
results = dict(zip(df.dataset, df.main_metric))
25-
26-
aggregate_results = {
27-
'ImageNet': results['ImageNet 1k']
28-
}
44+
45+
aggregate_results = {"ImageNet": results["ImageNet 1k"]}
2946

3047
for group, datasets in DATASET_GROUPS.items():
3148
score = np.mean([results[dataset] for dataset in datasets])
3249
aggregate_results[group] = score
33-
3450

35-
aggregate_results['Average'] = np.mean(list(results.values()))
51+
aggregate_results["Average"] = np.mean(list(results.values()))
3652

3753
return aggregate_results
3854

3955

40-
41-
if __name__ == '__main__':
56+
if __name__ == "__main__":
4257
parser = argparse.ArgumentParser()
4358

44-
parser.add_argument('--input', type=str, required=True, help='Path to the results file.')
59+
parser.add_argument(
60+
"--input", type=str, required=True, help="Path to the results file."
61+
)
4562

4663
args = parser.parse_args()
4764

4865
scores = get_aggregate_scores(args.input)
4966

5067
for group, score in scores.items():
51-
print(f'{group}: {score:.3f}')
68+
print(f"{group}: {score:.3f}")

Diff for: baselines/apply_filter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import multiprocessing as mp
12
import os
3+
import time
24
from functools import partial
3-
import multiprocessing as mp
45
from multiprocessing import Pool
5-
from typing import Any, Set, Tuple, List, Union
6+
from queue import Empty
7+
from typing import Any, List, Set, Tuple, Union
68

79
import fasttext
810
import fsspec
@@ -13,8 +15,6 @@
1315
import torch
1416
from nltk.corpus import wordnet
1517
from tqdm import tqdm
16-
import time
17-
from queue import Empty
1818

1919
from baselines.utils import download, worker_threadpool
2020

Diff for: baselines/image_based_clustering.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Clustering
22

3-
Generates cluster centroids from the `image-based` baselines using k-means clustering.
3+
Generates cluster centroids from the `image-based` baselines using k-means clustering.
44

55

66
## Installing dependencies
@@ -20,7 +20,7 @@ To run clustering for the `small` pool, run the following command:
2020

2121
```
2222
python image_based_clustering.py \
23-
--metadata_dir path/to/metadata \
23+
--metadata_dir path/to/metadata \
2424
--save_path path/to/output/centroids \
2525
--num_clusters 100000 \
2626
--sample_ratio -1.0 \
@@ -34,4 +34,3 @@ Explanation to several arguments:
3434
- `disable_caption_filtering`: whether to disable caption filtering to the dataset. Default is `False`
3535

3636
On a machine with 8 GPUs and 26 CPUs (there are 26 parquet files for the `small` pool), the clustering process takes about 10 minutes.
37-

Diff for: baselines/image_based_clustering.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import multiprocessing as mp
99
from functools import partial
1010
from multiprocessing import Pool
11-
from typing import Any, Tuple, List
11+
from typing import Any, List, Tuple
1212

1313
import faiss
1414
import fasttext
@@ -18,17 +18,14 @@
1818
import torch
1919
from tqdm import tqdm
2020

21-
from baselines.utils import random_seed, download
2221
from baselines.apply_filter import caption_filter
22+
from baselines.utils import download, random_seed
2323

2424
torch.backends.cudnn.benchmark = True
2525

2626

2727
def train_kmeans(
28-
embeddings: np.ndarray,
29-
num_clusters: int,
30-
num_gpus: int,
31-
seed: int = 0
28+
embeddings: np.ndarray, num_clusters: int, num_gpus: int, seed: int = 0
3229
) -> torch.Tensor:
3330
"""train kmeans on embeddings
3431
@@ -59,7 +56,9 @@ def train_kmeans(
5956
if num_gpus == 1:
6057
index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
6158
else:
62-
indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i]) for i in range(num_gpus)]
59+
indexes = [
60+
faiss.GpuIndexFlatL2(res[i], d, flat_config[i]) for i in range(num_gpus)
61+
]
6362
index = faiss.IndexReplicas()
6463
for sub_index in indexes:
6564
index.addIndex(sub_index)
@@ -72,10 +71,10 @@ def train_kmeans(
7271

7372

7473
def load_embedding_helper(
75-
fs_root: Tuple[Any, str],
76-
key: str = "l14_img",
77-
caption_filtering: bool = False,
78-
sample_ratio: float = -1.0
74+
fs_root: Tuple[Any, str],
75+
key: str = "l14_img",
76+
caption_filtering: bool = False,
77+
sample_ratio: float = -1.0,
7978
) -> np.ndarray:
8079
"""worker function to load embeddings
8180
@@ -89,8 +88,12 @@ def load_embedding_helper(
8988
fs, path_root = fs_root
9089
embed = np.load(fs.open(f"{path_root}.npz"))[key]
9190
if caption_filtering:
92-
lang_detect_model = fasttext.load_model(download("fasttext", "~/.cache/fasttext"))
93-
df = pd.read_parquet(f"{path_root}.parquet", columns=["uid", "text"], filesystem=fs)
91+
lang_detect_model = fasttext.load_model(
92+
download("fasttext", "~/.cache/fasttext")
93+
)
94+
df = pd.read_parquet(
95+
f"{path_root}.parquet", columns=["uid", "text"], filesystem=fs
96+
)
9497
mask = caption_filter(df, lang_detect_model)
9598
embed = embed[mask]
9699
if sample_ratio > 0:
@@ -101,11 +104,11 @@ def load_embedding_helper(
101104

102105

103106
def load_embedding(
104-
paths: List[Tuple[Any, str]],
105-
n_workers: int = 10,
106-
key: str = "l14_img",
107-
caption_filtering: bool = False,
108-
sample_ratio: float = -1.0
107+
paths: List[Tuple[Any, str]],
108+
n_workers: int = 10,
109+
key: str = "l14_img",
110+
caption_filtering: bool = False,
111+
sample_ratio: float = -1.0,
109112
) -> np.ndarray:
110113
"""worker function to load embeddings
111114
@@ -128,7 +131,9 @@ def load_embedding(
128131
with Pool(n_workers) as pool:
129132
embeds = [
130133
res
131-
for res in tqdm(pool.imap(worker, paths), total=len(paths)) # imap so that it can be reproduced
134+
for res in tqdm(
135+
pool.imap(worker, paths), total=len(paths)
136+
) # imap so that it can be reproduced
132137
if len(res) > 0
133138
]
134139
return np.vstack(embeds)
@@ -147,10 +152,10 @@ def load_embedding(
147152
)
148153
parser.add_argument(
149154
"--embedding_key",
150-
default='l14_img',
155+
default="l14_img",
151156
type=str,
152-
choices=['l14_img', 'b32_img'],
153-
help="precomputed embeddings used for clustering"
157+
choices=["l14_img", "b32_img"],
158+
help="precomputed embeddings used for clustering",
154159
)
155160
parser.add_argument(
156161
"--sample_ratio",
@@ -202,5 +207,7 @@ def load_embedding(
202207

203208
print(f"start clustering: num_clusters = {num_clusters}, num_gpus = {num_gpus}")
204209
embeddings = embeddings.astype(np.float32)
205-
centroids = train_kmeans(embeddings, num_clusters, num_gpus=num_gpus, seed=args.seed)
210+
centroids = train_kmeans(
211+
embeddings, num_clusters, num_gpus=num_gpus, seed=args.seed
212+
)
206213
torch.save(centroids, args.save_path, pickle_protocol=4)

0 commit comments

Comments
 (0)