Skip to content

Commit fbcd9ea

Browse files
authored
Merge pull request #19674 from github/redsun82/mad
Rust: regenerate MaD files using DCA
2 parents 6811cad + 4ac4e44 commit fbcd9ea

File tree

71 files changed

+3946
-936
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+3946
-936
lines changed

cpp/bulk_generation_targets.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
language: cpp
2+
strategy: dca
3+
destination: cpp/ql/lib/ext/generated
4+
targets:
5+
- name: openssl
6+
with-sinks: false
7+
with-sources: false
8+
- name: sqlite
9+
with-sinks: false
10+
with-sources: false

cpp/misc/bulk_generation_targets.json

Lines changed: 0 additions & 9 deletions
This file was deleted.

misc/scripts/models-as-data/bulk_generate_mad.py

100644100755
Lines changed: 131 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python3
12
"""
23
Experimental script for bulk generation of MaD models based on a list of projects.
34
@@ -7,15 +8,31 @@
78
import os.path
89
import subprocess
910
import sys
10-
from typing import NotRequired, TypedDict, List
11+
from typing import Required, TypedDict, List, Callable, Optional
1112
from concurrent.futures import ThreadPoolExecutor, as_completed
1213
import time
1314
import argparse
14-
import json
15-
import requests
1615
import zipfile
1716
import tarfile
18-
from functools import cmp_to_key
17+
import shutil
18+
19+
20+
def missing_module(module_name: str) -> None:
21+
print(
22+
f"ERROR: {module_name} is not installed. Please install it with 'pip install {module_name}'."
23+
)
24+
sys.exit(1)
25+
26+
27+
try:
28+
import yaml
29+
except ImportError:
30+
missing_module("pyyaml")
31+
32+
try:
33+
import requests
34+
except ImportError:
35+
missing_module("requests")
1936

2037
import generate_mad as mad
2138

@@ -28,22 +45,18 @@
2845

2946

3047
# A project to generate models for
31-
class Project(TypedDict):
32-
"""
33-
Type definition for projects (acquired via a GitHub repo) to model.
34-
35-
Attributes:
36-
name: The name of the project
37-
git_repo: URL to the git repository
38-
git_tag: Optional Git tag to check out
39-
"""
40-
41-
name: str
42-
git_repo: NotRequired[str]
43-
git_tag: NotRequired[str]
44-
with_sinks: NotRequired[bool]
45-
with_sinks: NotRequired[bool]
46-
with_summaries: NotRequired[bool]
48+
Project = TypedDict(
49+
"Project",
50+
{
51+
"name": Required[str],
52+
"git-repo": str,
53+
"git-tag": str,
54+
"with-sinks": bool,
55+
"with-sources": bool,
56+
"with-summaries": bool,
57+
},
58+
total=False,
59+
)
4760

4861

4962
def should_generate_sinks(project: Project) -> bool:
@@ -63,14 +76,14 @@ def clone_project(project: Project) -> str:
6376
Shallow clone a project into the build directory.
6477
6578
Args:
66-
project: A dictionary containing project information with 'name', 'git_repo', and optional 'git_tag' keys.
79+
project: A dictionary containing project information with 'name', 'git-repo', and optional 'git-tag' keys.
6780
6881
Returns:
6982
The path to the cloned project directory.
7083
"""
7184
name = project["name"]
72-
repo_url = project["git_repo"]
73-
git_tag = project.get("git_tag")
85+
repo_url = project["git-repo"]
86+
git_tag = project.get("git-tag")
7487

7588
# Determine target directory
7689
target_dir = os.path.join(build_dir, name)
@@ -103,6 +116,39 @@ def clone_project(project: Project) -> str:
103116
return target_dir
104117

105118

119+
def run_in_parallel[
120+
T, U
121+
](
122+
func: Callable[[T], U],
123+
items: List[T],
124+
*,
125+
on_error=lambda item, exc: None,
126+
error_summary=lambda failures: None,
127+
max_workers=8,
128+
) -> List[Optional[U]]:
129+
if not items:
130+
return []
131+
max_workers = min(max_workers, len(items))
132+
results = [None for _ in range(len(items))]
133+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
134+
# Start cloning tasks and keep track of them
135+
futures = {
136+
executor.submit(func, item): index for index, item in enumerate(items)
137+
}
138+
# Process results as they complete
139+
for future in as_completed(futures):
140+
index = futures[future]
141+
try:
142+
results[index] = future.result()
143+
except Exception as e:
144+
on_error(items[index], e)
145+
failed = [item for item, result in zip(items, results) if result is None]
146+
if failed:
147+
error_summary(failed)
148+
sys.exit(1)
149+
return results
150+
151+
106152
def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
107153
"""
108154
Clone all projects in parallel.
@@ -114,40 +160,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
114160
List of (project, project_dir) pairs in the same order as the input projects
115161
"""
116162
start_time = time.time()
117-
max_workers = min(8, len(projects)) # Use at most 8 threads
118-
project_dirs_map = {} # Map to store results by project name
119-
120-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
121-
# Start cloning tasks and keep track of them
122-
future_to_project = {
123-
executor.submit(clone_project, project): project for project in projects
124-
}
125-
126-
# Process results as they complete
127-
for future in as_completed(future_to_project):
128-
project = future_to_project[future]
129-
try:
130-
project_dir = future.result()
131-
project_dirs_map[project["name"]] = (project, project_dir)
132-
except Exception as e:
133-
print(f"ERROR: Failed to clone {project['name']}: {e}")
134-
135-
if len(project_dirs_map) != len(projects):
136-
failed_projects = [
137-
project["name"]
138-
for project in projects
139-
if project["name"] not in project_dirs_map
140-
]
141-
print(
142-
f"ERROR: Only {len(project_dirs_map)} out of {len(projects)} projects were cloned successfully. Failed projects: {', '.join(failed_projects)}"
143-
)
144-
sys.exit(1)
145-
146-
project_dirs = [project_dirs_map[project["name"]] for project in projects]
147-
163+
dirs = run_in_parallel(
164+
clone_project,
165+
projects,
166+
on_error=lambda project, exc: print(
167+
f"ERROR: Failed to clone project {project['name']}: {exc}"
168+
),
169+
error_summary=lambda failures: print(
170+
f"ERROR: Failed to clone {len(failures)} projects: {', '.join(p['name'] for p in failures)}"
171+
),
172+
)
148173
clone_time = time.time() - start_time
149174
print(f"Cloning completed in {clone_time:.2f} seconds")
150-
return project_dirs
175+
return list(zip(projects, dirs))
151176

152177

153178
def build_database(
@@ -159,7 +184,7 @@ def build_database(
159184
Args:
160185
language: The language for which to build the database (e.g., "rust").
161186
extractor_options: Additional options for the extractor.
162-
project: A dictionary containing project information with 'name' and 'git_repo' keys.
187+
project: A dictionary containing project information with 'name' and 'git-repo' keys.
163188
project_dir: Path to the CodeQL database.
164189
165190
Returns:
@@ -307,7 +332,10 @@ def pretty_name_from_artifact_name(artifact_name: str) -> str:
307332

308333

309334
def download_dca_databases(
310-
experiment_name: str, pat: str, projects: List[Project]
335+
language: str,
336+
experiment_name: str,
337+
pat: str,
338+
projects: List[Project],
311339
) -> List[tuple[Project, str | None]]:
312340
"""
313341
Download databases from a DCA experiment.
@@ -318,14 +346,14 @@ def download_dca_databases(
318346
Returns:
319347
List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
320348
"""
321-
database_results = {}
322349
print("\n=== Finding projects ===")
323350
response = get_json_from_github(
324351
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
325352
pat,
326353
)
327354
targets = response["targets"]
328355
project_map = {project["name"]: project for project in projects}
356+
analyzed_databases = {}
329357
for data in targets.values():
330358
downloads = data["downloads"]
331359
analyzed_database = downloads["analyzed_database"]
@@ -336,6 +364,15 @@ def download_dca_databases(
336364
print(f"Skipping {pretty_name} as it is not in the list of projects")
337365
continue
338366

367+
if pretty_name in analyzed_databases:
368+
print(
369+
f"Skipping previous database {analyzed_databases[pretty_name]['artifact_name']} for {pretty_name}"
370+
)
371+
372+
analyzed_databases[pretty_name] = analyzed_database
373+
374+
def download_and_decompress(analyzed_database: dict) -> str:
375+
artifact_name = analyzed_database["artifact_name"]
339376
repository = analyzed_database["repository"]
340377
run_id = analyzed_database["run_id"]
341378
print(f"=== Finding artifact: {artifact_name} ===")
@@ -351,27 +388,40 @@ def download_dca_databases(
351388
artifact_zip_location = download_artifact(
352389
archive_download_url, artifact_name, pat
353390
)
354-
print(f"=== Extracting artifact: {artifact_name} ===")
391+
print(f"=== Decompressing artifact: {artifact_name} ===")
355392
# The database is in a zip file, which contains a tar.gz file with the DB
356393
# First we open the zip file
357394
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
358395
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
396+
# clean up any remnants of previous runs
397+
shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
359398
# And then we extract it to build_dir/artifact_name
360399
zip_ref.extractall(artifact_unzipped_location)
361-
# And then we iterate over the contents of the extracted directory
362-
# and extract the tar.gz files inside it
363-
for entry in os.listdir(artifact_unzipped_location):
364-
artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
365-
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
366-
# And we just untar it to the same directory as the zip file
367-
tar_ref.extractall(artifact_unzipped_location)
368-
database_results[pretty_name] = os.path.join(
369-
artifact_unzipped_location, remove_extension(entry)
370-
)
400+
# And then we extract the language tar.gz file inside it
401+
artifact_tar_location = os.path.join(
402+
artifact_unzipped_location, f"{language}.tar.gz"
403+
)
404+
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
405+
# And we just untar it to the same directory as the zip file
406+
tar_ref.extractall(artifact_unzipped_location)
407+
ret = os.path.join(artifact_unzipped_location, language)
408+
print(f"Decompression complete: {ret}")
409+
return ret
410+
411+
results = run_in_parallel(
412+
download_and_decompress,
413+
list(analyzed_databases.values()),
414+
on_error=lambda db, exc: print(
415+
f"ERROR: Failed to download and decompress {db["artifact_name"]}: {exc}"
416+
),
417+
error_summary=lambda failures: print(
418+
f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
419+
),
420+
)
371421

372-
print(f"\n=== Extracted {len(database_results)} databases ===")
422+
print(f"\n=== Fetched {len(results)} databases ===")
373423

374-
return [(project, database_results[project["name"]]) for project in projects]
424+
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
375425

376426

377427
def get_mad_destination_for_project(config, name: str) -> str:
@@ -422,7 +472,9 @@ def main(config, args) -> None:
422472
case "repo":
423473
extractor_options = config.get("extractor_options", [])
424474
database_results = build_databases_from_projects(
425-
language, extractor_options, projects
475+
language,
476+
extractor_options,
477+
projects,
426478
)
427479
case "dca":
428480
experiment_name = args.dca
@@ -439,7 +491,10 @@ def main(config, args) -> None:
439491
with open(args.pat, "r") as f:
440492
pat = f.read().strip()
441493
database_results = download_dca_databases(
442-
experiment_name, pat, projects
494+
language,
495+
experiment_name,
496+
pat,
497+
projects,
443498
)
444499

445500
# Generate models for all projects
@@ -492,9 +547,9 @@ def main(config, args) -> None:
492547
sys.exit(1)
493548
try:
494549
with open(args.config, "r") as f:
495-
config = json.load(f)
496-
except json.JSONDecodeError as e:
497-
print(f"ERROR: Failed to parse JSON file {args.config}: {e}")
550+
config = yaml.safe_load(f)
551+
except yaml.YAMLError as e:
552+
print(f"ERROR: Failed to parse YAML file {args.config}: {e}")
498553
sys.exit(1)
499554

500555
main(config, args)

rust/bulk_generation_targets.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
strategy: dca
2+
language: rust
3+
destination: rust/ql/lib/ext/generated
4+
# targets must have name specified and corresponding to the name in the DCA suite
5+
# they can optionally specify any of
6+
# with-sinks: false
7+
# with-sources: false
8+
# with-summaries: false
9+
# if a target has a dependency in this same list, it should be listed after that dependency
10+
targets:
11+
- name: rust
12+
- name: libc
13+
- name: log
14+
- name: memchr
15+
- name: once_cell
16+
- name: rand
17+
- name: smallvec
18+
- name: serde
19+
- name: tokio
20+
- name: reqwest
21+
- name: rocket
22+
- name: actix-web
23+
- name: hyper
24+
- name: clap

0 commit comments

Comments
 (0)