diff --git a/td/client.py b/td/client.py index 792df70..ff5cfe2 100644 --- a/td/client.py +++ b/td/client.py @@ -152,6 +152,14 @@ def __init__(self, client_id: str, redirect_uri: str, account_number: str = None # Initalize the client with no streaming session. self.streaming_session = None + self.request_session = requests.Session() + self.request_session.verify = True + + def __del__(self): + # clear session + if self.request_session: + self.request_session.close() + def __repr__(self) -> str: """String representation of our TD Ameritrade Class instance.""" @@ -277,7 +285,7 @@ def logout(self) -> None: # new access token or refresh token next time they use the API self._state_manager('init') - def grab_access_token(self) -> dict: + def grab_access_token(self) -> bool: """Refreshes the current access token. This takes a valid refresh token and refreshes @@ -297,7 +305,7 @@ def grab_access_token(self) -> dict: } # Make the request. - response = requests.post( + response = self.request_session.post( url="https://api.tdameritrade.com/v1/oauth2/token", headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data @@ -309,6 +317,10 @@ def grab_access_token(self) -> dict: token_dict=response.json(), includes_refresh=False ) + return True + else: + raise RuntimeError("http response error, status:{}, msg:{}".format(response.status_code, response.text)) + def grab_refresh_token(self) -> bool: """Grabs a new refresh token if expired. @@ -334,20 +346,21 @@ def grab_refresh_token(self) -> bool: } # Make the request. - response = requests.post( + response = self.request_session.post( url="https://api.tdameritrade.com/v1/oauth2/token", headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data ) if response.ok: - self._token_save( token_dict=response.json(), includes_refresh=True ) return True + else: + raise RuntimeError("http response error, status:{}, msg:{}".format(response.status_code, response.text)) def grab_url(self) -> dict: """Builds the URL that is used for oAuth.""" @@ -570,7 +583,7 @@ def _token_save(self, token_dict: dict, includes_refresh: bool = False) -> dict: return self.state def _make_request(self, method: str, endpoint: str, mode: str = None, params: dict = None, data: dict = None, json:dict = None, - order_details: bool = False) -> Any: + order_details: bool = False, **session_kwargs) -> Any: """Handles all the requests in the library. A central function used to handle all the requests made in the library, @@ -604,9 +617,7 @@ def _make_request(self, method: str, endpoint: str, mode: str = None, params: di self.validate_token() headers = self._headers(mode=mode) - # Define a new session. - request_session = requests.Session() - request_session.verify = True + request_session = self.request_session # Define a new request. request_request = requests.Request( @@ -619,9 +630,8 @@ def _make_request(self, method: str, endpoint: str, mode: str = None, params: di ).prepare() # Send the request. - response: requests.Response = request_session.send(request=request_request) + response: requests.Response = request_session.send(request=request_request, **session_kwargs) - request_session.close() # grab the status code status_code = response.status_code @@ -753,7 +763,7 @@ def _prepare_arguments_list(self, parameter_list: List) -> str: return ','.join(parameter_list) - def get_quotes(self, instruments: List) -> Dict: + def get_quotes(self, instruments: List, **session_kwargs) -> Dict: """Grabs real-time quotes for an instrument. Serves as the mechanism to make a request to the Get Quote and Get Quotes Endpoint. @@ -789,10 +799,10 @@ def get_quotes(self, instruments: List) -> Dict: endpoint = 'marketdata/quotes' # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) def get_price_history(self, symbol: str, period_type:str = None, period: str = None, start_date:str = None, end_date:str = None, - frequency_type: str = None, frequency: str = None, extended_hours: bool = True) -> Dict: + frequency_type: str = None, frequency: str = None, extended_hours: bool = True, **session_kwargs) -> Dict: """Gets historical candle data for a financial instrument. ### Documentation: @@ -859,9 +869,9 @@ def get_price_history(self, symbol: str, period_type:str = None, period: str = N endpoint = 'marketdata/{}/pricehistory'.format(symbol) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def search_instruments(self, symbol: str, projection: str = None) -> Dict: + def search_instruments(self, symbol: str, projection: str = None, **session_kwargs) -> Dict: """ Search or retrieve instrument data, including fundamental data. ### Documentation: @@ -938,9 +948,9 @@ def search_instruments(self, symbol: str, projection: str = None) -> Dict: endpoint = 'instruments' # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_instruments(self, cusip: str) -> Dict: + def get_instruments(self, cusip: str, **session_kwargs) -> Dict: """Searches an Instrument. Get an instrument by CUSIP (Committee on Uniform Securities Identification Procedures) code. @@ -969,9 +979,9 @@ def get_instruments(self, cusip: str) -> Dict: endpoint = 'instruments/{cusip}'.format(cusip=cusip) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_market_hours(self, markets: List[str], date: str) -> Dict: + def get_market_hours(self, markets: List[str], date: str, **session_kwargs) -> Dict: """Returns the hours for a specific market. Serves as the mechanism to make a request to the "Get Hours for Multiple Markets" and @@ -1019,9 +1029,9 @@ def get_market_hours(self, markets: List[str], date: str) -> Dict: endpoint = 'marketdata/hours' # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_movers(self, market: str, direction: str, change: str) -> Dict: + def get_movers(self, market: str, direction: str, change: str, **session_kwargs) -> Dict: """Gets Active movers for a specific Index. Top 10 (up or down) movers by value or percent for a particular market. @@ -1082,9 +1092,9 @@ def get_movers(self, market: str, direction: str, change: str) -> Dict: endpoint = 'marketdata/{market_id}/movers'.format(market_id=market) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_options_chain(self, option_chain: Union[Dict, OptionChain]) -> Dict: + def get_options_chain(self, option_chain: Union[Dict, OptionChain], **session_kwargs) -> Dict: """Returns Option Chain Data and Quotes. Get option chain for an optionable Symbol using one of two methods. Either, @@ -1123,7 +1133,7 @@ def get_options_chain(self, option_chain: Union[Dict, OptionChain]) -> Dict: endpoint = 'marketdata/chains' # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) """ ----------------------------------------------------------- @@ -1135,7 +1145,7 @@ def get_options_chain(self, option_chain: Union[Dict, OptionChain]) -> Dict: ----------------------------------------------------------- """ - def get_accounts(self, account: str = 'all', fields: List[str] = None) -> Dict: + def get_accounts(self, account: str = 'all', fields: List[str] = None, **session_kwargs) -> Dict: """Queries accounts for a user. Serves as the mechanism to make a request to the "Get Accounts" and "Get Account" Endpoint. @@ -1183,11 +1193,11 @@ def get_accounts(self, account: str = 'all', fields: List[str] = None) -> Dict: endpoint = 'accounts/{}'.format(account) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) def get_transactions(self, account: str = None, transaction_type: str = None, symbol: str = None, - start_date: str = None, end_date: str = None, transaction_id: str= None) -> Dict: + start_date: str = None, end_date: str = None, transaction_id: str= None, **session_kwargs) -> Dict: """Queries the transactions for an account. Serves as the mechanism to make a request to the "Get Transactions" and "Get Transaction" Endpoint. @@ -1283,7 +1293,7 @@ def get_transactions(self, account: str = None, transaction_type: str = None, sy endpoint = 'accounts/{}/transactions'.format(account) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) """ ----------------------------------------------------------- @@ -1295,7 +1305,7 @@ def get_transactions(self, account: str = None, transaction_type: str = None, sy ----------------------------------------------------------- """ - def get_preferences(self, account: str) -> Dict: + def get_preferences(self, account: str, **session_kwargs) -> Dict: """Get's User Preferences for a specific account. ### Documentation: @@ -1320,9 +1330,9 @@ def get_preferences(self, account: str) -> Dict: endpoint = 'accounts/{}/preferences'.format(account) # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint) + return self._make_request(method='get', endpoint=endpoint, **session_kwargs) - def get_streamer_subscription_keys(self, accounts: List[str]) -> Dict: + def get_streamer_subscription_keys(self, accounts: List[str], **session_kwargs) -> Dict: """SubscriptionKey for provided accounts or default accounts. ### Documentation: @@ -1353,9 +1363,9 @@ def get_streamer_subscription_keys(self, accounts: List[str]) -> Dict: } # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_user_principals(self, fields: List[str]) -> Dict: + def get_user_principals(self, fields: List[str], **session_kwargs) -> Dict: """Returns User Principal details. ### Documentation: @@ -1398,9 +1408,9 @@ def get_user_principals(self, fields: List[str]) -> Dict: } # return the response of the get request. - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def update_preferences(self, account: str, data_payload: Dict) -> Dict: + def update_preferences(self, account: str, data_payload: Dict, **session_kwargs) -> Dict: """Updates the User's Preferences. Overview: @@ -1446,7 +1456,7 @@ def update_preferences(self, account: str, data_payload: Dict) -> Dict: endpoint = 'accounts/{}/preferences'.format(account) # make the request - return self._make_request(method='put', endpoint=endpoint, mode='json', data=data_payload) + return self._make_request(method='put', endpoint=endpoint, mode='json', data=data_payload, **session_kwargs) """ ----------------------------------------------------------- @@ -1458,7 +1468,7 @@ def update_preferences(self, account: str, data_payload: Dict) -> Dict: ----------------------------------------------------------- """ - def create_watchlist(self, account: str, name: str, watchlistItems=None) -> Dict: + def create_watchlist(self, account: str, name: str, watchlistItems=None, **session_kwargs) -> Dict: """Creates a new watchlist. Create watchlist for specific account. This method does not verify that the @@ -1498,9 +1508,9 @@ def create_watchlist(self, account: str, name: str, watchlistItems=None) -> Dict } # make the request - return self._make_request(method='put', endpoint=endpoint, mode='json', data=payload) + return self._make_request(method='put', endpoint=endpoint, mode='json', data=payload, **session_kwargs) - def get_watchlist_accounts(self, account: str = 'all') -> Dict: + def get_watchlist_accounts(self, account: str = 'all', **session_kwargs) -> Dict: """Gets watchlist, by account number. Serves as the mechanism to make a request to the "Get Watchlist for Single Account" and @@ -1532,9 +1542,9 @@ def get_watchlist_accounts(self, account: str = 'all') -> Dict: endpoint = 'accounts/{}/watchlists'.format(account) # make the request - return self._make_request(method='get', endpoint=endpoint) + return self._make_request(method='get', endpoint=endpoint, **session_kwargs) - def get_watchlist(self, account: str, watchlist_id: str) -> Dict: + def get_watchlist(self, account: str, watchlist_id: str, **session_kwargs) -> Dict: """Queries a watchlist. Returns a specific watchlist for a specific account designated by the @@ -1565,9 +1575,9 @@ def get_watchlist(self, account: str, watchlist_id: str) -> Dict: endpoint = 'accounts/{}/watchlists/{}'.format(account, watchlist_id) # make the request - return self._make_request(method='get', endpoint=endpoint) + return self._make_request(method='get', endpoint=endpoint, **session_kwargs) - def delete_watchlist(self, account: str, watchlist_id: str) -> Dict: + def delete_watchlist(self, account: str, watchlist_id: str, **session_kwargs) -> Dict: """Deletes an existing watchlist Deletes a specific watchlist for a specific account. @@ -1598,9 +1608,9 @@ def delete_watchlist(self, account: str, watchlist_id: str) -> Dict: endpoint = 'accounts/{}/watchlists/{}'.format(account, watchlist_id) # make the request - return self._make_request(method='delete', endpoint=endpoint) + return self._make_request(method='delete', endpoint=endpoint, **session_kwargs) - def update_watchlist(self, account: str, watchlist_id: str, name: str, watchlistItems: Dict) -> Dict: + def update_watchlist(self, account: str, watchlist_id: str, name: str, watchlistItems: Dict, **session_kwargs) -> Dict: """Updates an Exisitng watchlist. Partially update watchlist for a specific account: change watchlist name, add to the beginning/end of a @@ -1641,9 +1651,10 @@ def update_watchlist(self, account: str, watchlist_id: str, name: str, watchlist endpoint = 'accounts/{}/watchlists/{}'.format(account, watchlist_id) # make the request - return self._make_request(method='patch', endpoint=endpoint, data=payload) + return self._make_request(method='patch', endpoint=endpoint, data=payload, **session_kwargs) - def replace_watchlist(self, account: str, watchlist_id_new: dict, watchlist_id_old: dict, name_new: str, watchlistItems_new: dict) -> Dict: + def replace_watchlist(self, account: str, watchlist_id_new: dict, watchlist_id_old: dict, name_new: str, + watchlistItems_new: dict, **session_kwargs) -> Dict: """Replaces an existing watchlist. Replace watchlist for a specific account. This method does not verify that @@ -1690,7 +1701,7 @@ def replace_watchlist(self, account: str, watchlist_id_new: dict, watchlist_id_o endpoint = 'accounts/{}/watchlists/{}'.format(account, watchlist_id_old) # make the request - return self._make_request(method='put', endpoint=endpoint, mode='json', data=payload) + return self._make_request(method='put', endpoint=endpoint, mode='json', data=payload, **session_kwargs) """ ----------------------------------------------------------- @@ -1703,7 +1714,7 @@ def replace_watchlist(self, account: str, watchlist_id_new: dict, watchlist_id_o """ def get_orders_path(self, account: str, max_results: int = None, from_entered_time: - str = None, to_entered_time: str = None, status: str = None) -> Dict: + str = None, to_entered_time: str = None, status: str = None, **session_kwargs) -> Dict: """Returns the orders for a specific account. ### Documentation: @@ -1781,10 +1792,10 @@ def get_orders_path(self, account: str, max_results: int = None, from_entered_ti endpoint = 'accounts/{}/orders'.format(account) # make the request - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) def get_orders_query(self, account: str = None, max_results: int = None, from_entered_time: str = None, - to_entered_time: str = None, status: str = None) -> Dict: + to_entered_time: str = None, status: str = None, **session_kwargs) -> Dict: """Get's all the orders for an account. All orders for a specific account or, if account ID isn't specified, orders will be returned for all linked accounts @@ -1865,9 +1876,9 @@ def get_orders_query(self, account: str = None, max_results: int = None, from_en endpoint = 'orders' # make the request - return self._make_request(method='get', endpoint=endpoint, params=params) + return self._make_request(method='get', endpoint=endpoint, params=params, **session_kwargs) - def get_orders(self, account: str, order_id: str = None) -> Dict: + def get_orders(self, account: str, order_id: str = None, **session_kwargs) -> Dict: """Gets the orders for an account Returns all orders for a specific account or, if account ID @@ -1903,9 +1914,9 @@ def get_orders(self, account: str, order_id: str = None) -> Dict: endpoint = 'accounts/{}/orders'.format(account) # make the request - return self._make_request(method='get', endpoint=endpoint) + return self._make_request(method='get', endpoint=endpoint, **session_kwargs) - def cancel_order(self, account: str, order_id: str) -> Dict: + def cancel_order(self, account: str, order_id: str, **session_kwargs) -> Dict: """Cancel a specific order for a specific account. ### Documentation: @@ -1931,10 +1942,10 @@ def cancel_order(self, account: str, order_id: str) -> Dict: endpoint = 'accounts/{}/orders/{}'.format(account, order_id) # delete the request - return self._make_request(method='delete', endpoint=endpoint, order_details=True) + return self._make_request(method='delete', endpoint=endpoint, order_details=True, **session_kwargs) - def place_order(self, account: str, order: dict) -> dict: + def place_order(self, account: str, order: dict, **session_kwargs) -> dict: """Places an order for a specific account. ### Documentation: @@ -1964,9 +1975,10 @@ def place_order(self, account: str, order: dict) -> dict: # make the request endpoint = 'accounts/{}/orders'.format(account) - return self._make_request(method='post', endpoint=endpoint, mode='json', json=order, order_details=True) + return self._make_request(method='post', endpoint=endpoint, mode='json', json=order, order_details=True, + **session_kwargs) - def modify_order(self, account: str, order: dict, order_id: str) -> dict: + def modify_order(self, account: str, order: dict, order_id: str, **session_kwargs) -> dict: """Modifies an exisiting order. ### Documentation: @@ -2007,10 +2019,11 @@ def modify_order(self, account: str, order: dict, order_id: str) -> dict: endpoint=endpoint, mode='json', json=order, - order_details=True + order_details=True, + **session_kwargs ) - def get_saved_order(self, account: str, saved_order_id: str = None) -> Dict: + def get_saved_order(self, account: str, saved_order_id: str = None, **session_kwargs) -> Dict: """Grabs a saved order. Grabs all the saved orders for a specific account or, if account @@ -2037,9 +2050,9 @@ def get_saved_order(self, account: str, saved_order_id: str = None) -> Dict: # define the endpoint endpoint = 'accounts/{}/savedorders/{}'.format(account, saved_order_id) - return self._make_request(method='get', endpoint=endpoint) + return self._make_request(method='get', endpoint=endpoint, **session_kwargs) - def cancel_saved_order(self, account: str, saved_order_id: str) -> Dict: + def cancel_saved_order(self, account: str, saved_order_id: str, **session_kwargs) -> Dict: """Cancel a saved order Using a saved order ID and account number, will delete the order from @@ -2066,10 +2079,10 @@ def cancel_saved_order(self, account: str, saved_order_id: str) -> Dict: # define the endpoint endpoint = 'accounts/{}/savedorders/{}'.format(account, saved_order_id) - return self._make_request(method='delete', endpoint=endpoint, order_details=True) + return self._make_request(method='delete', endpoint=endpoint, order_details=True, **session_kwargs) - def create_saved_order(self, account: str, saved_order: dict) -> dict: + def create_saved_order(self, account: str, saved_order: dict, **session_kwargs) -> dict: """Creates a saved order Creates a saved order for the specified account. @@ -2101,7 +2114,8 @@ def create_saved_order(self, account: str, saved_order: dict) -> dict: # make the request endpoint = 'accounts/{}/savedorders'.format(account) - return self._make_request(method='post', endpoint=endpoint, mode='json', json=saved_order, order_details=True) + return self._make_request(method='post', endpoint=endpoint, mode='json', json=saved_order, order_details=True, + **session_kwargs) def _create_token_timestamp(self, token_timestamp: str) -> int: """Parses the token and converts it to a timestamp.