"""
To do interactive plotting, run:

    bokeh serve betacat_destruction_cycle.py

on the command line.  Then, point a web browser to
http://localhost:5006/betacat_destruction_cycle.
"""

import numpy as np
import pandas as pd
import scipy.integrate

import bokeh.core.properties
import bokeh.io
import bokeh.models.widgets
import bokeh.palettes
import bokeh.plotting

def dcdt(c, t, k8, km8, k9, k10, k11, k12):
    """
    Time derivative of concentrations.
    c = (c3, c8, c9, c10, c11)
    """
    # Unpack concentrations and parameters
    c3, c8, c9, c10, c11 = c

    # Build derivatives
    deriv = np.empty(5)
    deriv[0] = -k8*c3*c11 + km8*c8 + k10*c9
    deriv[1] = k8*c3*c11 - (km8 + k9)*c8
    deriv[2] = k9*c8 - k10*c9
    deriv[3] = k10*c9 - k11*c10
    deriv[4] = -k8*c3*c11 + km8*c8 + k12

    return deriv

# Key for names
names = ['Axin complex', 'Axin-βcat', 'Axin-βcat*', 'βcat*', 'βcat']

# Specify colors
colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00']

# Define known parameters from Lee, et al, PLoS Biology, 2003
c_A = 50          # nM (given by fixed GSK-3 concentration)
k9 = 206          # 1/min
k10 = 206         # 1/min
k11 = 0.417       # 1/min
Kd8 = 120         # nM

# Unknown parameters
log10_km8 = 0     # log10(1/min)
k12 = 100         # nM/min

# k8 determined form Kd8 and km8
km8 = 10**log10_km8
k8 = km8 / Kd8    # 1/nM-min

# Initial conditions
c0 = np.array([c_A, 0, 0, 0, 0])

# Set up time points and solve
t = np.linspace(0, 15, 400)
c = scipy.integrate.odeint(dcdt, c0, t, args=(k8, km8, k9, k10, k11, k12))

# Store in a DataFrame for convenience in plotting
df = pd.DataFrame(data=c, columns=names)
df['t'] = t

# Data source
source = bokeh.models.ColumnDataSource(data=df)

# Set up the figure
p = bokeh.plotting.Figure(plot_width=650, plot_height=450,
                          x_axis_label='time (min)',
                          y_axis_label='conc (nM)',
                          y_range=[-10,310],
                          border_fill_alpha=0, background_fill_alpha=0)

# Add glyphs
for i, name in enumerate(names):
    p.line(x='t', y=name, source=source, line_width=3, color=colors[i],
    legend=bokeh.core.properties.value(name))

# Place legend
p.legend.location = 'top_left'

# Set up widgets
k12_val = bokeh.models.Slider(title='k12 [1/nm-min]',
                              value=100.0,
                              start=20.0,
                              end=100.0,
                              step=1.0)
log10_km8_val = bokeh.models.Slider(title='log10 km8 [log10(1/min)]',
                                    value=0.0,
                                    start=-2.0, 
                                    end=4.0,
                                    step=0.1)

# Set up callbacks
def update_data(attrname, old, new):

    # Compute k8
    log10_km8 = log10_km8_val.value
    km8 = 10**log10_km8
    k8 = km8 / Kd8

    # Generate the new curve
    c = scipy.integrate.odeint(dcdt, c0, t,
                 args=(k8, km8, k9, k10, k11, k12_val.value))
    df = pd.DataFrame(data=c, columns=names)
    df['t'] = t

    # Re-source
    source.data = dict(df)

# Change values upon activating slider
for widget in [k12_val, log10_km8_val]:
    widget.on_change('value', update_data)

# Set up layouts and add to document
inputs = bokeh.models.layouts.WidgetBox(children=[k12_val, log10_km8_val])
bokeh.io.curdoc().add_root(bokeh.models.Row(children=[inputs, p], width=800))