@@ -437,5 +437,80 @@ def test_add_block_containing_multiple_constraints(self):
437
437
self .assertEqual (opt ._solver_model .linear_constraints .get_num (), 3 )
438
438
439
439
440
+ class TestLoadVars (unittest .TestCase ):
441
+ def setUp (self ):
442
+ opt = SolverFactory ("cplex" , solver_io = "python" )
443
+ model = ConcreteModel ()
444
+ model .X = Var (within = NonNegativeReals , initialize = 0 )
445
+ model .Y = Var (within = NonNegativeReals , initialize = 0 )
446
+
447
+ model .C1 = Constraint (expr = 2 * model .X + model .Y >= 8 )
448
+ model .C2 = Constraint (expr = model .X + 3 * model .Y >= 6 )
449
+
450
+ model .O = Objective (expr = model .X + model .Y )
451
+
452
+ opt .solve (model , load_solutions = False , save_results = False )
453
+
454
+ self ._model = model
455
+ self ._opt = opt
456
+
457
+ def test_all_vars_are_loaded (self ):
458
+ self .assertTrue (self ._model .X .stale )
459
+ self .assertTrue (self ._model .Y .stale )
460
+ self .assertEqual (value (self ._model .X ), 0 )
461
+ self .assertEqual (value (self ._model .Y ), 0 )
462
+
463
+ with unittest .mock .patch .object (
464
+ self ._opt ._solver_model .solution ,
465
+ "get_values" ,
466
+ wraps = self ._opt ._solver_model .solution .get_values ,
467
+ ) as wrapped_values_call :
468
+ self ._opt .load_vars ()
469
+
470
+ self .assertEqual (wrapped_values_call .call_count , 1 )
471
+ self .assertEqual (wrapped_values_call .call_args , tuple ())
472
+
473
+ self .assertFalse (self ._model .X .stale )
474
+ self .assertFalse (self ._model .Y .stale )
475
+ self .assertAlmostEqual (value (self ._model .X ), 3.6 )
476
+ self .assertAlmostEqual (value (self ._model .Y ), 0.8 )
477
+
478
+ def test_only_specified_vars_are_loaded (self ):
479
+ self .assertTrue (self ._model .X .stale )
480
+ self .assertTrue (self ._model .Y .stale )
481
+ self .assertEqual (value (self ._model .X ), 0 )
482
+ self .assertEqual (value (self ._model .Y ), 0 )
483
+
484
+ with unittest .mock .patch .object (
485
+ self ._opt ._solver_model .solution ,
486
+ "get_values" ,
487
+ wraps = self ._opt ._solver_model .solution .get_values ,
488
+ ) as wrapped_values_call :
489
+ self ._opt .load_vars ([self ._model .X ])
490
+
491
+ self .assertEqual (wrapped_values_call .call_count , 1 )
492
+ self .assertEqual (wrapped_values_call .call_args , (([0 ],), {}))
493
+
494
+ self .assertFalse (self ._model .X .stale )
495
+ self .assertTrue (self ._model .Y .stale )
496
+ self .assertAlmostEqual (value (self ._model .X ), 3.6 )
497
+ self .assertEqual (value (self ._model .Y ), 0 )
498
+
499
+ with unittest .mock .patch .object (
500
+ self ._opt ._solver_model .solution ,
501
+ "get_values" ,
502
+ wraps = self ._opt ._solver_model .solution .get_values ,
503
+ ) as wrapped_values_call :
504
+ self ._opt .load_vars ([self ._model .Y ])
505
+
506
+ self .assertEqual (wrapped_values_call .call_count , 1 )
507
+ self .assertEqual (wrapped_values_call .call_args , (([1 ],), {}))
508
+
509
+ self .assertFalse (self ._model .X .stale )
510
+ self .assertFalse (self ._model .Y .stale )
511
+ self .assertAlmostEqual (value (self ._model .X ), 3.6 )
512
+ self .assertAlmostEqual (value (self ._model .Y ), 0.8 )
513
+
514
+
440
515
if __name__ == "__main__" :
441
516
unittest .main ()
0 commit comments