Skip to content

Commit d7e51b8

Browse files
committed
Changed the get_prices function to use pandas DataReader because
the yahoo library doesn't work any more. Also changed to tf.global_variables_initializer() to get rid of the deprecation warning
1 parent 4461953 commit d7e51b8

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

ch08_rl/prices.png

-1000 Bytes
Loading

ch08_rl/rl.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from yahoo_finance import Share
21
from matplotlib import pyplot as plt
32
import numpy as np
43
import random
54
import tensorflow as tf
65
import random
7-
6+
import pandas as pd
7+
pd.core.common.is_list_like = pd.api.types.is_list_like
8+
from pandas_datareader import data
9+
import datetime
10+
import requests_cache
811

912
class DecisionPolicy:
1013
def select_action(self, current_state, step):
@@ -43,7 +46,7 @@ def __init__(self, actions, input_dim):
4346
loss = tf.square(self.y - self.q)
4447
self.train_op = tf.train.AdagradOptimizer(0.01).minimize(loss)
4548
self.sess = tf.Session()
46-
self.sess.run(tf.initialize_all_variables())
49+
self.sess.run(tf.global_variables_initializer())
4750

4851
def select_action(self, current_state, step):
4952
threshold = min(self.epsilon, step / 1000.)
@@ -108,17 +111,12 @@ def run_simulations(policy, budget, num_stocks, prices, hist):
108111
return avg, std
109112

110113

111-
def get_prices(share_symbol, start_date, end_date, cache_filename='stock_prices.npy'):
112-
try:
113-
stock_prices = np.load(cache_filename)
114-
except IOError:
115-
share = Share(share_symbol)
116-
stock_hist = share.get_historical(start_date, end_date)
117-
stock_prices = [stock_price['Open'] for stock_price in stock_hist]
118-
np.save(cache_filename, stock_prices)
119-
120-
return stock_prices
121-
114+
def get_prices(share_symbol, start_date, end_date):
115+
expire_after = datetime.timedelta(days=3)
116+
session = requests_cache.CachedSession(cache_name='cache', backend='sqlite', expire_after=expire_after)
117+
stock_hist = data.DataReader(share_symbol, 'iex', start_date, end_date, session=session)
118+
open_prices = stock_hist['open']
119+
return open_prices.values.tolist()
122120

123121
def plot_prices(prices):
124122
plt.title('Opening stock prices')
@@ -129,7 +127,7 @@ def plot_prices(prices):
129127

130128

131129
if __name__ == '__main__':
132-
prices = get_prices('MSFT', '1992-07-22', '2016-07-22')
130+
prices = get_prices('MSFT', '2013-07-22', '2018-07-22')
133131
plot_prices(prices)
134132
actions = ['Buy', 'Sell', 'Hold']
135133
hist = 200

0 commit comments

Comments
 (0)