Skip to content

Commit c7367fc

Browse files
committed
fix(migration): use None for non stale node end ts, performance improvements
1 parent cee6773 commit c7367fc

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

pychunkedgraph/ingest/upgrade/atomic_layer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import defaultdict
44
from concurrent.futures import ThreadPoolExecutor, as_completed
5-
from datetime import timedelta
5+
from datetime import datetime, timedelta, timezone
66
import logging, math, time
77
from copy import copy
88

@@ -39,6 +39,7 @@ def update_cross_edges(
3939
for partner in partners:
4040
timestamps.update(timestamps_d[partner])
4141

42+
node_end_ts = node_end_ts or datetime.now(timezone.utc)
4243
for ts in sorted(timestamps):
4344
if ts < node_ts:
4445
continue
@@ -76,15 +77,20 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
7677

7778
rows = []
7879
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
79-
end_ts -= timedelta(milliseconds=1)
80+
is_stale = end_ts is not None
8081
_cx_edges_d = cx_edges_d.get(node, {})
8182
if not _cx_edges_d:
8283
continue
84+
if is_stale:
85+
end_ts -= timedelta(milliseconds=1)
86+
8387
_rows = update_cross_edges(cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d)
84-
row_id = serializers.serialize_uint64(node)
85-
val_dict = {Hierarchy.StaleTimeStamp: 0}
86-
_rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
88+
if is_stale:
89+
row_id = serializers.serialize_uint64(node)
90+
val_dict = {Hierarchy.StaleTimeStamp: 0}
91+
_rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
8792
rows.extend(_rows)
93+
8894
return rows
8995

9096

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing as mp
55
from collections import defaultdict
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from datetime import datetime, timezone
78

89
import fastremap
910
import numpy as np
@@ -59,33 +60,47 @@ def _populate_cx_edges_with_timestamps(
5960
for all IDs involved in an edit, we can use the timestamps of
6061
when cross edges of children were updated.
6162
"""
63+
64+
start = time.time()
6265
global CX_EDGES
6366
attrs = [Connectivity.CrossChunkEdge[l] for l in range(layer, cg.meta.layer_count)]
6467
all_children = np.concatenate(list(CHILDREN.values()))
6568
response = cg.client.read_nodes(node_ids=all_children, properties=attrs)
6669
timestamps_d = get_parent_timestamps(cg, nodes)
6770
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN, layer=layer)
71+
logging.info(f"_populate_nodes_and_children init: {time.time() - start}")
6872

69-
rows = []
70-
for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps):
73+
start = time.time()
74+
partners_map = {}
75+
for node, node_ts in zip(nodes, nodes_ts):
7176
CX_EDGES[node] = {}
72-
timestamps = timestamps_d[node]
7377
cx_edges_d_node_ts = _get_cx_edges_at_timestamp(node, response, node_ts)
74-
7578
edges = np.concatenate([empty_2d] + list(cx_edges_d_node_ts.values()))
76-
partner_parent_ts_d = get_parent_timestamps(cg, edges[:, 1])
77-
for v in partner_parent_ts_d.values():
78-
timestamps.update(v)
79+
partners_map[node] = edges[:, 1]
7980
CX_EDGES[node][node_ts] = cx_edges_d_node_ts
8081

82+
partners = np.unique(np.concatenate([*partners_map.values()]))
83+
partner_parent_ts_d = get_parent_timestamps(cg, partners)
84+
logging.info(f"get partners timestamps init: {time.time() - start}")
85+
86+
rows = []
87+
for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps):
88+
timestamps = timestamps_d[node]
89+
for partner in partners_map[node]:
90+
timestamps.update(partner_parent_ts_d[partner])
91+
92+
is_stale = node_end_ts is not None
93+
node_end_ts = node_end_ts or datetime.now(timezone.utc)
8194
for ts in sorted(timestamps):
8295
if ts > node_end_ts:
8396
break
8497
CX_EDGES[node][ts] = _get_cx_edges_at_timestamp(node, response, ts)
8598

86-
row_id = serializers.serialize_uint64(node)
87-
val_dict = {Hierarchy.StaleTimeStamp: 0}
88-
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts))
99+
if is_stale:
100+
row_id = serializers.serialize_uint64(node)
101+
val_dict = {Hierarchy.StaleTimeStamp: 0}
102+
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts))
103+
89104
cg.client.write(rows)
90105

91106

@@ -140,7 +155,6 @@ def _update_cross_edges_helper(args):
140155
futures = [executor.submit(_update_cross_edges_helper_thread, task) for task in tasks]
141156
for future in tqdm(as_completed(futures), total=len(futures)):
142157
rows.extend(future.result())
143-
144158
cg.client.write(rows)
145159

146160

@@ -154,13 +168,21 @@ def update_chunk(
154168
start = time.time()
155169
x, y, z = chunk_coords
156170
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
171+
157172
_populate_nodes_and_children(cg, chunk_id, nodes=nodes)
173+
logging.info(f"_populate_nodes_and_children: {time.time() - start}")
158174
if not CHILDREN:
159175
return
160176
nodes = list(CHILDREN.keys())
161177
random.shuffle(nodes)
178+
179+
start = time.time()
162180
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
181+
logging.info(f"get_node_timestamps: {time.time() - start}")
182+
183+
start = time.time()
163184
_populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts)
185+
logging.info(f"_populate_cx_edges_with_timestamps: {time.time() - start}")
164186

165187
if debug:
166188
rows = []

pychunkedgraph/ingest/upgrade/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,19 @@ def get_end_timestamps(
6464

6565
for node, node_ts in zip(nodes, nodes_ts):
6666
node_children = children_map[node]
67-
_timestamps = set().union(*[timestamps_d[k] for k in node_children])
67+
_children_timestamps = []
68+
for k in node_children:
69+
if k in timestamps_d:
70+
_children_timestamps.append(timestamps_d[k])
71+
_timestamps = set().union(*_children_timestamps)
6872
_timestamps.add(node_ts)
6973
try:
7074
_timestamps = sorted(_timestamps)
7175
_index = np.searchsorted(_timestamps, node_ts)
72-
assert _timestamps[_index] == node_ts, (_index, node_ts, _timestamps)
7376
end_ts = _timestamps[_index + 1]
7477
except IndexError:
7578
# this node has not been edited, but might have it edges updated
76-
end_ts = datetime.now(timezone.utc)
79+
end_ts = None
7780
result.append(end_ts)
7881
return result
7982

0 commit comments

Comments
 (0)