1
- from yahoo_finance import Share
2
1
from matplotlib import pyplot as plt
3
2
import numpy as np
4
3
import random
5
4
import tensorflow as tf
6
5
import random
7
-
6
+ import pandas as pd
7
+ pd .core .common .is_list_like = pd .api .types .is_list_like
8
+ from pandas_datareader import data
9
+ import datetime
10
+ import requests_cache
8
11
9
12
class DecisionPolicy :
10
13
def select_action (self , current_state , step ):
@@ -43,7 +46,7 @@ def __init__(self, actions, input_dim):
43
46
loss = tf .square (self .y - self .q )
44
47
self .train_op = tf .train .AdagradOptimizer (0.01 ).minimize (loss )
45
48
self .sess = tf .Session ()
46
- self .sess .run (tf .initialize_all_variables ())
49
+ self .sess .run (tf .global_variables_initializer ())
47
50
48
51
def select_action (self , current_state , step ):
49
52
threshold = min (self .epsilon , step / 1000. )
@@ -108,17 +111,12 @@ def run_simulations(policy, budget, num_stocks, prices, hist):
108
111
return avg , std
109
112
110
113
111
- def get_prices (share_symbol , start_date , end_date , cache_filename = 'stock_prices.npy' ):
112
- try :
113
- stock_prices = np .load (cache_filename )
114
- except IOError :
115
- share = Share (share_symbol )
116
- stock_hist = share .get_historical (start_date , end_date )
117
- stock_prices = [stock_price ['Open' ] for stock_price in stock_hist ]
118
- np .save (cache_filename , stock_prices )
119
-
120
- return stock_prices
121
-
114
+ def get_prices (share_symbol , start_date , end_date ):
115
+ expire_after = datetime .timedelta (days = 3 )
116
+ session = requests_cache .CachedSession (cache_name = 'cache' , backend = 'sqlite' , expire_after = expire_after )
117
+ stock_hist = data .DataReader (share_symbol , 'iex' , start_date , end_date , session = session )
118
+ open_prices = stock_hist ['open' ]
119
+ return open_prices .values .tolist ()
122
120
123
121
def plot_prices (prices ):
124
122
plt .title ('Opening stock prices' )
@@ -129,7 +127,7 @@ def plot_prices(prices):
129
127
130
128
131
129
if __name__ == '__main__' :
132
- prices = get_prices ('MSFT' , '1992 -07-22' , '2016 -07-22' )
130
+ prices = get_prices ('MSFT' , '2013 -07-22' , '2018 -07-22' )
133
131
plot_prices (prices )
134
132
actions = ['Buy' , 'Sell' , 'Hold' ]
135
133
hist = 200
0 commit comments