"""
Create a price weighted index from finance yahoo! Japanese stocks
Date: 2018-06-15
"""

from __future__ import print_function
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

#  Global variable. Subdirectory where stock price data are stored
data_dir = 'Data'

###############################################################################
def read_data(fname):
   fn = data_dir + '/' + fname
   data = pd.read_csv(fn)

   # Exclude stocks that have stock splits 
   n = len(data)
   data.index = range(n) 
   if data['Stock Splits'].sum() != n:
      print('%s Stock split and excluded!' % (fname))
      return data[0:0]

   # Exclude days corresponding to Japanese holidays
   for i in range(n):
      if data.Volume[i] == 0:
         if data.High[i] == data.Low[i]:
            data = data.drop([i], axis=0)

   return data

###############################################################################

# Obtain the list of stock data in the sub-directory data_dir 
stocks = os.listdir(data_dir)
n_stocks = len(stocks)

# create a list of dataframes
# First dataframe
data = read_data(stocks[0])
data.index = pd.to_datetime(data.Date, format = '%Y-%m-%d') 
p = data.Close
q = pd.DataFrame({stocks[0][0:4] :p}, index = data.index)

for i in range(1, n_stocks):
   data = read_data(stocks[i])
   if len(data.Date) == 0:
      continue

   data.index = pd.to_datetime(data.Date, format = '%Y-%m-%d')  
   p = data.Close
   ticker = stocks[i][0:4]
   pp = pd.DataFrame({ticker :p}, index = data.index)
   q[ticker] = pp

n = len(q)
a, nc = q.shape
codes = list(q)

# Rplace nan by previous day's price
for i in range(1,n):
   s = sum(q.iloc[i])
   if np.isnan(s):
      for j in range(nc):
         if np.isnan(q.iloc[i][j]):
            q.loc[q.index[i], codes[j]] = q.iloc[i-1][j] 
    
pw =  np.zeros(n, dtype = float)
for i in range(n):
   pw[i] = q.iloc[i].sum()

# Set the divisor so that the inception index value is 1000
divisor = pw[0]/1000
pw = pw/divisor

ts = pd.Series(pw, index=q.index)
# These two holidays were not dropped in the first pass in read_data(fname):
# https://www.timeanddate.com/holidays/japan/2005#!hol=1
ts['2005-11-03'] = np.nan
ts['2006-07-17'] = np.nan
ts = ts.dropna()

# Save the time series o a csv file called pw.csv
ts.to_csv('pw.csv')

# Plot the index
ts.plot()
plt.title('Price-Weighted Index')
plt.show()
   




