Skip to content

Commit 387ecc4

Browse files
committed
Adjust formatting of introspection plots
1 parent 463a7f1 commit 387ecc4

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

Diff for: data_prototype/introspection.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,49 @@ def get_node(n):
113113
return nodes, edges
114114

115115

116-
def draw_graph(graph: Graph, ax=None):
116+
def draw_graph(graph: Graph, ax=None, *, adjust_axes=None):
117117
if ax is None:
118118
fig, ax = plt.subplots()
119+
if adjust_axes is None:
120+
adjust_axes = True
121+
122+
inverted = adjust_axes or ax.yaxis.get_inverted()
119123

120124
origin_y = 0
125+
xmax = 0
121126

122127
for sg in graph._subgraphs:
123128
nodes, edges = _position_subgraph(sg)
129+
annotations = {}
124130
# Draw nodes
125131
for node in nodes:
126-
ax.annotate(
127-
node.format(), (node.x, node.y + origin_y), bbox={"boxstyle": "round"}
132+
annotations[node.format()] = ax.annotate(
133+
node.format(),
134+
(node.x, node.y + origin_y),
135+
ha="center",
136+
va="center",
137+
bbox={"boxstyle": "round", "facecolor": "none"},
128138
)
129139

130140
# Draw edges
131141
for edge in edges:
132-
ax.annotate(
142+
arr = ax.annotate(
133143
"",
134-
(edge.child.x, edge.child.y + origin_y),
135-
(edge.parent.x, edge.parent.y + origin_y),
144+
(0.5, 1.05 if inverted else -0.05),
145+
(0.5, -0.05 if inverted else 1.05),
146+
xycoords=annotations[edge.child.format()],
147+
textcoords=annotations[edge.parent.format()],
136148
arrowprops={"arrowstyle": "->"},
149+
annotation_clip=True,
137150
)
138-
mid_x = (edge.child.x + edge.parent.x) / 2
139-
mid_y = (edge.child.y + edge.parent.y) / 2
140-
ax.text(mid_x, mid_y + origin_y, edge.name)
151+
ax.annotate(edge.name, (0.5, 0.5), xytext=(0.5, 0.5), textcoords=arr)
141152

142153
origin_y += max(node.y for node in nodes) + 1
154+
xmax = max(xmax, max(node.x for node in nodes))
155+
156+
if adjust_axes:
157+
ax.set_ylim(origin_y, -1)
158+
ax.set_xlim(-1, xmax + 1)
159+
ax.spines[:].set_visible(False)
160+
ax.set_xticks([])
161+
ax.set_yticks([])

0 commit comments

Comments
 (0)