Skip to content

Commit 4dcc585

Browse files
committed
updates
1 parent 928555a commit 4dcc585

File tree

3 files changed

+206
-122
lines changed

3 files changed

+206
-122
lines changed

mongo/integration/cmd_monitoring_helpers_test.go

+180-104
Original file line numberDiff line numberDiff line change
@@ -270,21 +270,30 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270270
return
271271
}
272272

273-
for idx, expectation := range *expectations {
274-
var err error
273+
startedEvents := make([]*cmdStartedEvt, 0, len(*expectations))
274+
succeededEvents := make([]*cmdSucceededEvt, 0, len(*expectations))
275+
failedEvents := make([]*cmdFailedEvt, 0, len(*expectations))
275276

277+
for _, expectation := range *expectations {
276278
if expectation.CommandStartedEvent != nil {
277-
err = compareStartedEvent(mt, expectation, id0, id1)
279+
startedEvents = append(startedEvents, expectation.CommandStartedEvent)
278280
}
279281
if expectation.CommandSucceededEvent != nil {
280-
err = compareSucceededEvent(mt, expectation)
282+
succeededEvents = append(succeededEvents, expectation.CommandSucceededEvent)
281283
}
282284
if expectation.CommandFailedEvent != nil {
283-
err = compareFailedEvent(mt, expectation)
285+
failedEvents = append(failedEvents, expectation.CommandFailedEvent)
284286
}
285-
286-
assert.Nil(mt, err, "expectation comparison error at index %v: %s", idx, err)
287287
}
288+
289+
var err error
290+
err = compareStartedEvents(mt, startedEvents, id0, id1)
291+
assert.Nil(mt, err, "expectation comparison %s", err)
292+
err = compareSucceededEvents(mt, succeededEvents)
293+
assert.Nil(mt, err, "expectation comparison %s", err)
294+
err = compareFailedEvents(mt, failedEvents)
295+
assert.Nil(mt, err, "expectation comparison %s", err)
296+
288297
}
289298

290299
// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +307,104 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298307
return fmt.Errorf("%s\nExpected %s\nGot: %s", msg, string(expectedJSON), string(actualJSON))
299308
}
300309

301-
func compareStartedEvent(mt *mtest.T, expectation *expectation, id0, id1 bson.Raw) error {
310+
func compareStartedEvents(mt *mtest.T, expectations []*cmdStartedEvt, id0, id1 bson.Raw) error {
302311
mt.Helper()
303312

304-
expected := expectation.CommandStartedEvent
305-
306-
if len(expected.Extra) > 0 {
307-
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
308-
}
309-
310-
evt := mt.GetStartedEvent()
311-
if evt == nil {
312-
return errors.New("expected CommandStartedEvent, got nil")
313-
}
314-
315-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
316-
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
317-
}
318-
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
319-
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
320-
}
321-
322-
eElems, err := expected.Command.Elements()
323-
if err != nil {
324-
return fmt.Errorf("error getting expected command elements: %s", err)
313+
expectedCmds := make(map[string]bool)
314+
for _, expected := range expectations {
315+
expectedCmds[expected.CommandName] = true
325316
}
326317

327-
for _, elem := range eElems {
328-
key := elem.Key()
329-
val := elem.Value()
330-
331-
actualVal, err := evt.Command.LookupErr(key)
318+
compare := func(expected *cmdStartedEvt) error {
319+
if len(expected.Extra) > 0 {
320+
return fmt.Errorf("unrecognized fields for CommandStartedEvent: %v", expected.Extra)
321+
}
332322

333-
// Keys that may be nil
334-
if val.Type == bson.TypeNull {
335-
// Expected value is BSON null. Expect the actual field to be omitted.
336-
if errors.Is(err, bsoncore.ErrElementNotFound) {
337-
continue
323+
var evt *event.CommandStartedEvent
324+
// skip events not in expectations
325+
for {
326+
evt = mt.GetStartedEvent()
327+
if evt == nil {
328+
return fmt.Errorf("expected CommandStartedEvent %s, got nil", expected.CommandName)
338329
}
339-
if err != nil {
340-
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
330+
if expected.CommandName == "" {
331+
break
332+
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
333+
break
341334
}
342-
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
343335
}
344-
assert.Nil(mt, err, "expected command to contain key %q", key)
345336

346-
if key == "batchSize" {
347-
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348-
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349-
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
350-
// contains a batchSize field above and we can skip the actual value comparison below.
351-
continue
337+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
338+
return fmt.Errorf("command name mismatch for started event; expected %s, got %s", expected.CommandName, evt.CommandName)
339+
}
340+
if expected.DatabaseName != "" && expected.DatabaseName != evt.DatabaseName {
341+
return fmt.Errorf("database name mismatch; expected %s, got %s", expected.DatabaseName, evt.DatabaseName)
352342
}
353343

354-
switch key {
355-
case "lsid":
356-
sessName := val.StringValue()
357-
var expectedID bson.Raw
358-
actualID := actualVal.Document()
344+
eElems, err := expected.Command.Elements()
345+
if err != nil {
346+
return fmt.Errorf("error getting expected command elements: %s", err)
347+
}
359348

360-
switch sessName {
361-
case "session0":
362-
expectedID = id0
363-
case "session1":
364-
expectedID = id1
365-
default:
366-
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
349+
for _, elem := range eElems {
350+
key := elem.Key()
351+
val := elem.Value()
352+
353+
actualVal, err := evt.Command.LookupErr(key)
354+
355+
// Keys that may be nil
356+
if val.Type == bson.TypeNull {
357+
// Expected value is BSON null. Expect the actual field to be omitted.
358+
if errors.Is(err, bsoncore.ErrElementNotFound) {
359+
continue
360+
}
361+
if err != nil {
362+
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got error: %v", key, err)
363+
}
364+
return newMatchError(mt, expected.Command, evt.Command, "expected key %q to be omitted but got %q", key, actualVal)
367365
}
366+
assert.Nil(mt, err, "expected command to contain key %q", key)
368367

369-
if !bytes.Equal(expectedID, actualID) {
370-
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
371-
actualID)
368+
if key == "batchSize" {
369+
// Some command monitoring tests expect that the driver will send a lower batch size if the required batch
370+
// size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
371+
// versions do not support the limit option, but not for 3.2+. We've already validated that the command
372+
// contains a batchSize field above and we can skip the actual value comparison below.
373+
continue
372374
}
373-
default:
374-
if err := compareValues(mt, key, val, actualVal); err != nil {
375-
return newMatchError(mt, expected.Command, evt.Command, "%s", err)
375+
376+
switch key {
377+
case "lsid":
378+
sessName := val.StringValue()
379+
var expectedID bson.Raw
380+
actualID := actualVal.Document()
381+
382+
switch sessName {
383+
case "session0":
384+
expectedID = id0
385+
case "session1":
386+
expectedID = id1
387+
default:
388+
return newMatchError(mt, expected.Command, evt.Command, "unrecognized session identifier in command document: %s", sessName)
389+
}
390+
391+
if !bytes.Equal(expectedID, actualID) {
392+
return newMatchError(mt, expected.Command, evt.Command, "session ID mismatch for session %s; expected %s, got %s", sessName, expectedID,
393+
actualID)
394+
}
395+
default:
396+
if err := compareValues(mt, key, val, actualVal); err != nil {
397+
return newMatchError(mt, expected.Command, evt.Command, "%s", err)
398+
}
376399
}
377400
}
401+
return nil
402+
}
403+
for idx, expected := range expectations {
404+
err := compare(expected)
405+
if err != nil {
406+
return fmt.Errorf("error at index %d: %s", idx, err)
407+
}
378408
}
379409
return nil
380410
}
@@ -416,60 +446,106 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416446
return nil
417447
}
418448

419-
func compareSucceededEvent(mt *mtest.T, expectation *expectation) error {
449+
func compareSucceededEvents(mt *mtest.T, expectations []*cmdSucceededEvt) error {
420450
mt.Helper()
421451

422-
expected := expectation.CommandSucceededEvent
423-
if len(expected.Extra) > 0 {
424-
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
425-
}
426-
evt := mt.GetSucceededEvent()
427-
if evt == nil {
428-
return errors.New("expected CommandSucceededEvent, got nil")
452+
expectedCmds := make(map[string]bool)
453+
for _, expected := range expectations {
454+
expectedCmds[expected.CommandName] = true
429455
}
430456

431-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
432-
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
433-
}
457+
compare := func(expected *cmdSucceededEvt) error {
458+
if len(expected.Extra) > 0 {
459+
return fmt.Errorf("unrecognized fields for CommandSucceededEvent: %v", expected.Extra)
460+
}
434461

435-
eElems, err := expected.Reply.Elements()
436-
if err != nil {
437-
return fmt.Errorf("error getting expected reply elements: %s", err)
438-
}
462+
var evt *event.CommandSucceededEvent
463+
// skip events not in expectations
464+
for {
465+
evt = mt.GetSucceededEvent()
466+
if evt == nil {
467+
return fmt.Errorf("expected CommandSucceededEvent %s, got nil", expected.CommandName)
468+
}
469+
if expected.CommandName == "" {
470+
break
471+
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
472+
break
473+
}
474+
}
439475

440-
for _, elem := range eElems {
441-
key := elem.Key()
442-
val := elem.Value()
443-
actualVal := evt.Reply.Lookup(key)
476+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
477+
return fmt.Errorf("command name mismatch for succeeded event; expected %s, got %s", expected.CommandName, evt.CommandName)
478+
}
444479

445-
switch key {
446-
case "writeErrors":
447-
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
448-
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
449-
}
450-
default:
451-
if err := compareValues(mt, key, val, actualVal); err != nil {
452-
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
480+
eElems, err := expected.Reply.Elements()
481+
if err != nil {
482+
return fmt.Errorf("error getting expected reply elements: %s", err)
483+
}
484+
485+
for _, elem := range eElems {
486+
key := elem.Key()
487+
val := elem.Value()
488+
actualVal := evt.Reply.Lookup(key)
489+
490+
switch key {
491+
case "writeErrors":
492+
if err = compareWriteErrors(mt, val.Array(), actualVal.Array()); err != nil {
493+
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
494+
}
495+
default:
496+
if err := compareValues(mt, key, val, actualVal); err != nil {
497+
return newMatchError(mt, expected.Reply, evt.Reply, "%s", err)
498+
}
453499
}
454500
}
501+
return nil
502+
}
503+
for idx, expected := range expectations {
504+
err := compare(expected)
505+
if err != nil {
506+
return fmt.Errorf("error at index %d: %s", idx, err)
507+
}
455508
}
456509
return nil
457510
}
458511

459-
func compareFailedEvent(mt *mtest.T, expectation *expectation) error {
512+
func compareFailedEvents(mt *mtest.T, expectations []*cmdFailedEvt) error {
460513
mt.Helper()
461514

462-
expected := expectation.CommandFailedEvent
463-
if len(expected.Extra) > 0 {
464-
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
465-
}
466-
evt := mt.GetFailedEvent()
467-
if evt == nil {
468-
return errors.New("expected CommandFailedEvent, got nil")
515+
expectedCmds := make(map[string]bool)
516+
for _, expected := range expectations {
517+
expectedCmds[expected.CommandName] = true
469518
}
470519

471-
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
472-
return fmt.Errorf("command name mismatch; expected %s, got %s", expected.CommandName, evt.CommandName)
520+
compare := func(expected *cmdFailedEvt) error {
521+
if len(expected.Extra) > 0 {
522+
return fmt.Errorf("unrecognized fields for CommandFailedEvent: %v", expected.Extra)
523+
}
524+
525+
var evt *event.CommandFailedEvent
526+
// skip events not in expectations
527+
for {
528+
evt = mt.GetFailedEvent()
529+
if evt == nil {
530+
return fmt.Errorf("expected CommandFailedEvent %s, got nil", expected.CommandName)
531+
}
532+
if expected.CommandName == "" {
533+
break
534+
} else if v, ok := expectedCmds[evt.CommandName]; ok && v {
535+
break
536+
}
537+
}
538+
539+
if expected.CommandName != "" && expected.CommandName != evt.CommandName {
540+
return fmt.Errorf("command name mismatch for failed event; expected %s, got %s", expected.CommandName, evt.CommandName)
541+
}
542+
return nil
543+
}
544+
for idx, expected := range expectations {
545+
err := compare(expected)
546+
if err != nil {
547+
return fmt.Errorf("error at index %d: %s", idx, err)
548+
}
473549
}
474550
return nil
475551
}

mongo/integration/mtest/mongotest.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -616,9 +616,11 @@ func (t *T) ClearFailPoints() {
616616
if err != nil {
617617
t.Fatalf("error clearing fail point %s: %v", fp.name, err)
618618
}
619-
if fp.client != t.Client {
620-
_ = fp.client.Disconnect(context.Background())
621-
t.fpClients[fp.client] = false
619+
t.fpClients[fp.client] = false
620+
}
621+
for client, active := range t.fpClients {
622+
if !active && client != t.Client {
623+
_ = client.Disconnect(context.Background())
622624
}
623625
}
624626
t.failPoints = t.failPoints[:0]

0 commit comments

Comments
 (0)