-
Notifications
You must be signed in to change notification settings - Fork 189
/
Copy pathfirework.py
1402 lines (1190 loc) · 50.6 KB
/
firework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This module contains some of the most central FireWorks classes.
- A Workflow is a sequence of FireWorks as a DAG (directed acyclic graph).
- A Firework defines a workflow step and contains one or more Firetasks along with its Launches.
- A Launch describes the run of a Firework on a computing resource.
- A FiretaskBase defines the contract for tasks that run within a Firework (Firetasks).
- A FWAction encapsulates the output of a Firetask and tells FireWorks what to do next after a job completes.
"""
from __future__ import annotations
import abc
import os
import pprint
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from typing import Any, Iterator, NoReturn, Sequence
from monty.io import reverse_readline, zopen
from monty.os.path import zpath
from fireworks.core.fworker import FWorker
from fireworks.fw_config import EXCEPT_DETAILS_ON_RERUN, TRACKER_LINES
from fireworks.fw_config import NEGATIVE_FWID_CTR as NEGATIVE_FWID_CTR # noqa: PLC0414
from fireworks.utilities.dict_mods import apply_mod
from fireworks.utilities.fw_serializers import FWSerializable, recursive_deserialize, recursive_serialize, serialize_fw
from fireworks.utilities.fw_utilities import NestedClassGetter, get_my_host, get_my_ip
__author__ = "Anubhav Jain"
__credits__ = "Shyue Ping Ong"
__copyright__ = "Copyright 2013, The Materials Project"
__maintainer__ = "Anubhav Jain"
__email__ = "[email protected]"
__date__ = "Feb 5, 2013"
class FiretaskBase(defaultdict, FWSerializable, abc.ABC):
"""
FiretaskBase is used like an abstract class that defines a computing task
(Firetask). All Firetasks should inherit from FiretaskBase.
You can set parameters of a Firetask like you'd use a dict.
"""
required_params = None # list of str of required parameters to check for consistency upon init
# if set to a list of str, only required and optional kwargs are allowed; consistency checked upon init
optional_params = None
def __init__(self, *args, **kwargs) -> None:
dict.__init__(self, *args, **kwargs)
required_params = self.required_params or []
for k in required_params:
if k not in self:
raise RuntimeError(f"{self}: Required parameter {k} not specified!")
if self.optional_params is not None:
allowed_params = required_params + self.optional_params
for k in kwargs:
if k not in allowed_params:
raise RuntimeError(
f"Invalid keyword argument specified for: {self.__class__}. "
f"You specified: {k}. Allowed values are: {allowed_params}."
)
@abc.abstractmethod
def run_task(self, fw_spec) -> NoReturn:
"""
This method gets called when the Firetask is run. It can take in a
Firework spec, perform some task using that data, and then return an
output in the form of a FWAction.
Args:
fw_spec (dict): A Firework spec. This comes from the master spec.
In addition, this spec contains a special "_fw_env" key that
contains the env settings of the FWorker calling this method.
This provides for abstracting out certain commands or
settings. For example, "foo" may be named "foo1" in resource
1 and "foo2" in resource 2. The FWorker env can specify {
"foo": "foo1"}, which maps an abstract variable "foo" to the
relevant "foo1" or "foo2". You can then write a task that
uses fw_spec["_fw_env"]["foo"] that will work across all
these multiple resources.
Returns:
(FWAction)
"""
raise NotImplementedError("You must have a run_task implemented!")
@serialize_fw
@recursive_serialize
def to_dict(self):
return dict(self)
@classmethod
@recursive_deserialize
def from_dict(cls, m_dict):
return cls(m_dict)
def __repr__(self) -> str:
return f"<{self.fw_name}>:{dict(self)}"
# not strictly needed here for pickle/unpickle, but complements __setstate__
def __getstate__(self):
return self.to_dict()
# added to support pickle/unpickle
def __setstate__(self, state):
self.__init__(state)
# added to support pickle/unpickle
def __reduce__(self):
t = defaultdict.__reduce__(self)
return (t[0], (self.to_dict(),), t[2], t[3], t[4])
class FWAction(FWSerializable):
"""
A FWAction encapsulates the output of a Firetask (it is returned by a Firetask after the
Firetask completes). The FWAction allows a user to store rudimentary output data as well
as return commands that alter the workflow.
"""
def __init__(
self,
stored_data=None,
exit=False,
update_spec=None,
mod_spec=None,
additions=None,
detours=None,
defuse_children=False,
defuse_workflow=False,
propagate=False,
) -> None:
"""
Args:
stored_data (dict): data to store from the run. Does not affect the operation of FireWorks.
exit (bool): if set to True, any remaining Firetasks within the same Firework are skipped.
update_spec (dict): specifies how to update the child FW's spec
mod_spec ([dict]): update the child FW's spec using the DictMod language (more flexible
than update_spec)
additions ([Workflow]): a list of WFs/FWs to add as children
detours ([Workflow]): a list of WFs/FWs to add as children (they will inherit the
current FW's children)
defuse_children (bool): defuse all the original children of this Firework
defuse_workflow (bool): defuse all incomplete steps of this workflow
propagate (bool): apply any update_spec and mod_spec modifications
not only to direct children, but to all dependent FireWorks
down to the Workflow's leaves.
"""
mod_spec = mod_spec if mod_spec is not None else []
additions = additions if additions is not None else []
detours = detours if detours is not None else []
self.stored_data = stored_data or {}
self.exit = exit
self.update_spec = update_spec or {}
self.mod_spec = mod_spec if isinstance(mod_spec, (list, tuple)) else [mod_spec]
self.additions = additions if isinstance(additions, (list, tuple)) else [additions]
self.detours = detours if isinstance(detours, (list, tuple)) else [detours]
self.defuse_children = defuse_children
self.defuse_workflow = defuse_workflow
self.propagate = propagate
@recursive_serialize
def to_dict(self):
return {
"stored_data": self.stored_data,
"exit": self.exit,
"update_spec": self.update_spec,
"mod_spec": self.mod_spec,
"additions": self.additions,
"detours": self.detours,
"defuse_children": self.defuse_children,
"defuse_workflow": self.defuse_workflow,
"propagate": self.propagate,
}
@classmethod
@recursive_deserialize
def from_dict(cls, m_dict):
d = m_dict
additions = [Workflow.from_dict(f) for f in d["additions"]]
detours = [Workflow.from_dict(f) for f in d["detours"]]
return FWAction(
d["stored_data"],
d["exit"],
d["update_spec"],
d["mod_spec"],
additions,
detours,
d["defuse_children"],
d.get("defuse_workflow", False),
d.get("propagate", False),
)
@property
def skip_remaining_tasks(self):
"""
If the FWAction gives any dynamic action, we skip the subsequent Firetasks.
Returns:
bool
"""
return self.exit or self.detours or self.additions or self.defuse_children or self.defuse_workflow
def __str__(self) -> str:
return "FWAction\n" + pprint.pformat(self.to_dict())
class Firework(FWSerializable):
"""A Firework is a workflow step and might be contain several Firetasks."""
STATE_RANKS = {
"ARCHIVED": -2,
"FIZZLED": -1,
"DEFUSED": 0,
"PAUSED": 0,
"WAITING": 1,
"READY": 2,
"RESERVED": 3,
"RUNNING": 4,
"COMPLETED": 5,
}
# note: if you modify this signature, you must also modify LazyFirework
def __init__(
self,
tasks,
spec=None,
name=None,
launches=None,
archived_launches=None,
state="WAITING",
created_on=None,
fw_id=None,
parents=None,
updated_on=None,
) -> None:
"""
Args:
tasks (Firetask or [Firetask]): a list of Firetasks to run in sequence.
spec (dict): specification of the job to run. Used by the Firetask.
launches ([Launch]): a list of Launch objects of this Firework.
archived_launches ([Launch]): a list of archived Launch objects of this Firework.
state (str): the state of the FW (e.g. WAITING, RUNNING, COMPLETED, ARCHIVED)
created_on (datetime): time of creation
fw_id (int): an identification number for this Firework.
parents (Firework or [Firework]): list of parent FWs this FW depends on.
updated_on (datetime): last time the STATE was updated.
"""
self.tasks = tasks if isinstance(tasks, (list, tuple)) else [tasks]
self.spec = spec.copy() if spec else {}
self.name = name or "Unnamed FW" # do it this way to prevent None
# names
if fw_id is not None:
self.fw_id = fw_id
else:
global NEGATIVE_FWID_CTR # noqa: PLW0603
NEGATIVE_FWID_CTR -= 1
self.fw_id = NEGATIVE_FWID_CTR
self.launches = launches or []
self.archived_launches = archived_launches or []
self.created_on = created_on or datetime.utcnow()
self.updated_on = updated_on or datetime.utcnow()
parents = [parents] if isinstance(parents, Firework) else parents
self.parents = parents or []
self._state = state
@property
def state(self):
"""
Returns:
str: The current state of the Firework.
"""
return self._state
@state.setter
def state(self, state) -> None:
"""
Setter for the FW state, which triggers updated_on change.
Args:
state (str): the state to set for the FW
"""
self._state = state
self.updated_on = datetime.utcnow()
@recursive_serialize
def to_dict(self):
# put tasks in a special location of the spec
spec = self.spec
spec["_tasks"] = [t.to_dict() for t in self.tasks]
m_dict = {"spec": spec, "fw_id": self.fw_id, "created_on": self.created_on, "updated_on": self.updated_on}
# only serialize these fields if non-empty
if len(list(self.launches)) > 0:
m_dict["launches"] = self.launches
if len(list(self.archived_launches)) > 0:
m_dict["archived_launches"] = self.archived_launches
# keep export of new FWs to files clean
if self.state != "WAITING":
m_dict["state"] = self.state
m_dict["name"] = self.name
return m_dict
def _rerun(self) -> None:
"""
Moves all Launches to archived Launches and resets the state to 'WAITING'. The Firework
can thus be re-run even if it was Launched in the past. This method should be called by
a Workflow because a refresh is needed after calling this method.
"""
if self.state == "FIZZLED":
if len(self.launches) == 0:
self.spec.pop("_exception_details", None)
else:
last_launch = self.launches[-1]
if (
EXCEPT_DETAILS_ON_RERUN
and last_launch.action
and last_launch.action.stored_data.get("_exception", {}).get("_details")
):
# add the exception details to the spec
self.spec["_exception_details"] = last_launch.action.stored_data["_exception"]["_details"]
else:
# clean spec from stale details
self.spec.pop("_exception_details", None)
self.archived_launches.extend(self.launches)
self.archived_launches = list(set(self.archived_launches)) # filter duplicates
self.launches = []
self.state = "WAITING"
def to_db_dict(self):
"""Return firework dict with updated launches and state."""
m_dict = self.to_dict()
# the launches are stored separately
m_dict["launches"] = [launch.launch_id for launch in self.launches]
# the archived launches are stored separately
m_dict["archived_launches"] = [launch.launch_id for launch in self.archived_launches]
m_dict["state"] = self.state
return m_dict
@classmethod
@recursive_deserialize
def from_dict(cls, m_dict):
tasks = m_dict["spec"]["_tasks"]
launches = [Launch.from_dict(tmp) for tmp in m_dict.get("launches", [])]
archived_launches = [Launch.from_dict(tmp) for tmp in m_dict.get("archived_launches", [])]
fw_id = m_dict.get("fw_id", -1)
state = m_dict.get("state", "WAITING")
created_on = m_dict.get("created_on")
updated_on = m_dict.get("updated_on")
name = m_dict.get("name", None)
return Firework(
tasks, m_dict["spec"], name, launches, archived_launches, state, created_on, fw_id, updated_on=updated_on
)
def __str__(self) -> str:
return f"Firework object: (id: {int(self.fw_id)} , name: {self.fw_name})"
def __iter__(self) -> Iterator[FiretaskBase]:
return self.tasks.__iter__()
def __len__(self) -> int:
return len(self.tasks)
def __getitem__(self, idx: int) -> FiretaskBase:
return self.tasks[idx]
class Tracker(FWSerializable):
"""A Tracker monitors a file and returns the last N lines for updating the Launch object."""
MAX_TRACKER_LINES = 1000
def __init__(self, filename, nlines=TRACKER_LINES, content="", allow_zipped=False) -> None:
"""
Args:
filename (str)
nlines (int): number of lines to track
content (str): tracked content
allow_zipped (bool): if set, will look for zipped file.
"""
if nlines > self.MAX_TRACKER_LINES:
raise ValueError(f"Tracker only supports a maximum of {self.MAX_TRACKER_LINES} lines; you put {nlines}.")
self.filename = filename
self.nlines = nlines
self.content = content
self.allow_zipped = allow_zipped
def track_file(self, launch_dir=None):
"""
Reads the monitored file and returns back the last N lines.
Args:
launch_dir (str): directory where job was launched in case of relative filename
Returns:
str: the content(last N lines)
"""
m_file = self.filename
if launch_dir and not os.path.isabs(self.filename):
m_file = os.path.join(launch_dir, m_file)
lines = []
if self.allow_zipped:
m_file = zpath(m_file)
if os.path.exists(m_file):
with zopen(m_file, "rt", errors="surrogateescape") as f:
for line in reverse_readline(f):
lines.append(line)
if len(lines) == self.nlines:
break
self.content = "\n".join(reversed(lines))
return self.content
def to_dict(self):
m_dict = {"filename": self.filename, "nlines": self.nlines, "allow_zipped": self.allow_zipped}
if self.content:
m_dict["content"] = self.content
return m_dict
@classmethod
def from_dict(cls, m_dict):
return Tracker(
m_dict["filename"], m_dict["nlines"], m_dict.get("content", ""), m_dict.get("allow_zipped", False)
)
def __str__(self) -> str:
return f"### Filename: {self.filename}\n{self.content}"
class Launch(FWSerializable):
"""A Launch encapsulates data about a specific run of a Firework on a computing resource."""
def __init__(
self,
state,
launch_dir,
fworker=None,
host=None,
ip=None,
trackers=None,
action=None,
state_history=None,
launch_id=None,
fw_id=None,
) -> None:
"""
Args:
state (str): the state of the Launch (e.g. RUNNING, COMPLETED)
launch_dir (str): the directory where the Launch takes place
fworker (FWorker): The FireWorker running the Launch
host (str): the hostname where the launch took place (set automatically if None)
ip (str): the IP address where the launch took place (set automatically if None)
trackers ([Tracker]): File Trackers for this Launch
action (FWAction): the output of the Launch
state_history ([dict]): a history of all states of the Launch and when they occurred
launch_id (int): launch_id set by the LaunchPad
fw_id (int): id of the Firework this Launch is running.
"""
if state not in Firework.STATE_RANKS:
raise ValueError(f"Invalid launch {state=}")
self.launch_dir = launch_dir
self.fworker = fworker or FWorker()
self.host = host or get_my_host()
self.ip = ip or get_my_ip()
self.trackers = trackers or []
self.action = action or None
self.state_history = state_history or []
self.state = state
self.launch_id = launch_id
self.fw_id = fw_id
def touch_history(self, update_time=None, checkpoint=None) -> None:
"""
Updates the update_on field of the state history of a Launch. Used to ping that a Launch
is still alive.
Args:
update_time (datetime)
"""
update_time = update_time or datetime.utcnow()
if checkpoint:
self.state_history[-1]["checkpoint"] = checkpoint
self.state_history[-1]["updated_on"] = update_time
def set_reservation_id(self, reservation_id) -> None:
"""
Adds the job_id to the reservation.
Args:
reservation_id (int or str): the id of the reservation (e.g., queue reservation)
"""
for data in self.state_history:
if data["state"] == "RESERVED" and "reservation_id" not in data:
data["reservation_id"] = str(reservation_id)
break
@property
def state(self):
"""
Returns:
str: The current state of the Launch.
"""
return self._state
@state.setter
def state(self, state) -> None:
"""
Setter for the Launch's state. Automatically triggers an update to state_history.
Args:
state (str): the state to set for the Launch
"""
self._state = state
self._update_state_history(state)
@property
def time_start(self):
"""
Returns:
datetime: the time the Launch started RUNNING.
"""
return self._get_time("RUNNING")
@property
def time_end(self):
"""
Returns:
datetime: the time the Launch was COMPLETED or FIZZLED.
"""
return self._get_time(["COMPLETED", "FIZZLED"])
@property
def time_reserved(self):
"""
Returns:
datetime: the time the Launch was RESERVED in the queue.
"""
return self._get_time("RESERVED")
@property
def last_pinged(self):
"""
Returns:
datetime: the time the Launch last pinged a heartbeat that it was still running.
"""
return self._get_time("RUNNING", True)
@property
def runtime_secs(self):
"""
Returns:
int: the number of seconds that the Launch ran for.
"""
start = self.time_start
end = self.time_end
if start and end:
return (end - start).total_seconds()
return None
@property
def reservedtime_secs(self):
"""
Returns:
int: number of seconds the Launch was stuck as RESERVED in a queue.
"""
start = self.time_reserved
if start:
end = self.time_start or datetime.utcnow()
return (end - start).total_seconds()
return None
@recursive_serialize
def to_dict(self):
return {
"fworker": self.fworker,
"fw_id": self.fw_id,
"launch_dir": self.launch_dir,
"host": self.host,
"ip": self.ip,
"trackers": self.trackers,
"action": self.action,
"state": self.state,
"state_history": self.state_history,
"launch_id": self.launch_id,
}
@recursive_serialize
def to_db_dict(self):
m_d = self.to_dict()
m_d["time_start"] = self.time_start
m_d["time_end"] = self.time_end
m_d["runtime_secs"] = self.runtime_secs
if self.reservedtime_secs:
m_d["reservedtime_secs"] = self.reservedtime_secs
return m_d
@classmethod
@recursive_deserialize
def from_dict(cls, m_dict):
fworker = FWorker.from_dict(m_dict["fworker"]) if m_dict["fworker"] else None
action = FWAction.from_dict(m_dict["action"]) if m_dict.get("action") else None
trackers = [Tracker.from_dict(f) for f in m_dict["trackers"]] if m_dict.get("trackers") else None
return Launch(
m_dict["state"],
m_dict["launch_dir"],
fworker,
m_dict["host"],
m_dict["ip"],
trackers,
action,
m_dict["state_history"],
m_dict["launch_id"],
m_dict["fw_id"],
)
def _update_state_history(self, state) -> None:
"""
Internal method to update the state history whenever the Launch state is modified.
Args:
state (str)
"""
if len(self.state_history) > 0:
last_state = self.state_history[-1]["state"]
last_checkpoint = self.state_history[-1].get("checkpoint", None)
else:
last_state, last_checkpoint = None, None
if state != last_state:
now_time = datetime.utcnow()
new_history_entry = {"state": state, "created_on": now_time}
if state != "COMPLETED" and last_checkpoint:
new_history_entry.update(checkpoint=last_checkpoint)
self.state_history.append(new_history_entry)
if state in ["RUNNING", "RESERVED"]:
self.touch_history() # add updated_on key
def _get_time(self, states, use_update_time=False):
"""
Internal method to help get the time of various events in the Launch (e.g. RUNNING)
from the state history.
Args:
states (list/tuple): match one of these states
use_update_time (bool): use the "updated_on" time rather than "created_on"
Returns:
(datetime)
"""
states = states if isinstance(states, (list, tuple)) else [states]
for data in self.state_history:
if data["state"] in states:
if use_update_time:
return data["updated_on"]
return data["created_on"]
return None
class Workflow(FWSerializable):
"""A Workflow connects a group of FireWorks in an execution order."""
class Links(dict, FWSerializable):
"""An inner class for storing the DAG links between FireWorks."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
for k, v in list(self.items()):
if not isinstance(v, (list, tuple)):
self[k] = [v] # v must be list
self[k] = [x.fw_id if hasattr(x, "fw_id") else x for x in self[k]]
if not isinstance(k, int):
if hasattr(k, "fw_id"): # maybe it's a String?
self[k.fw_id] = self[k]
else: # maybe it's a String?
try:
self[int(k)] = self[k] # k must be int
except Exception:
pass # garbage input
del self[k]
@property
def nodes(self):
"""Return list of all nodes."""
all_nodes = list(self)
for v in self.values():
all_nodes.extend(v)
return list(set(all_nodes))
@property
def parent_links(self):
"""
Return a dict of child and its parents.
Note: if performance of parent_links becomes an issue, override delitem/setitem to
update parent_links
"""
child_parents = defaultdict(list)
for parent, children in self.items():
for child in children:
child_parents[child].append(parent)
return dict(child_parents)
def to_dict(self):
"""
Convert to str form for Mongo, which cannot have int keys.
Returns:
dict
"""
return {str(k): v for (k, v) in self.items()}
def to_db_dict(self):
"""
Convert to str form for Mongo, which cannot have int keys .
Returns:
dict
"""
return {
"links": {str(k): v for (k, v) in self.items()},
"parent_links": {str(k): v for (k, v) in self.parent_links.items()},
"nodes": self.nodes,
}
@classmethod
def from_dict(cls, m_dict):
return Workflow.Links(m_dict)
def __setstate__(self, state):
for k, v in state:
self[k] = v
def __reduce__(self):
"""
To support Pickling of inner classes (for multi-job launcher's multiprocessing).
Return a class which can return this class when called with the appropriate tuple of
arguments.
"""
state = list(self.items())
return (
NestedClassGetter(),
(
Workflow,
self.__class__.__name__,
),
state,
)
def __init__(
self,
fireworks: Sequence[Firework],
links_dict: dict[int, list[int]] | None = None,
name: str | None = None,
metadata: dict[str, Any] | None = None,
created_on: datetime | None = None,
updated_on: datetime | None = None,
fw_states: dict[int, str] | None = None,
) -> None:
"""
Args:
fireworks ([Firework]): all FireWorks in this workflow.
links_dict (dict): links between the FWs as (parent_id):[(child_id1, child_id2)]
name (str): name of workflow.
metadata (dict): metadata for this Workflow.
created_on (datetime): time of creation
updated_on (datetime): time of update
fw_states (dict): leave this alone unless you are purposefully creating a Lazy-style WF.
"""
name = name or "unnamed WF" # prevent None names
links_dict = links_dict or {}
# main dict containing mapping of an id to a Firework object
self.id_fw: dict[int, Firework] = {}
for fw in fireworks:
if fw.fw_id in self.id_fw:
raise ValueError("FW ids must be unique!")
self.id_fw[fw.fw_id] = fw
if fw.fw_id not in links_dict and fw not in links_dict:
links_dict[fw.fw_id] = []
self.links = Workflow.Links(links_dict)
# add depends on
for fw in fireworks:
for pfw in fw.parents:
if pfw.fw_id not in self.links:
raise ValueError(
f"FW_id: {fw.fw_id} defines a dependent link to FW_id: {pfw.fw_id}, but the latter was not "
"added to the workflow!"
)
if fw.fw_id not in self.links[pfw.fw_id]:
self.links[pfw.fw_id].append(fw.fw_id)
self.name = name
# sanity check: make sure the set of nodes from the links_dict is equal to the set
# of nodes from id_fw
if set(self.links.nodes) != set(map(int, self.id_fw)):
raise ValueError("Specified links don't match given FW")
if len(self.links.nodes) == 0:
raise ValueError("Workflow cannot be empty (must contain at least 1 FW)")
self.metadata = metadata or {}
self.created_on = created_on or datetime.utcnow()
self.updated_on = updated_on or datetime.utcnow()
# dict containing mapping of an id to a firework state. The states are stored locally and
# redundantly for speed purpose
self.fw_states = fw_states or {key: val.state for key, val in self.id_fw.items()}
@property
def fws(self) -> list[Firework]:
"""Return list of all fireworks."""
return list(self.id_fw.values())
def __iter__(self) -> Iterator[Firework]:
"""Iterate over all fireworks."""
return self.id_fw.values().__iter__()
def __len__(self) -> int:
return len(self.id_fw)
def __getitem__(self, idx: int) -> Firework:
return list(self.id_fw.values())[idx]
@property
def state(self) -> str:
"""
Returns:
state (str): state of workflow.
"""
m_state = "READY"
# states = [fw.state for fw in self.fws]
states = self.fw_states.values()
leaf_fw_ids = self.leaf_fw_ids # to save recalculating this
leaf_states = (self.fw_states[fw_id] for fw_id in leaf_fw_ids)
if all(s == "COMPLETED" for s in leaf_states):
m_state = "COMPLETED"
elif all(s == "ARCHIVED" for s in states):
m_state = "ARCHIVED"
elif any(s == "DEFUSED" for s in states):
m_state = "DEFUSED"
elif any(s == "PAUSED" for s in states):
m_state = "PAUSED"
elif any(s == "FIZZLED" for s in states):
fizzled_ids = (fw_id for fw_id, state in self.fw_states.items() if state == "FIZZLED")
for fizzled_id in fizzled_ids:
# If a fizzled fw is a leaf fw, then the workflow is fizzled
if (
fizzled_id in leaf_fw_ids
or
# Otherwise all children must be ok with the fizzled parent
not all(
self.id_fw[child_id].spec.get("_allow_fizzled_parents", False)
for child_id in self.links[fizzled_id]
)
):
m_state = "FIZZLED"
break
else:
m_state = "RUNNING"
elif any(s == "COMPLETED" for s in states) or any(s == "RUNNING" for s in states):
m_state = "RUNNING"
elif any(s == "RESERVED" for s in states):
m_state = "RESERVED"
return m_state
def apply_action(self, action: FWAction, fw_id: int) -> list[int]:
"""
Apply a FWAction on a Firework in the Workflow.
Args:
action (FWAction): action to apply
fw_id (int): id of Firework on which to apply the action
Returns:
list[int]: list of Firework ids that were updated or new.
"""
updated_ids = []
# note: update specs before inserting additions to give user more control
# see: https://github.com/materialsproject/fireworks/pull/407
# update the spec of the children FireWorks
if action.update_spec and action.propagate:
# Traverse whole sub-workflow down to leaves.
visited_cfid = set() # avoid double-updating for diamond deps
def recursive_update_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
self.id_fw[cfid].spec.update(action.update_spec)
updated_ids.append(cfid)
recursive_update_spec(cfid)
recursive_update_spec(fw_id)
elif action.update_spec:
# Update only direct children.
# Kept original code here for "backwards readability".
for cfid in self.links[fw_id]:
self.id_fw[cfid].spec.update(action.update_spec)
updated_ids.append(cfid)
# update the spec of the children FireWorks using DictMod language
if action.mod_spec and action.propagate:
visited_cfid = set()
def recursive_mod_spec(fw_id) -> None:
for cfid in self.links[fw_id]:
if cfid not in visited_cfid:
visited_cfid.add(cfid)
for mod in action.mod_spec:
apply_mod(mod, self.id_fw[cfid].spec)
updated_ids.append(cfid)
recursive_mod_spec(cfid)
recursive_mod_spec(fw_id)
elif action.mod_spec:
for cfid in self.links[fw_id]:
for mod in action.mod_spec:
apply_mod(mod, self.id_fw[cfid].spec)
updated_ids.append(cfid) # seems to me the indentation had been wrong here
# defuse children
if action.defuse_children:
for cfid in self.links[fw_id]:
self.id_fw[cfid].state = "DEFUSED"
self.fw_states[cfid] = "DEFUSED"
updated_ids.append(cfid)
# defuse workflow
if action.defuse_workflow:
for fw_id in self.links.nodes:
if self.id_fw[fw_id].state not in ["FIZZLED", "COMPLETED"]:
self.id_fw[fw_id].state = "DEFUSED"
self.fw_states[fw_id] = "DEFUSED"
updated_ids.append(fw_id)
# add detour FireWorks. This should be done *before* additions
if action.detours:
for wf in action.detours:
new_updates = self.append_wf(wf, [fw_id], detour=True, pull_spec_mods=False)
if len(set(updated_ids).intersection(new_updates)) > 0:
raise ValueError("Cannot use duplicated fw_ids when dynamically detouring workflows!")
updated_ids.extend(new_updates)
# add additional FireWorks
if action.additions:
for wf in action.additions:
new_updates = self.append_wf(wf, [fw_id], detour=False, pull_spec_mods=False)
if len(set(updated_ids).intersection(new_updates)) > 0:
raise ValueError("Cannot use duplicated fw_ids when dynamically adding workflows!")
updated_ids.extend(new_updates)
return list(set(updated_ids))
def rerun_fw(self, fw_id, updated_ids=None):
"""
Archives the launches of a Firework so that it can be re-run.
Args:
fw_id (int): id of firework to rerun
updated_ids (set(int)): set of fireworks id to rerun
Returns:
list[int]: list of Firework ids that were updated.
"""
updated_ids = updated_ids or set()
m_fw = self.id_fw[fw_id]
m_fw._rerun()
updated_ids.add(fw_id)
# refresh the states of the current fw before rerunning the children
# so that they get the correct state of the parent.
updated_ids.union(self.refresh(fw_id, updated_ids))