1313
1414from gokart .file_processor import _ChunkedLargeFileReader
1515from gokart .target import make_model_target , make_target
16-
17-
18- def _get_temporary_directory ():
19- return os .path .abspath (os .path .join (os .path .dirname (__name__ ), 'temporary' ))
16+ from test .util import _get_temporary_directory
2017
2118
2219class LocalTargetTest (unittest .TestCase ):
20+ def setUp (self ):
21+ self .temporary_directory = _get_temporary_directory ()
22+
2323 def tearDown (self ):
24- shutil .rmtree (_get_temporary_directory () , ignore_errors = True )
24+ shutil .rmtree (self . temporary_directory , ignore_errors = True )
2525
2626 def test_save_and_load_pickle_file (self ):
2727 obj = 1
28- file_path = os .path .join (_get_temporary_directory () , 'test.pkl' )
28+ file_path = os .path .join (self . temporary_directory , 'test.pkl' )
2929
3030 target = make_target (file_path = file_path , unique_id = None )
3131 target .dump (obj )
@@ -37,7 +37,7 @@ def test_save_and_load_pickle_file(self):
3737
3838 def test_save_and_load_text_file (self ):
3939 obj = 1
40- file_path = os .path .join (_get_temporary_directory () , 'test.txt' )
40+ file_path = os .path .join (self . temporary_directory , 'test.txt' )
4141
4242 target = make_target (file_path = file_path , unique_id = None )
4343 target .dump (obj )
@@ -47,7 +47,7 @@ def test_save_and_load_text_file(self):
4747
4848 def test_save_and_load_gzip (self ):
4949 obj = 1
50- file_path = os .path .join (_get_temporary_directory () , 'test.gz' )
50+ file_path = os .path .join (self . temporary_directory , 'test.gz' )
5151
5252 target = make_target (file_path = file_path , unique_id = None )
5353 target .dump (obj )
@@ -57,7 +57,7 @@ def test_save_and_load_gzip(self):
5757
5858 def test_save_and_load_npz (self ):
5959 obj = np .ones (shape = 10 , dtype = np .float32 )
60- file_path = os .path .join (_get_temporary_directory () , 'test.npz' )
60+ file_path = os .path .join (self . temporary_directory , 'test.npz' )
6161 target = make_target (file_path = file_path , unique_id = None )
6262 target .dump (obj )
6363 loaded = target .load ()
@@ -69,7 +69,7 @@ def test_save_and_load_figure(self):
6969 pd .DataFrame (dict (x = range (10 ), y = range (10 ))).plot .scatter (x = 'x' , y = 'y' )
7070 pyplot .savefig (figure_binary )
7171 figure_binary .seek (0 )
72- file_path = os .path .join (_get_temporary_directory () , 'test.png' )
72+ file_path = os .path .join (self . temporary_directory , 'test.png' )
7373 target = make_target (file_path = file_path , unique_id = None )
7474 target .dump (figure_binary .read ())
7575
@@ -78,7 +78,7 @@ def test_save_and_load_figure(self):
7878
7979 def test_save_and_load_csv (self ):
8080 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]))
81- file_path = os .path .join (_get_temporary_directory () , 'test.csv' )
81+ file_path = os .path .join (self . temporary_directory , 'test.csv' )
8282
8383 target = make_target (file_path = file_path , unique_id = None )
8484 target .dump (obj )
@@ -88,7 +88,7 @@ def test_save_and_load_csv(self):
8888
8989 def test_save_and_load_tsv (self ):
9090 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]))
91- file_path = os .path .join (_get_temporary_directory () , 'test.tsv' )
91+ file_path = os .path .join (self . temporary_directory , 'test.tsv' )
9292
9393 target = make_target (file_path = file_path , unique_id = None )
9494 target .dump (obj )
@@ -98,7 +98,7 @@ def test_save_and_load_tsv(self):
9898
9999 def test_save_and_load_parquet (self ):
100100 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]))
101- file_path = os .path .join (_get_temporary_directory () , 'test.parquet' )
101+ file_path = os .path .join (self . temporary_directory , 'test.parquet' )
102102
103103 target = make_target (file_path = file_path , unique_id = None )
104104 target .dump (obj )
@@ -108,7 +108,7 @@ def test_save_and_load_parquet(self):
108108
109109 def test_save_and_load_feather (self ):
110110 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]), index = pd .Index ([33 , 44 ], name = 'object_index' ))
111- file_path = os .path .join (_get_temporary_directory () , 'test.feather' )
111+ file_path = os .path .join (self . temporary_directory , 'test.feather' )
112112
113113 target = make_target (file_path = file_path , unique_id = None )
114114 target .dump (obj )
@@ -118,7 +118,7 @@ def test_save_and_load_feather(self):
118118
119119 def test_save_and_load_feather_without_store_index_in_feather (self ):
120120 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]), index = pd .Index ([33 , 44 ], name = 'object_index' )).reset_index ()
121- file_path = os .path .join (_get_temporary_directory () , 'test.feather' )
121+ file_path = os .path .join (self . temporary_directory , 'test.feather' )
122122
123123 target = make_target (file_path = file_path , unique_id = None , store_index_in_feather = False )
124124 target .dump (obj )
@@ -128,22 +128,22 @@ def test_save_and_load_feather_without_store_index_in_feather(self):
128128
129129 def test_last_modified_time (self ):
130130 obj = pd .DataFrame (dict (a = [1 , 2 ], b = [3 , 4 ]))
131- file_path = os .path .join (_get_temporary_directory () , 'test.csv' )
131+ file_path = os .path .join (self . temporary_directory , 'test.csv' )
132132
133133 target = make_target (file_path = file_path , unique_id = None )
134134 target .dump (obj )
135135 t = target .last_modification_time ()
136136 self .assertIsInstance (t , datetime )
137137
138138 def test_last_modified_time_without_file (self ):
139- file_path = os .path .join (_get_temporary_directory () , 'test.csv' )
139+ file_path = os .path .join (self . temporary_directory , 'test.csv' )
140140 target = make_target (file_path = file_path , unique_id = None )
141141 with self .assertRaises (FileNotFoundError ):
142142 target .last_modification_time ()
143143
144144 def test_save_pandas_series (self ):
145145 obj = pd .Series (data = [1 , 2 ], name = 'column_name' )
146- file_path = os .path .join (_get_temporary_directory () , 'test.csv' )
146+ file_path = os .path .join (self . temporary_directory , 'test.csv' )
147147
148148 target = make_target (file_path = file_path , unique_id = None )
149149 target .dump (obj )
@@ -154,7 +154,7 @@ def test_save_pandas_series(self):
154154 def test_dump_with_lock (self ):
155155 with patch ('gokart.target.wrap_dump_with_lock' ) as wrap_with_lock_mock :
156156 obj = 1
157- file_path = os .path .join (_get_temporary_directory () , 'test.pkl' )
157+ file_path = os .path .join (self . temporary_directory , 'test.pkl' )
158158 target = make_target (file_path = file_path , unique_id = None )
159159 target .dump (obj , lock_at_dump = True )
160160
@@ -163,7 +163,7 @@ def test_dump_with_lock(self):
163163 def test_dump_without_lock (self ):
164164 with patch ('gokart.target.wrap_dump_with_lock' ) as wrap_with_lock_mock :
165165 obj = 1
166- file_path = os .path .join (_get_temporary_directory () , 'test.pkl' )
166+ file_path = os .path .join (self . temporary_directory , 'test.pkl' )
167167 target = make_target (file_path = file_path , unique_id = None )
168168 target .dump (obj , lock_at_dump = False )
169169
@@ -238,8 +238,11 @@ def test_save_on_s3_parquet(self):
238238
239239
240240class ModelTargetTest (unittest .TestCase ):
241+ def setUp (self ):
242+ self .temporary_directory = _get_temporary_directory ()
243+
241244 def tearDown (self ):
242- shutil .rmtree (_get_temporary_directory () , ignore_errors = True )
245+ shutil .rmtree (self . temporary_directory , ignore_errors = True )
243246
244247 @staticmethod
245248 def _save_function (obj , path ):
@@ -251,10 +254,10 @@ def _load_function(path):
251254
252255 def test_model_target_on_local (self ):
253256 obj = 1
254- file_path = os .path .join (_get_temporary_directory () , 'test.zip' )
257+ file_path = os .path .join (self . temporary_directory , 'test.zip' )
255258
256259 target = make_model_target (
257- file_path = file_path , temporary_directory = _get_temporary_directory () , save_function = self ._save_function , load_function = self ._load_function
260+ file_path = file_path , temporary_directory = self . temporary_directory , save_function = self ._save_function , load_function = self ._load_function
258261 )
259262
260263 target .dump (obj )
@@ -271,7 +274,7 @@ def test_model_target_on_s3(self):
271274 file_path = os .path .join ('s3://test/' , 'test.zip' )
272275
273276 target = make_model_target (
274- file_path = file_path , temporary_directory = _get_temporary_directory () , save_function = self ._save_function , load_function = self ._load_function
277+ file_path = file_path , temporary_directory = self . temporary_directory , save_function = self ._save_function , load_function = self ._load_function
275278 )
276279
277280 target .dump (obj )
0 commit comments