Skip to content

Commit 2b4e66c

Browse files
committed
fix refresh expiring
1 parent 0a64244 commit 2b4e66c

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

public_invest_api/endpoints.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def get_pending_orders_url(self, account_uuid):
5151
def cancel_pending_order_url(self, account_uuid, order_id):
5252
return f"{self.ordergateway}/accounts/{account_uuid}/orders/{order_id}"
5353

54-
@staticmethod
55-
def build_headers(auth=None, prodApi=False):
54+
def build_headers(self, auth=None, prodApi=False):
5655
headers = {
5756
"authority": "public.com",
5857
"accept": "*/*",
@@ -76,7 +75,7 @@ def build_headers(auth=None, prodApi=False):
7675
if auth is not None:
7776
headers["authorization"] = auth
7877
if prodApi:
79-
headers["authority"] = "prod-api.154310543964.hellopublic.com"
78+
headers["authority"] = self.prodapi.replace("https://", "")
8079
headers["sec-fetch-site"] = "cross-site"
8180
return headers
8281

public_invest_api/public.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ def wrapper(self, *args, **kwargs):
2828
return wrapper
2929

3030

31+
def refresh_check(func):
32+
def wrapper(self, *args, **kwargs):
33+
"""
34+
A wrapper function that checks if the access token needs to be refreshed
35+
Args:
36+
*args: Variable length argument list.
37+
**kwargs: Arbitrary keyword arguments.
38+
Raises:
39+
Exception: If the access token refresh fails (i.e., response status code is not 200).
40+
Returns:
41+
The result of the function `func` if the access token does not need to be refreshed.
42+
"""
43+
if self.expires_at is not None and datetime.now().timestamp() > self.expires_at:
44+
self._refresh_token()
45+
return func(self, *args, **kwargs)
46+
47+
return wrapper
48+
49+
3150
class Public:
3251
def __init__(self, filename=None, path=None):
3352
self.session = requests.Session()
@@ -38,6 +57,7 @@ def __init__(self, filename=None, path=None):
3857
self.account_number = None
3958
self.all_login_info = None
4059
self.timeout = 10
60+
self.expires_at = None
4161
self.filename = "public_credentials.pkl"
4262
if filename is not None:
4363
self.filename = filename
@@ -110,6 +130,15 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
110130
data=payload,
111131
timeout=self.timeout,
112132
)
133+
if response.status_code != 200:
134+
# Perhaps cookies are expired
135+
self._clear_cookies()
136+
response = self.session.post(
137+
self.endpoints.login_url(),
138+
headers=headers,
139+
data=payload,
140+
timeout=self.timeout,
141+
)
113142
if response.status_code != 200:
114143
print(response.text)
115144
raise Exception("Login failed, check credentials")
@@ -145,12 +174,13 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
145174
self.access_token = response["loginResponse"]["accessToken"]
146175
self.account_uuid = response["loginResponse"]["accounts"][0]["accountUuid"]
147176
self.account_number = response["loginResponse"]["accounts"][0]["account"]
177+
self.expires_at = (int(response["loginResponse"]["serverTime"]) / 1000) + int(response["loginResponse"]["expiresIn"])
148178
self.all_login_info = response
149179
self._save_cookies()
150180
return response
151181

152182
@login_required
153-
def refresh_token(self) -> dict:
183+
def _refresh_token(self) -> dict:
154184
"""
155185
Refreshes the access token by making a POST request to the refresh URL.
156186
Returns:
@@ -166,10 +196,13 @@ def refresh_token(self) -> dict:
166196
raise Exception("Token refresh failed")
167197
response = response.json()
168198
self.access_token = response["accessToken"]
199+
self.expires_at = (int(response["serverTime"]) / 1000) + int(response["expiresIn"])
200+
self.account_uuid = response["accounts"][0]["accountUuid"]
169201
self._save_cookies()
170202
return response
171203

172204
@login_required
205+
@refresh_check
173206
def get_portfolio(self) -> dict:
174207
"""
175208
Gets the user's portfolio by making a GET request to the portfolio URL.
@@ -178,7 +211,7 @@ def get_portfolio(self) -> dict:
178211
Raises:
179212
Exception: If the portfolio request fails (i.e., response status code is not 200).
180213
"""
181-
headers = self.endpoints.build_headers(self.access_token)
214+
headers = self.endpoints.build_headers(self.access_token, prodApi=True)
182215
portfolio = self.session.get(
183216
self.endpoints.portfolio_url(self.account_uuid),
184217
headers=headers,
@@ -222,6 +255,7 @@ def _history_filter_date(date: str) -> dict:
222255
}
223256

224257
@login_required
258+
@refresh_check
225259
def get_account_history(
226260
self,
227261
date="all",
@@ -423,6 +457,7 @@ def get_account_cash(self) -> float:
423457
return account_info["equity"]["cash"]
424458

425459
@login_required
460+
@refresh_check
426461
def get_symbol_price(self, symbol) -> float:
427462
"""
428463
Gets the price of a stock by making a GET request to the quote URL.
@@ -445,6 +480,7 @@ def get_symbol_price(self, symbol) -> float:
445480
return response.json()["price"]
446481

447482
@login_required
483+
@refresh_check
448484
def get_order_quote(self, symbol) -> dict:
449485
"""
450486
Gets a quote for an order by making a GET request to the order quote URL.
@@ -466,6 +502,7 @@ def get_order_quote(self, symbol) -> dict:
466502
return response.json()
467503

468504
@login_required
505+
@refresh_check
469506
def place_order(
470507
self,
471508
symbol,
@@ -587,6 +624,7 @@ def place_order(
587624
return check_response
588625

589626
@login_required
627+
@refresh_check
590628
def get_pending_orders(self) -> dict:
591629
"""
592630
Gets the user's pending orders by making a GET request to the pending orders URL.
@@ -606,6 +644,7 @@ def get_pending_orders(self) -> dict:
606644
return response.json()
607645

608646
@login_required
647+
@refresh_check
609648
def cancel_order(self, order_id) -> dict:
610649
"""
611650
Cancels an order by making a DELETE request to the cancel order URL.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="public_invest_api",
5-
version="1.2.1",
5+
version="1.2.2",
66
description="Unofficial Public.com Invest API written in Python Requests",
77
long_description=open("README.md").read(),
88
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)