Skip to content

Commit ab16c81

Browse files
committed
refresh before logging in
1 parent 651aa38 commit ab16c81

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

public_invest_api/public.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,16 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
119119
"""
120120
if username is None or password is None:
121121
raise Exception("Username or password not provided")
122+
# See if we can refresh login
123+
refresh_success = False
124+
try:
125+
response = self._refresh_token()
126+
refresh_success = True
127+
except Exception:
128+
pass
122129
headers = self.session.headers
123130
need_2fa = True
124-
if code is None:
131+
if code is None and not refresh_success:
125132
payload = self.endpoints.build_payload(username, password)
126133
self._load_cookies()
127134
response = self.session.post(
@@ -152,7 +159,7 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
152159
code = input("Enter code: ")
153160
else:
154161
need_2fa = False
155-
if need_2fa:
162+
if need_2fa and not refresh_success:
156163
payload = self.endpoints.build_payload(username, password, code)
157164
response = self.session.post(
158165
self.endpoints.mfa_url(),
@@ -171,17 +178,19 @@ def login(self, username=None, password=None, wait_for_2fa=True, code=None) -> d
171178
if response.status_code != 200:
172179
raise Exception("Login failed, check credentials")
173180
response = response.json()
174-
self.access_token = response["loginResponse"]["accessToken"]
175-
self.account_uuid = response["loginResponse"]["accounts"][0]["accountUuid"]
176-
self.account_number = response["loginResponse"]["accounts"][0]["account"]
177-
self.expires_at = (int(response["loginResponse"]["serverTime"]) / 1000) + int(
178-
response["loginResponse"]["expiresIn"]
181+
# Get info from response
182+
if "loginResponse" in response:
183+
response = response["loginResponse"]
184+
self.access_token = response["accessToken"]
185+
self.account_uuid = response["accounts"][0]["accountUuid"]
186+
self.account_number = response["accounts"][0]["account"]
187+
self.expires_at = (int(response["serverTime"]) / 1000) + int(
188+
response["expiresIn"]
179189
)
180190
self.all_login_info = response
181191
self._save_cookies()
182192
return response
183193

184-
@login_required
185194
def _refresh_token(self) -> dict:
186195
"""
187196
Refreshes the access token by making a POST request to the refresh URL.
@@ -448,7 +457,7 @@ def get_account_type(self) -> str:
448457
Returns:
449458
str: The user's account type.
450459
"""
451-
return self.all_login_info["loginResponse"]["accounts"][0]["type"]
460+
return self.all_login_info["accounts"][0]["type"]
452461

453462
@login_required
454463
def get_account_cash(self) -> float:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="public_invest_api",
5-
version="1.2.2",
5+
version="1.2.3",
66
description="Unofficial Public.com Invest API written in Python Requests",
77
long_description=open("README.md").read(),
88
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)