5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
- import copy
9
8
import unittest
10
9
import parlai .utils .testing as testing_utils
11
10
import parlai .scripts .build_dict as build_dict
15
14
BATCHSIZE = 4
16
15
17
16
18
- def _forced_parse (parser , opt ):
19
- parser .set_params (** opt )
20
- parser .set_params (log_every_n_sec = 10 )
21
- popt = parser .parse_args ([])
22
- # in some rare cases, like for instance if the model class also
23
- # overrides its default params, the params override will not
24
- # be taken into account.
25
- for k , v in opt .items ():
26
- popt [k ] = v
27
- return popt
17
+ class _AbstractTest (unittest .TestCase ):
18
+ def _distributed_train_model (self , ** overrides ):
19
+ opt = {** self .base_config , ** overrides }
20
+ with testing_utils .tempdir () as tmpdir :
21
+ if 'model_file' not in opt :
22
+ opt ['model_file' ] = os .path .join (tmpdir , 'model' )
23
+ if 'dict_file' not in opt :
24
+ opt ['dict_file' ] = os .path .join (tmpdir , 'model.dict' )
25
+
26
+ parser = mp_train .setup_args ()
27
+ popt = parser .parse_kwargs (** opt )
28
+
29
+ # we need a prebuilt dictionary
30
+ parser = build_dict .setup_args ()
31
+ build_dict .build_dict (popt )
32
+
33
+ valid , test = mp_train .launch_and_train (popt )
34
+
35
+ return (valid , test )
28
36
29
37
30
38
@testing_utils .skipUnlessGPU
31
- class TestDistributed (unittest . TestCase ):
32
- _base_config = dict (
39
+ class TestDistributed (_AbstractTest ):
40
+ base_config = dict (
33
41
task = 'integration_tests:overfit' ,
34
42
model = 'transformer/generator' ,
35
43
optimizer = 'adam' ,
@@ -46,30 +54,8 @@ class TestDistributed(unittest.TestCase):
46
54
verbose = True ,
47
55
)
48
56
49
- def setUp (self ):
50
- print (f'[Setting up test { self ._testMethodName } ]' )
51
-
52
- def _distributed_train_model (self , opt ):
53
- with testing_utils .tempdir () as tmpdir :
54
- if 'model_file' not in opt :
55
- opt ['model_file' ] = os .path .join (tmpdir , 'model' )
56
- if 'dict_file' not in opt :
57
- opt ['dict_file' ] = os .path .join (tmpdir , 'model.dict' )
58
-
59
- parser = mp_train .setup_args ()
60
- popt = _forced_parse (parser , opt )
61
-
62
- # we need a prebuilt dictionary
63
- parser = build_dict .setup_args ()
64
- build_dict .build_dict (popt )
65
-
66
- valid , test = mp_train .launch_and_train (popt , 31338 )
67
-
68
- return (valid , test )
69
-
70
57
def test_generator_distributed (self ):
71
- config = copy .deepcopy (self ._base_config )
72
- valid , test = self ._distributed_train_model (config )
58
+ valid , test = self ._distributed_train_model ()
73
59
74
60
self .assertLessEqual (valid ['ppl' ], 1.60 )
75
61
self .assertLessEqual (test ['ppl' ], 1.60 )
@@ -80,11 +66,11 @@ def test_generator_distributed(self):
80
66
self .assertEqual (test ['exs' ].value (), BATCHSIZE )
81
67
82
68
def test_multitask_distributed (self ):
83
- config = copy . deepcopy ( self ._base_config )
84
- config [ 'num_epochs' ] = 50
85
- config [ 'task' ] = 'integration_tests:overfit,integration_tests:overfit_multiturn'
86
- config [ 'dynb' ] = 'full'
87
- valid , test = self . _distributed_train_model ( config )
69
+ valid , test = self ._distributed_train_model (
70
+ num_epochs = 50 ,
71
+ task = 'integration_tests:overfit,integration_tests:overfit_multiturn' ,
72
+ truncate = 16 ,
73
+ )
88
74
89
75
self .assertLessEqual (valid ['ppl' ], 1.20 )
90
76
self .assertLessEqual (test ['ppl' ], 1.20 )
@@ -100,12 +86,12 @@ def test_multitask_distributed(self):
100
86
)
101
87
102
88
def test_distributed_eval_max_exs (self ):
103
- config = copy . deepcopy ( self ._base_config )
104
- config [ 'task' ] = 'integration_tests'
105
- config [ 'num_epochs' ] = 0.01
106
- config [ 'validation_max_exs' ] = 90
107
- config [ 'short_final_eval' ] = True
108
- valid , test = self . _distributed_train_model ( config )
89
+ valid , test = self ._distributed_train_model (
90
+ task = 'integration_tests' ,
91
+ num_epochs = 0.01 ,
92
+ validation_max_exs = 90 ,
93
+ short_final_eval = True ,
94
+ )
109
95
110
96
# Tests that DialogData.get() is doing the right thing
111
97
# Ensure no duplication of examples among workers
@@ -120,11 +106,9 @@ def test_distributed_eval_max_exs(self):
120
106
self .assertEqual (test ['exs' ].value (), 96 )
121
107
122
108
def test_distributed_eval_stream_mode (self ):
123
- config = copy .deepcopy (self ._base_config )
124
- config ['task' ] = 'integration_tests'
125
- config ['num_epochs' ] = 0.01
126
- config ['datatype' ] = 'train:stream'
127
- valid , test = self ._distributed_train_model (config )
109
+ valid , test = self ._distributed_train_model (
110
+ task = 'integration_tests' , num_epochs = 0.01 , datatype = 'train:stream'
111
+ )
128
112
129
113
# Tests that StreamDialogData.get() is doing the right thing
130
114
# Ensure no duplication of examples among workers
@@ -133,14 +117,13 @@ def test_distributed_eval_stream_mode(self):
133
117
self .assertEqual (test ['exs' ].value (), inttests .NUM_TEST )
134
118
135
119
def test_distributed_eval_stream_mode_max_exs (self ):
136
- config = copy .deepcopy (self ._base_config )
137
- config ['task' ] = 'integration_tests'
138
- config ['num_epochs' ] = 0.01
139
- config ['datatype' ] = 'train:stream'
140
- config ['validation_max_exs' ] = 90
141
- config ['short_final_eval' ] = True
142
-
143
- valid , test = self ._distributed_train_model (config )
120
+ valid , test = self ._distributed_train_model (
121
+ task = 'integration_tests' ,
122
+ num_epochs = 0.01 ,
123
+ datatype = 'train:stream' ,
124
+ validation_max_exs = 90 ,
125
+ short_final_eval = True ,
126
+ )
144
127
145
128
# Tests that StreamDialogData.get() is doing the right thing
146
129
# Ensure no duplication of examples among workers
@@ -155,45 +138,68 @@ def test_distributed_eval_stream_mode_max_exs(self):
155
138
self .assertEqual (test ['exs' ].value (), 96 )
156
139
157
140
def test_chunked_dynamic_teacher (self ):
158
- config = copy .deepcopy (self ._base_config )
159
- config ['task' ] = 'integration_tests'
160
- config ['num_epochs' ] = 0.01
161
- config ['datatype' ] = 'train:stream'
162
- config ['dynamic_batching' ] = 'full'
163
- config ['truncate' ] = 16
164
-
165
- valid , test = self ._distributed_train_model (config )
141
+ valid , test = self ._distributed_train_model (
142
+ task = 'integration_tests' ,
143
+ num_epochs = 0.01 ,
144
+ datatype = 'train:stream' ,
145
+ dynamic_batching = 'full' ,
146
+ truncate = 16 ,
147
+ )
166
148
assert valid ['exs' ].value () == inttests .NUM_TEST
167
149
assert test ['exs' ].value () == inttests .NUM_TEST
168
150
169
151
def test_chunked_teacher (self ):
170
- config = copy .deepcopy (self ._base_config )
171
- config ['task' ] = 'integration_tests'
172
- config ['num_epochs' ] = 0.01
173
- config ['datatype' ] = 'train:stream'
174
- config ['num_epochs' ] = 5
175
- config ['dynamic_batching' ] = None
176
-
177
- valid , test = self ._distributed_train_model (config )
152
+ valid , test = self ._distributed_train_model (
153
+ task = 'integration_tests' ,
154
+ datatype = 'train:stream' ,
155
+ num_epochs = 5 ,
156
+ dynamic_batching = None ,
157
+ )
178
158
assert valid ['exs' ].value () == inttests .NUM_TEST
179
159
assert test ['exs' ].value () == inttests .NUM_TEST
180
160
161
+
162
+ @testing_utils .skipUnlessGPU
163
+ class TestZero2 (TestDistributed ):
164
+ """
165
+ Integration tests for zero2 FSDP.
166
+ """
167
+
168
+ base_config = {** TestDistributed .base_config , 'ddp_backend' : 'zero2' }
169
+
170
+
171
+ @unittest .skip
172
+ @testing_utils .skipUnlessGPU
173
+ class TestZero3 (TestDistributed ):
174
+ # Not supported at this time. See:
175
+ # https://github.com/facebookresearch/ParlAI/pull/3740
176
+ base_config = {** TestDistributed .base_config , 'ddp_backend' : 'zero3' }
177
+
178
+
179
+ @testing_utils .skipUnlessGPU
180
+ class TestNoModelParallel (_AbstractTest ):
181
+ base_config = dict (
182
+ task = 'integration_tests:overfit' ,
183
+ optimizer = 'sgd' ,
184
+ validation_metric = 'loss' ,
185
+ learningrate = 1e-2 ,
186
+ batchsize = BATCHSIZE ,
187
+ validation_every_n_epochs = 1 ,
188
+ num_epochs = 1 ,
189
+ n_layers = 1 ,
190
+ n_heads = 1 ,
191
+ ffn_size = 32 ,
192
+ embedding_size = 8 ,
193
+ verbose = True ,
194
+ )
195
+
181
196
def test_no_model_parallel (self ):
182
197
"""
183
- Checks that we throw an error when combining mp_train with.
184
-
185
- --model-parallel true.
198
+ Checks that we throw an error when combining mp_train with --model-parallel.
186
199
"""
187
- config = copy .deepcopy (self ._base_config )
188
- config ['model_parallel' ] = True
189
- for m in [
190
- 'transformer/generator' ,
191
- 'transformer/ranker' ,
192
- 'transformer/classifier' ,
193
- ]:
194
- config ['model' ] = m
200
+ for m in ['transformer/generator' , 'transformer/ranker' ]:
195
201
try :
196
- _ = self ._distributed_train_model (config )
202
+ _ = self ._distributed_train_model (model = m , model_parallel = True )
197
203
except RuntimeError :
198
204
pass
199
205
else :
0 commit comments