# z_curve.py

import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.stats import norm

#########################################################################################

def indices(a, func):
   return [i for (i, val) in enumerate(a) if func(val)]

#########################################################################################

p = 0.05

random.seed(19610526)

value = norm.ppf(1-p/2)

x_all = np.arange(-10, 10, 0.001) # entire range of x, both in and out of spec

y2 = norm.pdf(x_all)

# build the plot
fig, ax = plt.subplots(figsize=(9,6))
# plt.style.use('fivethirtyeight')
ax.plot(x_all, y2, color='r', lw=2)

j = indices(x_all,  lambda y: y < -value)
k = indices(x_all,  lambda y: y > value)
ax.fill_between(x_all[j], norm.pdf(x_all[j]), 0, alpha=0.3, color='r')
ax.fill_between(x_all[k], norm.pdf(x_all[k]), 0, alpha=0.3, color='r')
ax.fill_between(x_all, y2, 0, color='r', alpha=0.1)
ax.set_xlim([-3.5,3.5])
ax.set_ylim([0, 0.45])

plt.axvline(x=value, linestyle='--', color='r', lw=2, alpha=0.4)
x1, y1 = value, 0
x2, y2 = x1, norm.pdf(x1) 
plt.plot([x1, x2], [y1, y2], lw=2, color='r')

ax.text(x1+0.8, 0.30, "Critical", fontsize = 12)
ax.text(x1+0.8, 0.28, "Value", fontsize = 12)
ax.arrow(x1+0.8, 0.295, -0.7, 0, fc='k', lw=1, head_width=0.015, head_length=0.07, color='b')

plt.axvline(x=-value, linestyle='--', color='r', lw=2, alpha=0.4)
plt.plot([-x1, -x2], [y1, y2], lw=2, color='r')
ax.text(-x1-1.35, 0.30, "Critical", fontsize = 12)
ax.text(-x1-1.35, 0.28, "Value", fontsize = 12)
ax.arrow(-x1-0.78, 0.295, 0.7, 0, fc='k', lw=1, head_width=0.015, head_length=0.07, color='b')

ax.text(0.0, 0.2, r'$1 - \alpha$',  horizontalalignment='center', fontsize=16)
ax.text(0.0, 0.1, r'Non-rejection Region',  horizontalalignment='center', fontsize=12)

ax.text(x1+0.15, 0.02, r'$\frac{\alpha}{2}$',  horizontalalignment='center', verticalalignment='center', fontsize=18)
ax.text(-x1-0.15, 0.02, r'$\frac{\alpha}{2}$',  horizontalalignment='center', verticalalignment='center', fontsize=18)

ax.annotate( r'Rejection Region', xy=(x1+0.5, 0.015), xytext = (x1+0.75, 0.1), horizontalalignment='center', color='r',
             fontsize=12, arrowprops = dict(facecolor='red', shrink=0.05))
ax.annotate( r'Rejection Region', xy=(-x1-0.5, 0.015), xytext = (-x1-0.75, 0.1), horizontalalignment='center', color='r',
             fontsize=12, arrowprops = dict(facecolor='red', shrink=0.05))

ax.set_yticklabels([])
ax.set_xlabel('Number of Standard Deviations from Mean 0', fontsize=12)
ax.set_title("Standard Normal Probability Density Function", fontsize=12)

plt.savefig('z_curve.png', dpi=200, bbox_inches='tight', transparent=True)
plt.show()
