@@ -270,21 +270,30 @@ func checkExpectations(mt *mtest.T, expectations *[]*expectation, id0, id1 bson.
270
270
return
271
271
}
272
272
273
- for idx , expectation := range * expectations {
274
- var err error
273
+ startedEvents := make ([]* cmdStartedEvt , 0 , len (* expectations ))
274
+ succeededEvents := make ([]* cmdSucceededEvt , 0 , len (* expectations ))
275
+ failedEvents := make ([]* cmdFailedEvt , 0 , len (* expectations ))
275
276
277
+ for _ , expectation := range * expectations {
276
278
if expectation .CommandStartedEvent != nil {
277
- err = compareStartedEvent ( mt , expectation , id0 , id1 )
279
+ startedEvents = append ( startedEvents , expectation . CommandStartedEvent )
278
280
}
279
281
if expectation .CommandSucceededEvent != nil {
280
- err = compareSucceededEvent ( mt , expectation )
282
+ succeededEvents = append ( succeededEvents , expectation . CommandSucceededEvent )
281
283
}
282
284
if expectation .CommandFailedEvent != nil {
283
- err = compareFailedEvent ( mt , expectation )
285
+ failedEvents = append ( failedEvents , expectation . CommandFailedEvent )
284
286
}
285
-
286
- assert .Nil (mt , err , "expectation comparison error at index %v: %s" , idx , err )
287
287
}
288
+
289
+ var err error
290
+ err = compareStartedEvents (mt , startedEvents , id0 , id1 )
291
+ assert .Nil (mt , err , "expectation comparison %s" , err )
292
+ err = compareSucceededEvents (mt , succeededEvents )
293
+ assert .Nil (mt , err , "expectation comparison %s" , err )
294
+ err = compareFailedEvents (mt , failedEvents )
295
+ assert .Nil (mt , err , "expectation comparison %s" , err )
296
+
288
297
}
289
298
290
299
// newMatchError appends `expected` and `actual` BSON data to an error.
@@ -298,83 +307,104 @@ func newMatchError(mt *mtest.T, expected bson.Raw, actual bson.Raw, format strin
298
307
return fmt .Errorf ("%s\n Expected %s\n Got: %s" , msg , string (expectedJSON ), string (actualJSON ))
299
308
}
300
309
301
- func compareStartedEvent (mt * mtest.T , expectation * expectation , id0 , id1 bson.Raw ) error {
310
+ func compareStartedEvents (mt * mtest.T , expectations [] * cmdStartedEvt , id0 , id1 bson.Raw ) error {
302
311
mt .Helper ()
303
312
304
- expected := expectation .CommandStartedEvent
305
-
306
- if len (expected .Extra ) > 0 {
307
- return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
308
- }
309
-
310
- evt := mt .GetStartedEvent ()
311
- if evt == nil {
312
- return errors .New ("expected CommandStartedEvent, got nil" )
313
- }
314
-
315
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
316
- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
317
- }
318
- if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
319
- return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
320
- }
321
-
322
- eElems , err := expected .Command .Elements ()
323
- if err != nil {
324
- return fmt .Errorf ("error getting expected command elements: %s" , err )
313
+ expectedCmds := make (map [string ]bool )
314
+ for _ , expected := range expectations {
315
+ expectedCmds [expected .CommandName ] = true
325
316
}
326
317
327
- for _ , elem := range eElems {
328
- key := elem .Key ()
329
- val := elem .Value ()
330
-
331
- actualVal , err := evt .Command .LookupErr (key )
318
+ compare := func (expected * cmdStartedEvt ) error {
319
+ if len (expected .Extra ) > 0 {
320
+ return fmt .Errorf ("unrecognized fields for CommandStartedEvent: %v" , expected .Extra )
321
+ }
332
322
333
- // Keys that may be nil
334
- if val .Type == bson .TypeNull {
335
- // Expected value is BSON null. Expect the actual field to be omitted.
336
- if errors .Is (err , bsoncore .ErrElementNotFound ) {
337
- continue
323
+ var evt * event.CommandStartedEvent
324
+ // skip events not in expectations
325
+ for {
326
+ evt = mt .GetStartedEvent ()
327
+ if evt == nil {
328
+ return fmt .Errorf ("expected CommandStartedEvent %s, got nil" , expected .CommandName )
338
329
}
339
- if err != nil {
340
- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
330
+ if expected .CommandName == "" {
331
+ break
332
+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
333
+ break
341
334
}
342
- return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
343
335
}
344
- assert .Nil (mt , err , "expected command to contain key %q" , key )
345
336
346
- if key == "batchSize" {
347
- // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
348
- // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
349
- // versions do not support the limit option, but not for 3.2+. We've already validated that the command
350
- // contains a batchSize field above and we can skip the actual value comparison below.
351
- continue
337
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
338
+ return fmt .Errorf ("command name mismatch for started event; expected %s, got %s" , expected .CommandName , evt .CommandName )
339
+ }
340
+ if expected .DatabaseName != "" && expected .DatabaseName != evt .DatabaseName {
341
+ return fmt .Errorf ("database name mismatch; expected %s, got %s" , expected .DatabaseName , evt .DatabaseName )
352
342
}
353
343
354
- switch key {
355
- case "lsid" :
356
- sessName := val .StringValue ()
357
- var expectedID bson.Raw
358
- actualID := actualVal .Document ()
344
+ eElems , err := expected .Command .Elements ()
345
+ if err != nil {
346
+ return fmt .Errorf ("error getting expected command elements: %s" , err )
347
+ }
359
348
360
- switch sessName {
361
- case "session0" :
362
- expectedID = id0
363
- case "session1" :
364
- expectedID = id1
365
- default :
366
- return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
349
+ for _ , elem := range eElems {
350
+ key := elem .Key ()
351
+ val := elem .Value ()
352
+
353
+ actualVal , err := evt .Command .LookupErr (key )
354
+
355
+ // Keys that may be nil
356
+ if val .Type == bson .TypeNull {
357
+ // Expected value is BSON null. Expect the actual field to be omitted.
358
+ if errors .Is (err , bsoncore .ErrElementNotFound ) {
359
+ continue
360
+ }
361
+ if err != nil {
362
+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got error: %v" , key , err )
363
+ }
364
+ return newMatchError (mt , expected .Command , evt .Command , "expected key %q to be omitted but got %q" , key , actualVal )
367
365
}
366
+ assert .Nil (mt , err , "expected command to contain key %q" , key )
368
367
369
- if ! bytes .Equal (expectedID , actualID ) {
370
- return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
371
- actualID )
368
+ if key == "batchSize" {
369
+ // Some command monitoring tests expect that the driver will send a lower batch size if the required batch
370
+ // size is lower than the operation limit. We only do this for legacy servers <= 3.0 because those server
371
+ // versions do not support the limit option, but not for 3.2+. We've already validated that the command
372
+ // contains a batchSize field above and we can skip the actual value comparison below.
373
+ continue
372
374
}
373
- default :
374
- if err := compareValues (mt , key , val , actualVal ); err != nil {
375
- return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
375
+
376
+ switch key {
377
+ case "lsid" :
378
+ sessName := val .StringValue ()
379
+ var expectedID bson.Raw
380
+ actualID := actualVal .Document ()
381
+
382
+ switch sessName {
383
+ case "session0" :
384
+ expectedID = id0
385
+ case "session1" :
386
+ expectedID = id1
387
+ default :
388
+ return newMatchError (mt , expected .Command , evt .Command , "unrecognized session identifier in command document: %s" , sessName )
389
+ }
390
+
391
+ if ! bytes .Equal (expectedID , actualID ) {
392
+ return newMatchError (mt , expected .Command , evt .Command , "session ID mismatch for session %s; expected %s, got %s" , sessName , expectedID ,
393
+ actualID )
394
+ }
395
+ default :
396
+ if err := compareValues (mt , key , val , actualVal ); err != nil {
397
+ return newMatchError (mt , expected .Command , evt .Command , "%s" , err )
398
+ }
376
399
}
377
400
}
401
+ return nil
402
+ }
403
+ for idx , expected := range expectations {
404
+ err := compare (expected )
405
+ if err != nil {
406
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
407
+ }
378
408
}
379
409
return nil
380
410
}
@@ -416,60 +446,106 @@ func compareWriteErrors(mt *mtest.T, expected, actual bson.Raw) error {
416
446
return nil
417
447
}
418
448
419
- func compareSucceededEvent (mt * mtest.T , expectation * expectation ) error {
449
+ func compareSucceededEvents (mt * mtest.T , expectations [] * cmdSucceededEvt ) error {
420
450
mt .Helper ()
421
451
422
- expected := expectation .CommandSucceededEvent
423
- if len (expected .Extra ) > 0 {
424
- return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
425
- }
426
- evt := mt .GetSucceededEvent ()
427
- if evt == nil {
428
- return errors .New ("expected CommandSucceededEvent, got nil" )
452
+ expectedCmds := make (map [string ]bool )
453
+ for _ , expected := range expectations {
454
+ expectedCmds [expected .CommandName ] = true
429
455
}
430
456
431
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
432
- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
433
- }
457
+ compare := func (expected * cmdSucceededEvt ) error {
458
+ if len (expected .Extra ) > 0 {
459
+ return fmt .Errorf ("unrecognized fields for CommandSucceededEvent: %v" , expected .Extra )
460
+ }
434
461
435
- eElems , err := expected .Reply .Elements ()
436
- if err != nil {
437
- return fmt .Errorf ("error getting expected reply elements: %s" , err )
438
- }
462
+ var evt * event.CommandSucceededEvent
463
+ // skip events not in expectations
464
+ for {
465
+ evt = mt .GetSucceededEvent ()
466
+ if evt == nil {
467
+ return fmt .Errorf ("expected CommandSucceededEvent %s, got nil" , expected .CommandName )
468
+ }
469
+ if expected .CommandName == "" {
470
+ break
471
+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
472
+ break
473
+ }
474
+ }
439
475
440
- for _ , elem := range eElems {
441
- key := elem .Key ()
442
- val := elem .Value ()
443
- actualVal := evt .Reply .Lookup (key )
476
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
477
+ return fmt .Errorf ("command name mismatch for succeeded event; expected %s, got %s" , expected .CommandName , evt .CommandName )
478
+ }
444
479
445
- switch key {
446
- case "writeErrors" :
447
- if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
448
- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
449
- }
450
- default :
451
- if err := compareValues (mt , key , val , actualVal ); err != nil {
452
- return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
480
+ eElems , err := expected .Reply .Elements ()
481
+ if err != nil {
482
+ return fmt .Errorf ("error getting expected reply elements: %s" , err )
483
+ }
484
+
485
+ for _ , elem := range eElems {
486
+ key := elem .Key ()
487
+ val := elem .Value ()
488
+ actualVal := evt .Reply .Lookup (key )
489
+
490
+ switch key {
491
+ case "writeErrors" :
492
+ if err = compareWriteErrors (mt , val .Array (), actualVal .Array ()); err != nil {
493
+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
494
+ }
495
+ default :
496
+ if err := compareValues (mt , key , val , actualVal ); err != nil {
497
+ return newMatchError (mt , expected .Reply , evt .Reply , "%s" , err )
498
+ }
453
499
}
454
500
}
501
+ return nil
502
+ }
503
+ for idx , expected := range expectations {
504
+ err := compare (expected )
505
+ if err != nil {
506
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
507
+ }
455
508
}
456
509
return nil
457
510
}
458
511
459
- func compareFailedEvent (mt * mtest.T , expectation * expectation ) error {
512
+ func compareFailedEvents (mt * mtest.T , expectations [] * cmdFailedEvt ) error {
460
513
mt .Helper ()
461
514
462
- expected := expectation .CommandFailedEvent
463
- if len (expected .Extra ) > 0 {
464
- return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
465
- }
466
- evt := mt .GetFailedEvent ()
467
- if evt == nil {
468
- return errors .New ("expected CommandFailedEvent, got nil" )
515
+ expectedCmds := make (map [string ]bool )
516
+ for _ , expected := range expectations {
517
+ expectedCmds [expected .CommandName ] = true
469
518
}
470
519
471
- if expected .CommandName != "" && expected .CommandName != evt .CommandName {
472
- return fmt .Errorf ("command name mismatch; expected %s, got %s" , expected .CommandName , evt .CommandName )
520
+ compare := func (expected * cmdFailedEvt ) error {
521
+ if len (expected .Extra ) > 0 {
522
+ return fmt .Errorf ("unrecognized fields for CommandFailedEvent: %v" , expected .Extra )
523
+ }
524
+
525
+ var evt * event.CommandFailedEvent
526
+ // skip events not in expectations
527
+ for {
528
+ evt = mt .GetFailedEvent ()
529
+ if evt == nil {
530
+ return fmt .Errorf ("expected CommandFailedEvent %s, got nil" , expected .CommandName )
531
+ }
532
+ if expected .CommandName == "" {
533
+ break
534
+ } else if v , ok := expectedCmds [evt .CommandName ]; ok && v {
535
+ break
536
+ }
537
+ }
538
+
539
+ if expected .CommandName != "" && expected .CommandName != evt .CommandName {
540
+ return fmt .Errorf ("command name mismatch for failed event; expected %s, got %s" , expected .CommandName , evt .CommandName )
541
+ }
542
+ return nil
543
+ }
544
+ for idx , expected := range expectations {
545
+ err := compare (expected )
546
+ if err != nil {
547
+ return fmt .Errorf ("error at index %d: %s" , idx , err )
548
+ }
473
549
}
474
550
return nil
475
551
}
0 commit comments