@@ -133,14 +133,16 @@ type T struct {
133
133
succeeded []* event.CommandSucceededEvent
134
134
failed []* event.CommandFailedEvent
135
135
136
- Client * mongo.Client
137
- DB * mongo.Database
138
- Coll * mongo.Collection
136
+ Client * mongo.Client
137
+ fpClients map [* mongo.Client ]bool
138
+ DB * mongo.Database
139
+ Coll * mongo.Collection
139
140
}
140
141
141
142
func newT (wrapped * testing.T , opts ... * Options ) * T {
142
143
t := & T {
143
- T : wrapped ,
144
+ T : wrapped ,
145
+ fpClients : make (map [* mongo.Client ]bool ),
144
146
}
145
147
for _ , opt := range opts {
146
148
for _ , optFn := range opt .optFuncs {
@@ -207,6 +209,12 @@ func (t *T) cleanup() {
207
209
// always disconnect the client regardless of clientType because Client.Disconnect will work against
208
210
// all deployments
209
211
_ = t .Client .Disconnect (context .Background ())
212
+ for client , v := range t .fpClients {
213
+ if v {
214
+ client .Disconnect (context .Background ())
215
+ }
216
+ }
217
+ t .fpClients = make (map [* mongo.Client ]bool )
210
218
}
211
219
212
220
// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
@@ -261,7 +269,9 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
261
269
}
262
270
// only disconnect client if it's not being shared
263
271
if sub .shareClient == nil || ! * sub .shareClient {
264
- _ = sub .Client .Disconnect (context .Background ())
272
+ if v , ok := sub .fpClients [sub .Client ]; ! ok || ! v {
273
+ _ = sub .Client .Disconnect (context .Background ())
274
+ }
265
275
}
266
276
assert .Equal (sub , 0 , sessions , "%v sessions checked out" , sessions )
267
277
assert .Equal (sub , 0 , conns , "%v connections checked out" , conns )
@@ -410,7 +420,9 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
410
420
t .clientOpts = opts
411
421
}
412
422
413
- _ = t .Client .Disconnect (context .Background ())
423
+ if v , ok := t .fpClients [t .Client ]; ! ok || ! v {
424
+ _ = t .Client .Disconnect (context .Background ())
425
+ }
414
426
t .createTestClient ()
415
427
t .DB = t .Client .Database (t .dbName )
416
428
t .Coll = t .DB .Collection (t .collName , t .collOpts )
@@ -564,42 +576,31 @@ func (t *T) SetFailPoint(fp FailPoint) {
564
576
}
565
577
}
566
578
567
- client , err := mongo .NewClient (t .clientOpts )
568
- if err != nil {
569
- t .Fatalf ("error creating client: %v" , err )
570
- }
571
- if err = client .Connect (context .Background ()); err != nil {
572
- t .Fatalf ("error connecting client: %v" , err )
573
- }
574
- if err = SetFailPoint (fp , client ); err != nil {
579
+ if err := SetFailPoint (fp , t .Client ); err != nil {
575
580
t .Fatal (err )
576
581
}
577
- t .failPoints = append (t .failPoints , failPoint {fp .ConfigureFailPoint , client })
582
+ t .fpClients [t .Client ] = true
583
+ t .failPoints = append (t .failPoints , failPoint {fp .ConfigureFailPoint , t .Client })
578
584
}
579
585
580
586
// SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This
581
587
// method assumes that the given document is in the form {configureFailPoint: <failPointName>, ...}. Commands to create
582
588
// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
583
589
// test has run.
584
590
func (t * T ) SetFailPointFromDocument (fp bson.Raw ) {
585
- client , err := mongo .NewClient (t .clientOpts )
586
- if err != nil {
587
- t .Fatalf ("error creating client: %v" , err )
588
- }
589
- if err = client .Connect (context .Background ()); err != nil {
590
- t .Fatalf ("error connecting client: %v" , err )
591
- }
592
- if err = SetRawFailPoint (fp , client ); err != nil {
591
+ if err := SetRawFailPoint (fp , t .Client ); err != nil {
593
592
t .Fatal (err )
594
593
}
595
594
595
+ t .fpClients [t .Client ] = true
596
596
name := fp .Index (0 ).Value ().StringValue ()
597
- t .failPoints = append (t .failPoints , failPoint {name , client })
597
+ t .failPoints = append (t .failPoints , failPoint {name , t . Client })
598
598
}
599
599
600
600
// TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes.
601
601
// This function does not create a fail point on the server.
602
602
func (t * T ) TrackFailPoint (fpName string , client * mongo.Client ) {
603
+ t .fpClients [client ] = true
603
604
t .failPoints = append (t .failPoints , failPoint {fpName , client })
604
605
}
605
606
@@ -614,7 +615,10 @@ func (t *T) ClearFailPoints() {
614
615
if err != nil {
615
616
t .Fatalf ("error clearing fail point %s: %v" , fp .name , err )
616
617
}
617
- _ = fp .client .Disconnect (context .Background ())
618
+ if fp .client != t .Client {
619
+ _ = fp .client .Disconnect (context .Background ())
620
+ t .fpClients [fp .client ] = false
621
+ }
618
622
}
619
623
t .failPoints = t .failPoints [:0 ]
620
624
}
0 commit comments