1
+ #!/usr/bin/env python3
1
2
"""
2
3
Experimental script for bulk generation of MaD models based on a list of projects.
3
4
7
8
import os .path
8
9
import subprocess
9
10
import sys
10
- from typing import NotRequired , TypedDict , List
11
+ from typing import Required , TypedDict , List , Callable , Optional
11
12
from concurrent .futures import ThreadPoolExecutor , as_completed
12
13
import time
13
14
import argparse
14
- import json
15
- import requests
16
15
import zipfile
17
16
import tarfile
18
- from functools import cmp_to_key
17
+ import shutil
18
+
19
+
20
+ def missing_module (module_name : str ) -> None :
21
+ print (
22
+ f"ERROR: { module_name } is not installed. Please install it with 'pip install { module_name } '."
23
+ )
24
+ sys .exit (1 )
25
+
26
+
27
+ try :
28
+ import yaml
29
+ except ImportError :
30
+ missing_module ("pyyaml" )
31
+
32
+ try :
33
+ import requests
34
+ except ImportError :
35
+ missing_module ("requests" )
19
36
20
37
import generate_mad as mad
21
38
28
45
29
46
30
47
# A project to generate models for
31
- class Project (TypedDict ):
32
- """
33
- Type definition for projects (acquired via a GitHub repo) to model.
34
-
35
- Attributes:
36
- name: The name of the project
37
- git_repo: URL to the git repository
38
- git_tag: Optional Git tag to check out
39
- """
40
-
41
- name : str
42
- git_repo : NotRequired [str ]
43
- git_tag : NotRequired [str ]
44
- with_sinks : NotRequired [bool ]
45
- with_sinks : NotRequired [bool ]
46
- with_summaries : NotRequired [bool ]
48
+ Project = TypedDict (
49
+ "Project" ,
50
+ {
51
+ "name" : Required [str ],
52
+ "git-repo" : str ,
53
+ "git-tag" : str ,
54
+ "with-sinks" : bool ,
55
+ "with-sources" : bool ,
56
+ "with-summaries" : bool ,
57
+ },
58
+ total = False ,
59
+ )
47
60
48
61
49
62
def should_generate_sinks (project : Project ) -> bool :
@@ -63,14 +76,14 @@ def clone_project(project: Project) -> str:
63
76
Shallow clone a project into the build directory.
64
77
65
78
Args:
66
- project: A dictionary containing project information with 'name', 'git_repo ', and optional 'git_tag ' keys.
79
+ project: A dictionary containing project information with 'name', 'git-repo ', and optional 'git-tag ' keys.
67
80
68
81
Returns:
69
82
The path to the cloned project directory.
70
83
"""
71
84
name = project ["name" ]
72
- repo_url = project ["git_repo " ]
73
- git_tag = project .get ("git_tag " )
85
+ repo_url = project ["git-repo " ]
86
+ git_tag = project .get ("git-tag " )
74
87
75
88
# Determine target directory
76
89
target_dir = os .path .join (build_dir , name )
@@ -103,6 +116,39 @@ def clone_project(project: Project) -> str:
103
116
return target_dir
104
117
105
118
119
+ def run_in_parallel [
120
+ T , U
121
+ ](
122
+ func : Callable [[T ], U ],
123
+ items : List [T ],
124
+ * ,
125
+ on_error = lambda item , exc : None ,
126
+ error_summary = lambda failures : None ,
127
+ max_workers = 8 ,
128
+ ) -> List [Optional [U ]]:
129
+ if not items :
130
+ return []
131
+ max_workers = min (max_workers , len (items ))
132
+ results = [None for _ in range (len (items ))]
133
+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
134
+ # Start cloning tasks and keep track of them
135
+ futures = {
136
+ executor .submit (func , item ): index for index , item in enumerate (items )
137
+ }
138
+ # Process results as they complete
139
+ for future in as_completed (futures ):
140
+ index = futures [future ]
141
+ try :
142
+ results [index ] = future .result ()
143
+ except Exception as e :
144
+ on_error (items [index ], e )
145
+ failed = [item for item , result in zip (items , results ) if result is None ]
146
+ if failed :
147
+ error_summary (failed )
148
+ sys .exit (1 )
149
+ return results
150
+
151
+
106
152
def clone_projects (projects : List [Project ]) -> List [tuple [Project , str ]]:
107
153
"""
108
154
Clone all projects in parallel.
@@ -114,40 +160,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
114
160
List of (project, project_dir) pairs in the same order as the input projects
115
161
"""
116
162
start_time = time .time ()
117
- max_workers = min (8 , len (projects )) # Use at most 8 threads
118
- project_dirs_map = {} # Map to store results by project name
119
-
120
- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
121
- # Start cloning tasks and keep track of them
122
- future_to_project = {
123
- executor .submit (clone_project , project ): project for project in projects
124
- }
125
-
126
- # Process results as they complete
127
- for future in as_completed (future_to_project ):
128
- project = future_to_project [future ]
129
- try :
130
- project_dir = future .result ()
131
- project_dirs_map [project ["name" ]] = (project , project_dir )
132
- except Exception as e :
133
- print (f"ERROR: Failed to clone { project ['name' ]} : { e } " )
134
-
135
- if len (project_dirs_map ) != len (projects ):
136
- failed_projects = [
137
- project ["name" ]
138
- for project in projects
139
- if project ["name" ] not in project_dirs_map
140
- ]
141
- print (
142
- f"ERROR: Only { len (project_dirs_map )} out of { len (projects )} projects were cloned successfully. Failed projects: { ', ' .join (failed_projects )} "
143
- )
144
- sys .exit (1 )
145
-
146
- project_dirs = [project_dirs_map [project ["name" ]] for project in projects ]
147
-
163
+ dirs = run_in_parallel (
164
+ clone_project ,
165
+ projects ,
166
+ on_error = lambda project , exc : print (
167
+ f"ERROR: Failed to clone project { project ['name' ]} : { exc } "
168
+ ),
169
+ error_summary = lambda failures : print (
170
+ f"ERROR: Failed to clone { len (failures )} projects: { ', ' .join (p ['name' ] for p in failures )} "
171
+ ),
172
+ )
148
173
clone_time = time .time () - start_time
149
174
print (f"Cloning completed in { clone_time :.2f} seconds" )
150
- return project_dirs
175
+ return list ( zip ( projects , dirs ))
151
176
152
177
153
178
def build_database (
@@ -159,7 +184,7 @@ def build_database(
159
184
Args:
160
185
language: The language for which to build the database (e.g., "rust").
161
186
extractor_options: Additional options for the extractor.
162
- project: A dictionary containing project information with 'name' and 'git_repo ' keys.
187
+ project: A dictionary containing project information with 'name' and 'git-repo ' keys.
163
188
project_dir: Path to the CodeQL database.
164
189
165
190
Returns:
@@ -307,7 +332,10 @@ def pretty_name_from_artifact_name(artifact_name: str) -> str:
307
332
308
333
309
334
def download_dca_databases (
310
- experiment_name : str , pat : str , projects : List [Project ]
335
+ language : str ,
336
+ experiment_name : str ,
337
+ pat : str ,
338
+ projects : List [Project ],
311
339
) -> List [tuple [Project , str | None ]]:
312
340
"""
313
341
Download databases from a DCA experiment.
@@ -318,14 +346,14 @@ def download_dca_databases(
318
346
Returns:
319
347
List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
320
348
"""
321
- database_results = {}
322
349
print ("\n === Finding projects ===" )
323
350
response = get_json_from_github (
324
351
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{ experiment_name } /reports/downloads.json" ,
325
352
pat ,
326
353
)
327
354
targets = response ["targets" ]
328
355
project_map = {project ["name" ]: project for project in projects }
356
+ analyzed_databases = {}
329
357
for data in targets .values ():
330
358
downloads = data ["downloads" ]
331
359
analyzed_database = downloads ["analyzed_database" ]
@@ -336,6 +364,15 @@ def download_dca_databases(
336
364
print (f"Skipping { pretty_name } as it is not in the list of projects" )
337
365
continue
338
366
367
+ if pretty_name in analyzed_databases :
368
+ print (
369
+ f"Skipping previous database { analyzed_databases [pretty_name ]['artifact_name' ]} for { pretty_name } "
370
+ )
371
+
372
+ analyzed_databases [pretty_name ] = analyzed_database
373
+
374
+ def download_and_decompress (analyzed_database : dict ) -> str :
375
+ artifact_name = analyzed_database ["artifact_name" ]
339
376
repository = analyzed_database ["repository" ]
340
377
run_id = analyzed_database ["run_id" ]
341
378
print (f"=== Finding artifact: { artifact_name } ===" )
@@ -351,27 +388,40 @@ def download_dca_databases(
351
388
artifact_zip_location = download_artifact (
352
389
archive_download_url , artifact_name , pat
353
390
)
354
- print (f"=== Extracting artifact: { artifact_name } ===" )
391
+ print (f"=== Decompressing artifact: { artifact_name } ===" )
355
392
# The database is in a zip file, which contains a tar.gz file with the DB
356
393
# First we open the zip file
357
394
with zipfile .ZipFile (artifact_zip_location , "r" ) as zip_ref :
358
395
artifact_unzipped_location = os .path .join (build_dir , artifact_name )
396
+ # clean up any remnants of previous runs
397
+ shutil .rmtree (artifact_unzipped_location , ignore_errors = True )
359
398
# And then we extract it to build_dir/artifact_name
360
399
zip_ref .extractall (artifact_unzipped_location )
361
- # And then we iterate over the contents of the extracted directory
362
- # and extract the tar.gz files inside it
363
- for entry in os .listdir (artifact_unzipped_location ):
364
- artifact_tar_location = os .path .join (artifact_unzipped_location , entry )
365
- with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
366
- # And we just untar it to the same directory as the zip file
367
- tar_ref .extractall (artifact_unzipped_location )
368
- database_results [pretty_name ] = os .path .join (
369
- artifact_unzipped_location , remove_extension (entry )
370
- )
400
+ # And then we extract the language tar.gz file inside it
401
+ artifact_tar_location = os .path .join (
402
+ artifact_unzipped_location , f"{ language } .tar.gz"
403
+ )
404
+ with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
405
+ # And we just untar it to the same directory as the zip file
406
+ tar_ref .extractall (artifact_unzipped_location )
407
+ ret = os .path .join (artifact_unzipped_location , language )
408
+ print (f"Decompression complete: { ret } " )
409
+ return ret
410
+
411
+ results = run_in_parallel (
412
+ download_and_decompress ,
413
+ list (analyzed_databases .values ()),
414
+ on_error = lambda db , exc : print (
415
+ f"ERROR: Failed to download and decompress { db ["artifact_name" ]} : { exc } "
416
+ ),
417
+ error_summary = lambda failures : print (
418
+ f"ERROR: Failed to download { len (failures )} databases: { ', ' .join (item [0 ] for item in failures )} "
419
+ ),
420
+ )
371
421
372
- print (f"\n === Extracted { len (database_results )} databases ===" )
422
+ print (f"\n === Fetched { len (results )} databases ===" )
373
423
374
- return [(project , database_results [ project [ "name" ]] ) for project in projects ]
424
+ return [(project_map [ n ], r ) for n , r in zip ( analyzed_databases , results ) ]
375
425
376
426
377
427
def get_mad_destination_for_project (config , name : str ) -> str :
@@ -422,7 +472,9 @@ def main(config, args) -> None:
422
472
case "repo" :
423
473
extractor_options = config .get ("extractor_options" , [])
424
474
database_results = build_databases_from_projects (
425
- language , extractor_options , projects
475
+ language ,
476
+ extractor_options ,
477
+ projects ,
426
478
)
427
479
case "dca" :
428
480
experiment_name = args .dca
@@ -439,7 +491,10 @@ def main(config, args) -> None:
439
491
with open (args .pat , "r" ) as f :
440
492
pat = f .read ().strip ()
441
493
database_results = download_dca_databases (
442
- experiment_name , pat , projects
494
+ language ,
495
+ experiment_name ,
496
+ pat ,
497
+ projects ,
443
498
)
444
499
445
500
# Generate models for all projects
@@ -492,9 +547,9 @@ def main(config, args) -> None:
492
547
sys .exit (1 )
493
548
try :
494
549
with open (args .config , "r" ) as f :
495
- config = json . load (f )
496
- except json . JSONDecodeError as e :
497
- print (f"ERROR: Failed to parse JSON file { args .config } : { e } " )
550
+ config = yaml . safe_load (f )
551
+ except yaml . YAMLError as e :
552
+ print (f"ERROR: Failed to parse YAML file { args .config } : { e } " )
498
553
sys .exit (1 )
499
554
500
555
main (config , args )
0 commit comments