import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation


ani = None
saved = False


def acc_model_simple(v, v1=0.3, a0=0.2, ap=0.1, am=0.2):
    """
    naive model of a motor acceleration bound as a function of its current speed

    inputs: position vector v
    returns: positive and negative accelerations
    """
    n = len(v)
    a_minus = np.zeros((n,))
    a_plus = np.zeros((n,))
    alpha1 = np.minimum(v[v > 0] / v1, 1)
    alpha2 = np.minimum(-v[v <= 0] / v1, 1)
    a_minus[v > 0] = -(alpha1 * am + (1 - alpha1) * a0)
    a_plus[v > 0] = alpha1 * ap + (1 - alpha1) * a0
    a_minus[v <= 0] = -(alpha2 * ap + (1 - alpha2) * a0)
    a_plus[v <= 0] = alpha2 * am + (1 - alpha2) * a0
    return a_minus, a_plus


def constraint_a(x, v, acc_model, jerk_max=0.5):
    """
    inputs: position vector x, speed vector v
    returns: min. and max. acceleration for each segment
    """
    x_delta = x[1:] - x[:-1]
    x_mid = (x[:-1] + x[1:]) * 0.5
    v_mid = (v[:-1] + v[1:]) * 0.5
    t_delta = x_delta / v_mid
    a = (v[1:] - v[:-1]) / t_delta
    t_delta_a = (t_delta[:-1] + t_delta[1:]) * 0.5

    a_max_next = a[:-1] + jerk_max * t_delta_a
    a_min_next = a[:-1] - jerk_max * t_delta_a
    a_max_prev = a[1:] + jerk_max * t_delta_a
    a_min_prev = a[1:] - jerk_max * t_delta_a
    a_min = np.zeros_like(a)
    a_max = np.zeros_like(a)
    a_min[0] = a_min_prev[0]
    a_min[-1] = a_min_next[-1]
    a_max[0] = a_max_prev[0]
    a_max[-1] = a_max_next[-1]
    a_min[1:-1] = np.maximum(a_min_next[:-1], a_min_prev[1:])
    a_max[1:-1] = np.minimum(a_max_next[:-1], a_max_prev[1:])

    a_neg_bound, a_pos_bound = acc_model(v_mid)
    a_min = np.maximum(a_min, a_neg_bound)
    a_max = np.minimum(a_max, a_pos_bound)
    a_min = np.minimum(a_min, a)
    a_max = np.maximum(a_max, a)

    return a_min, a_max


def constraint_v(x, v, a_min, a_max, v_abs_max=0.3):
    """
    inputs: position vector x, speed vector v
    returns: min. and max. speed for each segment
    """
    dx = x[1:] - x[:-1]
    v_max_next = np.sqrt(np.maximum(v[:-1] * v[:-1] + 2 * a_max * dx, 0))
    v_min_next = np.sqrt(np.maximum(v[:-1] * v[:-1] + 2 * a_min * dx, 0))
    v_min_prev = np.sqrt(np.maximum(v[1:] * v[1:] - 2 * a_max * dx, 0))
    v_max_prev = np.sqrt(np.maximum(v[1:] * v[1:] - 2 * a_min * dx, 0))
    v_min = np.zeros_like(v)
    v_max = np.zeros_like(v)
    v_min[0] = v_min_prev[0]
    v_min[-1] = v_min_next[-1]
    v_max[0] = v_max_prev[0]
    v_max[-1] = v_max_next[-1]
    v_min[1:-1] = np.maximum(v_min_next[:-1], v_min_prev[1:])
    v_max[1:-1] = np.minimum(v_max_next[:-1], v_max_prev[1:])
    v_min = np.maximum(v_min, -v_abs_max)
    v_max = np.minimum(v_max, v_abs_max)

    return v_min, v_max


def update(frame, lines, x, v):
    global ani, saved
    # acceleration horizon
    a_min, a_max = constraint_a(x, v, acc_model_simple)
    # speed horizon
    v_min, v_max = constraint_v(x, v, a_min, a_max)
    # update speed (could go unstable)
    lambd = 0.1
    v[1:-1] = v_max[1:-1] * lambd + v[1:-1] * (1 - lambd)

    # stop criterion
    mse_diff = np.mean(np.square(v_max[1:-1] - v[1:-1]))
    if mse_diff < 1e-10:
        if not saved:
            np.save("x.npy", x)
            np.save("v.npy", v)
            print("Solution saved!")
            saved = True

    lines[0].set_xdata(x)
    lines[0].set_ydata(v)
    lines[1].set_xdata(x)
    lines[1].set_ydata(v_min)
    lines[2].set_xdata(x)
    lines[2].set_ydata(v_max)

    # compute acceleration for display
    x_mid = (x[1:] + x[:-1])*0.5
    x_delta = x[1:] - x[:-1]
    v_mid = (v[:-1] + v[1:]) * 0.5
    t_delta = x_delta / v_mid
    a = (v[1:] - v[:-1]) / t_delta

    lines[3].set_xdata(x_mid)
    lines[3].set_ydata(a)
    lines[4].set_xdata(x_mid)
    lines[4].set_ydata(a_min)
    lines[5].set_xdata(x_mid)
    lines[5].set_ydata(a_max)

    return lines


def main_func():
    global ani

    fig, axes = plt.subplots(1, 2)
    ax = axes[0]
    ax2 = axes[1]
    ln1, = ax.plot([], [], 'k')
    ln2, = ax.plot([], [], 'b')
    ln3, = ax.plot([], [], 'r')
    ax.set_xlim([0, 4])
    ax.set_ylim([0, 0.4])
    ax.set_xlabel("x")
    ax.set_ylabel("v")
    ax.set_title("Speed")
    ax.legend(["speed", "min", "max"])

    ln4, = ax2.plot([], [], 'k')
    ln5, = ax2.plot([], [], 'b')
    ln6, = ax2.plot([], [], 'r')
    ax2.set_xlim([0, 4])
    ax2.set_ylim([-0.4, 0.4])
    ax2.set_xlabel("x")
    ax2.set_ylabel("a")
    ax2.set_title("Acceleration")
    ax2.legend(["acc.", "min", "max"])

    lines = [ln1, ln2, ln3, ln4, ln5, ln6]
    plt.tight_layout()

    # position axis
    n_points = 256
    x = np.linspace(0, 4, n_points)
    v = np.zeros((n_points, ))
    v[1:-1] = 1e-4

    def init():
        return lines

    ani = FuncAnimation(fig, update, init_func=init, fargs=(lines, x, v), blit=True, interval=2)
    plt.show()


if __name__ == "__main__":
    main_func()