11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22
3- from datetime import timedelta
4-
53import fastremap
64import numpy as np
75from pychunkedgraph .graph import ChunkedGraph
8- from pychunkedgraph .graph .attributes import Connectivity
6+ from pychunkedgraph .graph .attributes import Connectivity , Hierarchy
97from pychunkedgraph .graph .utils import serializers
108
11- from .utils import exists_as_parent , get_parent_timestamps
9+ from .utils import exists_as_parent , get_end_timestamps , get_parent_timestamps
10+
11+ CHILDREN = {}
1212
1313
1414def update_cross_edges (
15- cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , timestamps : set , earliest_ts
15+ cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , node_end_ts , timestamps : set
1616) -> list :
1717 """
1818 Helper function to update a single L2 ID.
@@ -27,13 +27,15 @@ def update_cross_edges(
2727 assert not exists_as_parent (cg , node , edges [:, 0 ])
2828 return rows
2929
30- partner_parent_ts_d = get_parent_timestamps (cg , edges [:, 1 ])
30+ partner_parent_ts_d = get_parent_timestamps (cg , np . unique ( edges [:, 1 ]) )
3131 for v in partner_parent_ts_d .values ():
3232 timestamps .update (v )
3333
3434 for ts in sorted (timestamps ):
35- if ts < earliest_ts :
36- ts = earliest_ts
35+ if ts < node_ts :
36+ continue
37+ if ts > node_end_ts :
38+ break
3739 val_dict = {}
3840 svs = edges [:, 1 ]
3941 parents = cg .get_parents (svs , time_stamp = ts )
@@ -51,21 +53,22 @@ def update_cross_edges(
5153 return rows
5254
5355
54- def update_nodes (cg : ChunkedGraph , nodes ) -> list :
55- nodes_ts = cg .get_node_timestamps (nodes , return_numpy = False , normalize = True )
56- earliest_ts = cg .get_earliest_timestamp ()
56+ def update_nodes (cg : ChunkedGraph , nodes , nodes_ts , children_map = None ) -> list :
57+ if children_map is None :
58+ children_map = CHILDREN
59+ end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map )
5760 timestamps_d = get_parent_timestamps (cg , nodes )
5861 cx_edges_d = cg .get_atomic_cross_edges (nodes )
5962 rows = []
60- for node , node_ts in zip (nodes , nodes_ts ):
63+ for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
6164 if cg .get_parent (node ) is None :
62- # invalid id caused by failed ingest task
65+ # invalid id caused by failed ingest task / edits
6366 continue
6467 _cx_edges_d = cx_edges_d .get (node , {})
6568 if not _cx_edges_d :
6669 continue
6770 _rows = update_cross_edges (
68- cg , node , _cx_edges_d , node_ts , timestamps_d [node ], earliest_ts
71+ cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d [node ]
6972 )
7073 rows .extend (_rows )
7174 return rows
@@ -76,10 +79,26 @@ def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2):
7679 Iterate over all L2 IDs in a chunk and update their cross chunk edges,
7780 within the periods they were valid/active.
7881 """
82+ global CHILDREN
83+
7984 x , y , z = chunk_coords
8085 chunk_id = cg .get_chunk_id (layer = layer , x = x , y = y , z = z )
8186 cg .copy_fake_edges (chunk_id )
8287 rr = cg .range_read_chunk (chunk_id )
83- nodes = list (rr .keys ())
84- rows = update_nodes (cg , nodes )
88+
89+ nodes = []
90+ nodes_ts = []
91+ earliest_ts = cg .get_earliest_timestamp ()
92+ for k , v in rr .items ():
93+ nodes .append (k )
94+ CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
95+ ts = v [Hierarchy .Child ][0 ].timestamp
96+ nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
97+
98+ if len (nodes ):
99+ assert len (CHILDREN ) > 0 , (nodes , CHILDREN )
100+ else :
101+ return
102+
103+ rows = update_nodes (cg , nodes , nodes_ts )
85104 cg .client .write (rows )
0 commit comments