Skip to content

Commit a859b1e

Browse files
committed
updates
1 parent c1ed704 commit a859b1e

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

mongo/integration/mtest/mongotest.go

+29-25
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,16 @@ type T struct {
133133
succeeded []*event.CommandSucceededEvent
134134
failed []*event.CommandFailedEvent
135135

136-
Client *mongo.Client
137-
DB *mongo.Database
138-
Coll *mongo.Collection
136+
Client *mongo.Client
137+
fpClients map[*mongo.Client]bool
138+
DB *mongo.Database
139+
Coll *mongo.Collection
139140
}
140141

141142
func newT(wrapped *testing.T, opts ...*Options) *T {
142143
t := &T{
143-
T: wrapped,
144+
T: wrapped,
145+
fpClients: make(map[*mongo.Client]bool),
144146
}
145147
for _, opt := range opts {
146148
for _, optFn := range opt.optFuncs {
@@ -207,6 +209,12 @@ func (t *T) cleanup() {
207209
// always disconnect the client regardless of clientType because Client.Disconnect will work against
208210
// all deployments
209211
_ = t.Client.Disconnect(context.Background())
212+
for client, v := range t.fpClients {
213+
if v {
214+
client.Disconnect(context.Background())
215+
}
216+
}
217+
t.fpClients = make(map[*mongo.Client]bool)
210218
}
211219

212220
// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
@@ -261,7 +269,9 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
261269
}
262270
// only disconnect client if it's not being shared
263271
if sub.shareClient == nil || !*sub.shareClient {
264-
_ = sub.Client.Disconnect(context.Background())
272+
if v, ok := sub.fpClients[sub.Client]; !ok || !v {
273+
_ = sub.Client.Disconnect(context.Background())
274+
}
265275
}
266276
assert.Equal(sub, 0, sessions, "%v sessions checked out", sessions)
267277
assert.Equal(sub, 0, conns, "%v connections checked out", conns)
@@ -410,7 +420,9 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
410420
t.clientOpts = opts
411421
}
412422

413-
_ = t.Client.Disconnect(context.Background())
423+
if v, ok := t.fpClients[t.Client]; !ok || !v {
424+
_ = t.Client.Disconnect(context.Background())
425+
}
414426
t.createTestClient()
415427
t.DB = t.Client.Database(t.dbName)
416428
t.Coll = t.DB.Collection(t.collName, t.collOpts)
@@ -564,42 +576,31 @@ func (t *T) SetFailPoint(fp FailPoint) {
564576
}
565577
}
566578

567-
client, err := mongo.NewClient(t.clientOpts)
568-
if err != nil {
569-
t.Fatalf("error creating client: %v", err)
570-
}
571-
if err = client.Connect(context.Background()); err != nil {
572-
t.Fatalf("error connecting client: %v", err)
573-
}
574-
if err = SetFailPoint(fp, client); err != nil {
579+
if err := SetFailPoint(fp, t.Client); err != nil {
575580
t.Fatal(err)
576581
}
577-
t.failPoints = append(t.failPoints, failPoint{fp.ConfigureFailPoint, client})
582+
t.fpClients[t.Client] = true
583+
t.failPoints = append(t.failPoints, failPoint{fp.ConfigureFailPoint, t.Client})
578584
}
579585

580586
// SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This
581587
// method assumes that the given document is in the form {configureFailPoint: <failPointName>, ...}. Commands to create
582588
// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
583589
// test has run.
584590
func (t *T) SetFailPointFromDocument(fp bson.Raw) {
585-
client, err := mongo.NewClient(t.clientOpts)
586-
if err != nil {
587-
t.Fatalf("error creating client: %v", err)
588-
}
589-
if err = client.Connect(context.Background()); err != nil {
590-
t.Fatalf("error connecting client: %v", err)
591-
}
592-
if err = SetRawFailPoint(fp, client); err != nil {
591+
if err := SetRawFailPoint(fp, t.Client); err != nil {
593592
t.Fatal(err)
594593
}
595594

595+
t.fpClients[t.Client] = true
596596
name := fp.Index(0).Value().StringValue()
597-
t.failPoints = append(t.failPoints, failPoint{name, client})
597+
t.failPoints = append(t.failPoints, failPoint{name, t.Client})
598598
}
599599

600600
// TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes.
601601
// This function does not create a fail point on the server.
602602
func (t *T) TrackFailPoint(fpName string, client *mongo.Client) {
603+
t.fpClients[client] = true
603604
t.failPoints = append(t.failPoints, failPoint{fpName, client})
604605
}
605606

@@ -614,7 +615,10 @@ func (t *T) ClearFailPoints() {
614615
if err != nil {
615616
t.Fatalf("error clearing fail point %s: %v", fp.name, err)
616617
}
617-
_ = fp.client.Disconnect(context.Background())
618+
if fp.client != t.Client {
619+
_ = fp.client.Disconnect(context.Background())
620+
t.fpClients[fp.client] = false
621+
}
618622
}
619623
t.failPoints = t.failPoints[:0]
620624
}

0 commit comments

Comments
 (0)