@@ -113,30 +113,49 @@ def get_node(n):
113
113
return nodes , edges
114
114
115
115
116
- def draw_graph (graph : Graph , ax = None ):
116
+ def draw_graph (graph : Graph , ax = None , * , adjust_axes = None ):
117
117
if ax is None :
118
118
fig , ax = plt .subplots ()
119
+ if adjust_axes is None :
120
+ adjust_axes = True
121
+
122
+ inverted = adjust_axes or ax .yaxis .get_inverted ()
119
123
120
124
origin_y = 0
125
+ xmax = 0
121
126
122
127
for sg in graph ._subgraphs :
123
128
nodes , edges = _position_subgraph (sg )
129
+ annotations = {}
124
130
# Draw nodes
125
131
for node in nodes :
126
- ax .annotate (
127
- node .format (), (node .x , node .y + origin_y ), bbox = {"boxstyle" : "round" }
132
+ annotations [node .format ()] = ax .annotate (
133
+ node .format (),
134
+ (node .x , node .y + origin_y ),
135
+ ha = "center" ,
136
+ va = "center" ,
137
+ bbox = {"boxstyle" : "round" , "facecolor" : "none" },
128
138
)
129
139
130
140
# Draw edges
131
141
for edge in edges :
132
- ax .annotate (
142
+ arr = ax .annotate (
133
143
"" ,
134
- (edge .child .x , edge .child .y + origin_y ),
135
- (edge .parent .x , edge .parent .y + origin_y ),
144
+ (0.5 , 1.05 if inverted else - 0.05 ),
145
+ (0.5 , - 0.05 if inverted else 1.05 ),
146
+ xycoords = annotations [edge .child .format ()],
147
+ textcoords = annotations [edge .parent .format ()],
136
148
arrowprops = {"arrowstyle" : "->" },
149
+ annotation_clip = True ,
137
150
)
138
- mid_x = (edge .child .x + edge .parent .x ) / 2
139
- mid_y = (edge .child .y + edge .parent .y ) / 2
140
- ax .text (mid_x , mid_y + origin_y , edge .name )
151
+ ax .annotate (edge .name , (0.5 , 0.5 ), xytext = (0.5 , 0.5 ), textcoords = arr )
141
152
142
153
origin_y += max (node .y for node in nodes ) + 1
154
+ xmax = max (xmax , max (node .x for node in nodes ))
155
+
156
+ if adjust_axes :
157
+ ax .set_ylim (origin_y , - 1 )
158
+ ax .set_xlim (- 1 , xmax + 1 )
159
+ ax .spines [:].set_visible (False )
160
+ ax .set_xticks ([])
161
+ ax .set_yticks ([])
0 commit comments