@@ -28,6 +28,25 @@ def wrapper(self, *args, **kwargs):
28
28
return wrapper
29
29
30
30
31
+ def refresh_check (func ):
32
+ def wrapper (self , * args , ** kwargs ):
33
+ """
34
+ A wrapper function that checks if the access token needs to be refreshed
35
+ Args:
36
+ *args: Variable length argument list.
37
+ **kwargs: Arbitrary keyword arguments.
38
+ Raises:
39
+ Exception: If the access token refresh fails (i.e., response status code is not 200).
40
+ Returns:
41
+ The result of the function `func` if the access token does not need to be refreshed.
42
+ """
43
+ if self .expires_at is not None and datetime .now ().timestamp () > self .expires_at :
44
+ self ._refresh_token ()
45
+ return func (self , * args , ** kwargs )
46
+
47
+ return wrapper
48
+
49
+
31
50
class Public :
32
51
def __init__ (self , filename = None , path = None ):
33
52
self .session = requests .Session ()
@@ -38,6 +57,7 @@ def __init__(self, filename=None, path=None):
38
57
self .account_number = None
39
58
self .all_login_info = None
40
59
self .timeout = 10
60
+ self .expires_at = None
41
61
self .filename = "public_credentials.pkl"
42
62
if filename is not None :
43
63
self .filename = filename
@@ -110,6 +130,15 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
110
130
data = payload ,
111
131
timeout = self .timeout ,
112
132
)
133
+ if response .status_code != 200 :
134
+ # Perhaps cookies are expired
135
+ self ._clear_cookies ()
136
+ response = self .session .post (
137
+ self .endpoints .login_url (),
138
+ headers = headers ,
139
+ data = payload ,
140
+ timeout = self .timeout ,
141
+ )
113
142
if response .status_code != 200 :
114
143
print (response .text )
115
144
raise Exception ("Login failed, check credentials" )
@@ -145,12 +174,13 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
145
174
self .access_token = response ["loginResponse" ]["accessToken" ]
146
175
self .account_uuid = response ["loginResponse" ]["accounts" ][0 ]["accountUuid" ]
147
176
self .account_number = response ["loginResponse" ]["accounts" ][0 ]["account" ]
177
+ self .expires_at = (int (response ["loginResponse" ]["serverTime" ]) / 1000 ) + int (response ["loginResponse" ]["expiresIn" ])
148
178
self .all_login_info = response
149
179
self ._save_cookies ()
150
180
return response
151
181
152
182
@login_required
153
- def refresh_token (self ) -> dict :
183
+ def _refresh_token (self ) -> dict :
154
184
"""
155
185
Refreshes the access token by making a POST request to the refresh URL.
156
186
Returns:
@@ -166,10 +196,13 @@ def refresh_token(self) -> dict:
166
196
raise Exception ("Token refresh failed" )
167
197
response = response .json ()
168
198
self .access_token = response ["accessToken" ]
199
+ self .expires_at = (int (response ["serverTime" ]) / 1000 ) + int (response ["expiresIn" ])
200
+ self .account_uuid = response ["accounts" ][0 ]["accountUuid" ]
169
201
self ._save_cookies ()
170
202
return response
171
203
172
204
@login_required
205
+ @refresh_check
173
206
def get_portfolio (self ) -> dict :
174
207
"""
175
208
Gets the user's portfolio by making a GET request to the portfolio URL.
@@ -178,7 +211,7 @@ def get_portfolio(self) -> dict:
178
211
Raises:
179
212
Exception: If the portfolio request fails (i.e., response status code is not 200).
180
213
"""
181
- headers = self .endpoints .build_headers (self .access_token )
214
+ headers = self .endpoints .build_headers (self .access_token , prodApi = True )
182
215
portfolio = self .session .get (
183
216
self .endpoints .portfolio_url (self .account_uuid ),
184
217
headers = headers ,
@@ -222,6 +255,7 @@ def _history_filter_date(date: str) -> dict:
222
255
}
223
256
224
257
@login_required
258
+ @refresh_check
225
259
def get_account_history (
226
260
self ,
227
261
date = "all" ,
@@ -423,6 +457,7 @@ def get_account_cash(self) -> float:
423
457
return account_info ["equity" ]["cash" ]
424
458
425
459
@login_required
460
+ @refresh_check
426
461
def get_symbol_price (self , symbol ) -> float :
427
462
"""
428
463
Gets the price of a stock by making a GET request to the quote URL.
@@ -445,6 +480,7 @@ def get_symbol_price(self, symbol) -> float:
445
480
return response .json ()["price" ]
446
481
447
482
@login_required
483
+ @refresh_check
448
484
def get_order_quote (self , symbol ) -> dict :
449
485
"""
450
486
Gets a quote for an order by making a GET request to the order quote URL.
@@ -466,6 +502,7 @@ def get_order_quote(self, symbol) -> dict:
466
502
return response .json ()
467
503
468
504
@login_required
505
+ @refresh_check
469
506
def place_order (
470
507
self ,
471
508
symbol ,
@@ -587,6 +624,7 @@ def place_order(
587
624
return check_response
588
625
589
626
@login_required
627
+ @refresh_check
590
628
def get_pending_orders (self ) -> dict :
591
629
"""
592
630
Gets the user's pending orders by making a GET request to the pending orders URL.
@@ -606,6 +644,7 @@ def get_pending_orders(self) -> dict:
606
644
return response .json ()
607
645
608
646
@login_required
647
+ @refresh_check
609
648
def cancel_order (self , order_id ) -> dict :
610
649
"""
611
650
Cancels an order by making a DELETE request to the cancel order URL.
0 commit comments