# this is the main file for the drawing your drawings with a dmp with force feedback
# TODO:
# 2. delete the unnecessary comments
# 8. add some code to pick up the marker from a prespecified location
# 10. write documentation as you go along
# BIG TODO:
# DOES NOT WORK WITH SPEED SLIDER =/= 1.0!!!!!!!!!!!!!!!!!!!!!!!!!!!
# realistic solution: keep it at 1.0 at all times. make acceleration and/or
# speeds small

import pinocchio as pin
import numpy as np
import matplotlib.pyplot as plt
import copy
import argparse
import time
from functools import partial
from ur_simple_control.util.get_model import get_model
from ur_simple_control.visualize.visualize import plotFromDict
from ur_simple_control.util.draw_path import drawPath
from ur_simple_control.dmp.dmp import DMP, NoTC,TCVelAccConstrained 
# TODO merge these clik files as well, they don't deserve to be separate
# TODO but first you need to clean up clik.py as specified there
from ur_simple_control.clik.clik_point_to_point import getClikController, moveL, moveUntilContact
from ur_simple_control.clik.clik_trajectory_following import map2DPathTo3DPlane, clikCartesianPathIntoJointPath
from ur_simple_control.managers import ControlLoopManager, RobotManager
from ur_simple_control.util.calib_board_hacks import calibratePlane, getSpeedInDirectionOfN
from ur_simple_control.visualize.visualize import plotFromDict
from ur_simple_control.basics.basics import moveJ

#######################################################################
#                            arguments                                #
#######################################################################

def getArgs():
    #######################################################################
    #                          generic arguments                          #
    #######################################################################
    parser = argparse.ArgumentParser(description='Make a drawing on screen,\
            watch the robot do it on the whiteboard.',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # TODO this one won't really work but let's leave it here for the future
    parser.add_argument('--simulation', action=argparse.BooleanOptionalAction, 
            help="whether you are running the UR simulator. \
                    NOTE: doesn't actually work because it's not a physics simulator", \
                    default=False)
    parser.add_argument('--pinocchio-only', action=argparse.BooleanOptionalAction, 
            help="whether you want to just integrate with pinocchio.\
                    NOTE: doesn't actually work because it's not a physics simulator", \
                    default=False)
    parser.add_argument('--visualize', action=argparse.BooleanOptionalAction, 
            help="whether you want to visualize with gepetto, but \
                    NOTE: not implemented yet", default=False)
    parser.add_argument('--gripper', action=argparse.BooleanOptionalAction, \
            help="whether you're using the gripper", default=False)
    parser.add_argument('--acceleration', type=float, \
            help="robot's joints acceleration. scalar positive constant, \
            max 1.7, and default 0.4. \
            BE CAREFUL WITH THIS. the urscript doc says this is 'lead axis acceleration'.\
            TODO: check what this means", default=0.3)
    parser.add_argument('--speed-slider', type=float,\
            help="cap robot's speed with the speed slider \
                    to something between 0 and 1, 1.0 by default because for dmp. \
                    BE CAREFUL WITH THIS.", default=1.0)
    parser.add_argument('--max-iterations', type=int, \
            help="maximum allowable iteration number (it runs at 500Hz)", default=50000)
    #######################################################################
    #                 your controller specific arguments                  #
    #######################################################################
    # not applicable here, but leaving it in the case it becomes applicable
    # it's also in the robot manager even though it shouldn't be
    parser.add_argument('--past-window-size', type=int, \
            help="how many timesteps of past data you want to save", default=5)
    parser.add_argument('--goal-error', type=float, \
            help="the final position error you are happy with. NOTE: not used here", \
            default=1e-3)
    # TODO: test the interaction of this and the overall demo
    parser.add_argument('--tikhonov-damp', type=float, \
            help="damping scalar in tiknohov regularization.\
            This is used when generating the joint trajectory from the drawing.", \
            default=1e-2)
    # TODO add the rest
    parser.add_argument('--clik-controller', type=str, \
            help="select which click algorithm you want", \
            default='dampedPseudoinverse', \
            choices=['dampedPseudoinverse', 'jacobianTranspose'])
        # maybe you want to scale the control signal
    parser.add_argument('--controller-speed-scaling', type=float, \
            default='1.0', help='not actually_used atm')
    #############################
    #  dmp  specific arguments  #
    #############################
    parser.add_argument('--temporal-coupling', action=argparse.BooleanOptionalAction, \
            help="whether you want to use temporal coupling", default=True)
    parser.add_argument('--kp', type=float, \
            help="proportial control constant for position errors", \
            default=1.0)
    parser.add_argument('--tau0', type=float, \
            help="total time needed for trajectory. if you use temporal coupling,\
                  you can still follow the path even if it's too fast", \
            default=5)
    parser.add_argument('--gamma-nominal', type=float, \
            help="positive constant for tuning temporal coupling: the higher,\
            the fast the return rate to nominal tau", \
            default=1.0)
    parser.add_argument('--gamma-a', type=float, \
            help="positive constant for tuning temporal coupling, potential term", \
            default=0.5)
    parser.add_argument('--eps-tc', type=float, \
            help="temporal coupling term, should be small", \
            default=0.001)
    parser.add_argument('--alpha', type=float, \
            help="force feedback proportional coefficient", \
            default=0.007)
    # TODO add low pass filtering and make it's parameters arguments too
    #######################################################################
    #                       task specific arguments                       #
    #######################################################################
    # TODO measure this for the new board
    parser.add_argument('--board-width', type=float, \
            help="width of the board (in meters) the robot will write on", \
            default=0.5)
    parser.add_argument('--board-height', type=float, \
            help="height of the board (in meters) the robot will write on", \
            default=0.35)
    parser.add_argument('--calibration', action=argparse.BooleanOptionalAction, \
            help="whether you want to do calibration", default=False)
    parser.add_argument('--draw-new', action=argparse.BooleanOptionalAction, \
            help="whether draw a new picture, or use the saved path path_in_pixels.csv", default=True)
    parser.add_argument('--pick_up_marker', action=argparse.BooleanOptionalAction, \
            help="""
    whether the robot should pick up the marker.
    NOTE: THIS IS FROM A PREDEFINED LOCATION.
    """, default=True)
    parser.add_argument('--find-marker-offset', action=argparse.BooleanOptionalAction, \
            help="""
    whether you want to do find marker offset (recalculate TCP
    based on the marker""", default=False)
    parser.add_argument('--n-calibration-tests', type=int, \
            help="number of calibration tests you want to run", default=10)
    parser.add_argument('--clik-goal-error', type=float, \
            help="the clik error you are happy with", default=1e-2)
    parser.add_argument('--max-init-clik-iterations', type=int, \
            help="number of max clik iterations to get to the first point", default=10000)
    parser.add_argument('--max-running-clik-iterations', type=int, \
            help="number of max clik iterations between path points", default=1000)
    args = parser.parse_args()
    if args.gripper and args.simulation:
        raise NotImplementedError('Did not figure out how to put the gripper in \
                the simulation yet, sorry :/ . You can have only 1 these flags right now')
    return args

"""
calibrateFT
-----------
Read from the f/t sensor a bit, average the results
and return the result.
This can be used to offset the bias of the f/t sensor.
NOTE: this is not an ideal solution.
ALSO TODO: test whether the offset changes when 
the manipulator is in different poses.
"""
def calibrateFT(robot):
    ft_readings = []
    print("Will read from f/t sensors for a some number of seconds")
    print("and give you the average.")
    print("Use this as offset.")
    for i in range(2000):
        start = time.time()
        q = robot.rtde_receive.getActualQ()
        ft = robot.rtde_receive.getActualTCPForce()
        tau = robot.rtde_control.getJointTorques()
        current = robot.rtde_receive.getActualCurrent()
        ft_readings.append(ft)
        end = time.time()
        diff = end - start
        if diff < robot.dt:
            time.sleep(robot.dt - diff)

    ft_readings = np.array(ft_readings)
    avg = np.average(ft_readings, axis=0)
    print("average ft time", avg)
    return avg


# go and pick up the marker
def getMarker(q_init):
    pass

"""
getMarkerOffset
---------------
This relies on having the correct orientation of the plane 
and the correct translation vector for top-left corner.
Idea is you pick up the marker, go to the top corner,
touch it, and see the difference between that and the translation vector.
Obviously it's just a hacked solution, but it works so who cares.
"""
def getMarkerOffset(args, robot, rotation_matrix, translation_vector, q_init):
    # TODO make this more general
    # so TODO: calculate TCP speed based on the rotation matrix
    # and then go
    #z_of_rot = rotation_matrix[:,2]
    y_of_rot = rotation_matrix[:,1]
    # it's going out of the board, and we want to go into the board, right????
    # TODO test this
    #z_of_rot = z_of_rot 
    print("vector i'm following:", y_of_rot)
    speed = getSpeedInDirectionOfN(rotation_matrix)
    #speed[2] = speed[2] * -1
    #robot.rtde_control.moveUntilContact(speed)
    moveUntilContact(args, robot, speed)
    # we use the pin coordinate system because that's what's 
    # the correct thing long term accross different robots etc
    q = robot.getQ()
    pin.forwardKinematics(robot.model, robot.data, np.array(q))
    current_translation = robot.data.oMi[6].translation
    # i only care about the z because i'm fixing the path atm
    # but, let's account for the possible milimiter offset 'cos why not
    print("translation_vector", translation_vector)
    print("current_translation", current_translation)
    print("translation_vector - current_translation", \
            translation_vector - current_translation)
    marker_offset = np.linalg.norm(translation_vector - current_translation)
#    robot.setSpeedSlider(old_speed_slider)
    return marker_offset

#######################################################################
#                            control loop                             #
#######################################################################

# feedforward velocity, feedback position and force for impedance
def controller():
    pass

# TODO:
# regarding saving data you have 2 options:
# 1) explicitely return what you want to save - you can't magically read local variables
# 2) make controlLoop a class and then save handle the saving behind the scenes -
#    now you these variables are saved in a class so they're not local variables
# option 1 is clearly more hands-on and thus worse
# option 2 introduces a third big class and makes everything convoluted.
# for now, we go for option 1) because it's simpler to implement and deal with.
# but in the future, implementing 2) should be tried. you really have to try 
# to do it cleanly to see how good/bad it is.
# in that case you'd most likely want to inherit ControlLoopManager and go from there.
# you'd still need to specify what you're saving with self.that_variable so no matter
# there's no running away from that in any case.
# it's 1) for now, it's the only non-idealy-clean part of this solution, and it's ok really.
# TODO but also look into something fancy like some decorator or something and try
# to find option 3)

# control loop to be passed to ControlLoopManager
def controlLoopWriting(wrench_offset, dmp, tc, controller, robot, i, past_data):
    breakFlag = False
    # TODO rename this into something less confusing
    save_past_dict = {}
    log_item = {}
    dmp.step(robot.dt) # dmp step
    # temporal coupling step
    tau = dmp.tau + tc.update(dmp, robot.dt) * robot.dt
    dmp.set_tau(tau)
    q = robot.getQ()
    # TODO look into UR code/api for estimating the same
    # based on currents in the joints.
    # it's probably worse, but maybe some sensor fusion-type thing
    # is actually better, who knows.
    # also you probably want to do the fusion of that onto tau (got from J.T @ wrench)
    #Z = np.diag(np.array([0.6, 0.6, 1.0, 0.5, 0.5, 0.5]))
    #Z = np.diag(np.array([0.6, 1.0, 0.6, 0.5, 0.5, 0.5]))
    #Z = np.diag(np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]))
    Z = np.diag(np.ones(6))
    #Z = np.diag(np.array([0.1, 0.1, 1.0, 0.1, 0.1, 0.1]))

    #Z = np.diag(np.array([1.0, 0.6, 1.0, 0.5, 0.5, 0.5]))
    wrench = robot.getWrench()
    wrench = wrench - wrench_offset
    #wrench = np.average(np.array(past_data['wrench']), axis=0)

    # first-order low pass filtering instead
    # beta is a smoothing coefficient, smaller values smooth more, has to be in [0,1]
    beta = 0.007
    wrench = beta * wrench + (1 - beta) * past_data['wrench'][-1]

    wrench = robot.getMtool().toDualActionMatrix().T @ wrench
    wrench = Z @ robot.getWrench()
    # evil hack because wrench is  not zeros (why? - no idea whatsoever)
    # TODO: run it empty and zero it
    #wrench = wrench - np.array([-1.14156273, 11.05155707,  1.88523016, -0.06643418,  0.16550734,  0.09019818])
    #wrench = wrench - np.array([ -2.89005509, 18.11969302, -1.92425821, -0.08124564,  0.08441558,  0.06800772])
    #wrench = wrench - np.array( [ 0.38101014, -0.32524308,  0.50800527, -0.00830584,  0.06112097,  0.01231109])
    # deepcopy for good coding practise (and correctness here)
    save_past_dict['wrench'] = copy.deepcopy(wrench)
    # rolling average
    if i % 100 == 0:
        print(wrench)
    pin.forwardKinematics(robot.model, robot.data, q)
    J = pin.computeJointJacobian(robot.model, robot.data, q, robot.JOINT_ID)
    dq = robot.getQd()[:6].reshape((6,1))
    # get joitn 
    tau = J.T @ wrench
    tau = tau[:6].reshape((6,1))
    # compute control law:
    # - feedforward the velocity and the force reading
    # - feedback the position 
    # TODO: don't use vel for qd, it's confusion (yes, that means changing dmp code too)
    # TODO: put this in a controller function for easy swapping (or don't if you won't swap)
    # solve this q number connundrum
    # TODO evil hack
    #vel_cmd = dmp.vel + args.kp * (dmp.pos - q[:6].reshape((6,1))) - args.alpha * tau
    vel_cmd = dmp.vel + args.kp * (dmp.pos - q[:6].reshape((6,1))) + args.alpha * tau
    robot.sendQd(vel_cmd)

    # TODO find a better criterion for stopping
    if (np.linalg.norm(dmp.vel) < 0.0001) and (i > 5000):
        breakFlag = True
    # immediatelly stop if something weird happened (some non-convergence)
    if np.isnan(vel_cmd[0]):
        breakFlag = True

    # log what you said you'd log
    # TODO fix the q6 situation (hide this)
    log_item['qs'] = q[:6].reshape((6,))
    log_item['dmp_poss'] = dmp.pos.reshape((6,))
    log_item['dqs'] = dq.reshape((6,))
    log_item['dmp_vels'] = dmp.vel.reshape((6,))

    return breakFlag, save_past_dict, log_item

if __name__ == "__main__":
    #######################################################################
    #                           software setup                            #
    #######################################################################
    args = getArgs()
    clikController = getClikController(args)
    robot = RobotManager(args)

    # calibrate FT first
    wrench_offset = calibrateFT(robot)
    #######################################################################
    #          drawing a path, making a joint trajectory for it           #
    #######################################################################
    # TODO make these ifs make more sense
    
    # draw the path on the screen
    if args.draw_new:
        pixel_path = drawPath()
        # make it 3D
    else:
        pixel_path_file_path = './path_in_pixels.csv'
        pixel_path = np.genfromtxt(pixel_path_file_path, delimiter=',')
    # do calibration if specified
    if args.calibration:
        rotation_matrix, translation_vector, q_init = \
            calibratePlane(args, robot, args.board_width, args.board_height, \
                           args.n_calibration_tests)
    else:
        # TODO: save this somewhere obviously
        # also make it prettier if possible
        print("using predefined values")
        q_init = np.array([1.4545, -1.7905, -1.1806, -1.0959, 1.6858, -0.1259, 0.0, 0.0])
        translation_vector = np.array([0.10125722 ,0.43077874 ,0.9110792 ])
        rotation_matrix = np.array([[1.  ,       0.         ,0.00336406],
                                    [-0.        , -0.00294646,  0.99999   ],
                                    [ 0.        , -0.99999  ,  -0.00294646]])

    # make the path 3D
    path = map2DPathTo3DPlane(pixel_path, args.board_width, args.board_height)
    # TODO: fix and trust z axis in 2D to 3D path
    # TODO: add an offset of the marker (this is of course approximate)
    # TODO: make this an argument once the rest is OK
    # ---> just go to the board while you have the marker ready to find this
    # ---> do that right here
    if args.pick_up_marker:
        # pick up the marker
        #TODO
        pass
    if args.find_marker_offset:
        # find the marker offset
        # TODO find a better q init (just moveL away from the board)
        marker_offset = getMarkerOffset(args, robot, rotation_matrix, translation_vector, q_init)
        print('marker_offset', marker_offset)
        robot.stopHandler(None,None)
        # Z
        #path = path + np.array([0.0, 0.0, -1 * marker_offset])
        # Y
        #path = path + np.array([0.0, -1 * marker_offset, 0.0])
        path = path + np.array([0.0, 0.0, -1 * marker_offset])
    else:
        #path = path + np.array([0.0, -0.0938, 0.0])
        #path = path + np.array([0.0, 0.0, -0.0938])
        # NEW MARKER IS SHORTER
        #path = path + np.array([0.0, 0.0, -0.0813])
        # NOTE GOOD FOR SHORT
        #path = path + np.array([0.0, 0.0, -0.0750])
        #path = path + np.array([0.0, 0.0, -0.0791])
        #path = path + np.array([0.0, 0.0, -0.0808])
        #path = path + np.array([0.0, 0.0, -0.0108])
        # NOTE: THIS IS THE ONE
        # NOTE: not a single one is the one, the f/t sensor sucks
        # and this is the best number to change to get it to work [upside_down_emoji]
        # but this is very close
        #path = path + np.array([0.0, 0.0, -0.0803])
        #path = path + np.array([0.0, 0.0, -0.1043])
        path = path + np.array([0.0, 0.0, -0.0100])
        #path = path + np.array([0.0, 0.0, -0.1573])
        #path = path + np.array([0.0, 0.2938, 0.0])

    # and if you don't want to draw new nor calibrate, but you want the same path
    # with a different clik, i'm sorry, i can't put that if here.
    # atm running the same joint trajectory on the same thing makes for easier testing
    # of the final system.
    if args.draw_new or args.calibration:
        
        #path = path + np.array([0.0, 0.0, -0.0938])
        # create a joint space trajectory based on the 3D path
    # TODO: add flag of success (now it's your eyeballs and printing)
    # and immediatelly exit if it didn't work
        joint_trajectory = clikCartesianPathIntoJointPath(path, args, robot, \
            clikController, q_init, rotation_matrix, translation_vector)
    else:
        joint_trajectory_file_path = './joint_trajectory.csv'
        joint_trajectory = np.genfromtxt(joint_trajectory_file_path, delimiter=',')
    
    # create DMP based on the trajectory
    dmp = DMP(joint_trajectory)
    if not args.temporal_coupling:
        tc = NoTC()
    else:
        # TODO test whether this works (it should, but test it)
        # test the interplay between this and the speed slider
        # ---> SPEED SLIDER HAS TO BE AT 1.0
        v_max_ndarray = np.ones(robot.n_joints) * robot.max_qd #* args.speed_slider
        # test the interplay between this and the speed slider
        # args.acceleration is the actual maximum you're using
        a_max_ndarray = np.ones(robot.n_joints) * args.acceleration #* args.speed_slider
        tc = TCVelAccConstrained(args.gamma_nominal, args.gamma_a, v_max_ndarray, a_max_ndarray, args.eps_tc)

    # TODO and NOTE the weight, TCP and inertial matrix needs to be set on the robot
    # you already found an API in rtde_control for this, just put it in initialization 
    # under using/not-using gripper parameters
    # ALSO NOTE: to use this you need to change the version inclusions in
    # ur_rtde due to a bug there in the current ur_rtde + robot firmware version 
    # (the bug is it works with the firmware verion, but ur_rtde thinks it doesn't)
    # here you give what you're saving in the rolling past window 
    # it's initial value.
    # controlLoopManager will populate the queue with these initial values
    save_past_dict = {
            'wrench' : np.zeros(6),
        }
    # here you give it it's initial value
    log_dict = {
            'qs' : np.zeros((args.max_iterations, 6)),
            'dmp_poss' : np.zeros((args.max_iterations, 6)),
            'dqs' : np.zeros((args.max_iterations, 6)),
            'dmp_vels' : np.zeros((args.max_iterations, 6)),
        }
    controlLoop = partial(controlLoopWriting, wrench_offset, dmp, tc, controller, robot)
    loop_manager = ControlLoopManager(robot, controlLoop, args, save_past_dict, log_dict)
    #######################################################################
    #                           physical setup                            #
    #######################################################################
    # TODO: add marker picking
    # get up from the board
    current_pose = robot.getMtool()
    # Z
    #current_pose.translation[2] = current_pose.translation[2] + 0.03
    # Y
    #current_pose.translation[1] = current_pose.translation[1] + 0.03
    #moveL(args, robot, current_pose)
    # move to initial pose
    dmp.step(1/500)
    first_q = dmp.pos.reshape((6,))
    first_q = list(first_q)
    first_q.append(0.0)
    first_q.append(0.0)
    first_q = np.array(first_q)
    #pin.forwardKinematics(robot.model, robot.data, first_q)
    mtool = robot.getMtool(q_given=first_q)
    #mtool.translation[1] = mtool.translation[1] - 0.0035
    mtool.translation[1] = mtool.translation[1] - 0.03
    moveL(args, robot, mtool)
    #moveL

    #moveJ(args, robot, dmp.pos.reshape((6,)))
    # and now we can actually run
    loop_manager.run()
    mtool = robot.getMtool()
    mtool.translation[1] = mtool.translation[1] - 0.1
    moveL(args, robot, mtool)

    plotFromDict(log_dict, args)
    robot.stopHandler(None, None)
    robot.stopHandler(None, None)
    robot.stopHandler(None, None)
    # plot results
    plotFromDict(log_dict, args)
    # TODO: add some math to analyze path precision