import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set(style="ticks", font_scale=1.5)
def budyko_original(DI):
    return (DI * np.tanh(1/DI) * (1-np.exp(-DI)) )**0.5

def fu(DI, omega):
    return 1 + DI - (1 + DI**omega)**(1/omega)
fig, ax = plt.subplots(1, 1, figsize=(10,7))

ai = np.linspace(0.01, 7, 101)

ax.fill_between([0, 1.54], 2, color="tab:blue", edgecolor='face')
ax.fill_between([1.54, 2], 2, color="tab:blue", alpha=0.6, edgecolor='face')
ax.fill_between([2, 5], 2, color="tab:blue", alpha=0.4, edgecolor='face')
ax.fill_between([5, 20], 2, color="tab:blue", alpha=0.2, edgecolor='face')

ax.text(1.54, 1.0, "Humid", color="white", ha="right")
ax.text(2, 0.8, "Dry Subhumid", color="white", ha="right")
ax.text(5, 0.6, "Semi-arid", color="black", ha="right")
ax.text(7, 0.4, "Arid", color="black", ha="right")

ax.set(xlabel=r"dryness index ($ET_0/P$)",
       ylabel=r"evaporation index ($ET/P$)",
       xlim=[0, 7],
       ylim=[0, 1.1])

plt.savefig("hydrology_figures/budyko0.png")
fig, ax = plt.subplots(1, 1, figsize=(10,7))

ai = np.linspace(0.01, 4, 101)
# ax.plot(ai, turk_pike(ai), color="black", lw=3)
# ax.plot(ai, fu(ai, 2.7), color="black", lw=3)

ax.plot([0, 4], 2*[1], color="tab:blue", lw=3)
ax.plot([0, 1.1], [0, 1.1], color="tab:red", lw=3)

ax.text(2, 1.02, "water limit", color="tab:blue")
ax.text(0.5, 0.65, "energy limit", rotation=68, color="tab:red")
# ax.text(2, 0.84, "Budyko curve")

ax.set(xlabel=r"dryness index ($ET_0/P$)",
       ylabel=r"evaporation index ($ET/P$)",
       xlim=[0, 4],
       ylim=[0, 1.1])

plt.savefig("hydrology_figures/budyko1.png")
fig, ax = plt.subplots(1, 1, figsize=(10,7))

DI = np.linspace(0.01, 4, 101)
# ax.plot(ai, fu(ai, 2.7), color="black", lw=3)
ax.plot(DI, budyko_original(DI), color="black", lw=3)

ax.plot([0, 4], 2*[1], color="tab:blue", lw=3)
ax.plot([0, 1.1], [0, 1.1], color="tab:red", lw=3)

ax.text(2, 1.02, "water limit", color="tab:blue")
ax.text(0.5, 0.65, "energy limit", rotation=68, color="tab:red")
ax.text(2, 0.84, "Budyko curve")

ax.set(xlabel=r"dryness index ($ET_0/P$)",
       ylabel=r"evaporation index ($ET/P$)",
       xlim=[0, 4],
       ylim=[0, 1.1])

plt.savefig("hydrology_figures/budyko2.png")
fig, ax = plt.subplots(1, 1, figsize=(10,7))

ai = np.linspace(0.01, 4, 101)
ax.plot(ai, turk_pike(ai), color="black", lw=3)
# ax.plot(ai, fu(ai, 2.7))

ax.plot(ax.get_xlim(), 2*[1], color="tab:blue", lw=3)
ax.plot([0, 1.1], [0, 1.1], color="tab:red", lw=3)

ax.text(2, 1.02, "water limit", color="tab:blue")
ax.text(0.5, 0.65, "energy limit", rotation=68, color="tab:red")
ax.text(2, 0.84, "Budyko curve")

arrow_x = 1.5
ax.annotate("",
            xy=(arrow_x, turk_pike(arrow_x)), #xycoords='data',
            xytext=(arrow_x, 0),# textcoords='data',
            size=26,
            arrowprops=dict(arrowstyle="<->",
                            shrinkA=0, shrinkB=0,
                            connectionstyle="arc3",
                            color='black')
            )
ax.annotate("",
            xy=(arrow_x, turk_pike(arrow_x)), #xycoords='data',
            xytext=(arrow_x, 1),# textcoords='data',
            size=26,
            arrowprops=dict(arrowstyle="<->",
                            shrinkA=0, shrinkB=0,
                            connectionstyle="arc3",
                            color='black')
            )

ax.text(1.55, 0.4, "actual evaporation")
ax.text(1.55, 0.9, "Q")

ax.set(xlabel="aridity index (PET/P)",
       ylabel="evaporation index (AEP/P)",
       xlim=[0, 4],
       ylim=[0, 1.1])

plt.savefig("hydrology_figures/budyko2.png")