#include <stdlib.h>
#include <stdio.h>
#include "traj2dmp.h"
#include "optpart.h"
#include "cuttraj.h"
#include <vector>
#include <netdb.h>
#include <string.h>
#include <sys/times.h>
#include <iostream>
#include "dmp.h"
#include <unistd.h>
#include <armadillo>
#include "plotMat3.h"



using namespace std;
using namespace arma;


int main(int argc, char *argv[]) {
	
	// Initialize original trajectories:
	double dt = 0.040; // 25 Hz sampling rate, in this example
	mat yDef;
	yDef.load("yDef.txt", raw_ascii); 
	mat yCorr;
	yCorr.load("yCorr.txt", raw_ascii); 	
	int nDims = yDef.n_cols; // Number of trajectory dimensions
	
	mat yDefKeep;  // kept part of deficient trajectory
	mat yCorrKeep; // kept part of corrective trajectory

	cuttraj(yDef, yCorr, yDefKeep, yCorrKeep); //Determine which part of trajectories to keep

	cout << "cuttraj done "  << endl;
	// extrapolate/intrapolate to length 100 (for CVXGEN solver):
	vec x = linspace<vec>(0,1,yDefKeep.n_rows);

	mat xx = linspace<vec>(0,1,100);
	mat yDefKeepExtrap;
	mat yMod = zeros(100, nDims);
	// loop over dimensions j and solve optimization problem:
	cout << "extrap done "  << endl;
	for (int j = 0; j < nDims; ++j) {
		interp1(x, yDefKeep.col(j), xx, yDefKeepExtrap);
		yMod.col(j) = optpart(yDefKeepExtrap, yCorrKeep.col(j));
	}
	cout << "optimization done "  << endl;	
	
	Dmp resDmp1 = traj2dmp(yMod, dt*yDefKeep.n_rows/yMod.n_rows).speedupTimes(1);

	Dmp resDmp2 = traj2dmp(yCorrKeep, dt).speedupTimes(1);

	cout << "First resulting DMP: " << endl << resDmp1 << endl;
	cout << "Second resulting DMP: " << endl << resDmp2 << endl;
	
	
	int simSamples1 = 500;
	mat yRes1 = yDef.row(0);
	mat vel = zeros(1,nDims);
	mat yResNext = zeros(1,nDims);
	for (int i = 0; i < simSamples1; ++i) {  // Simulate trajectory from resDmp1
		vel = resDmp1.getVel(yRes1.row(yRes1.n_rows-1), dt);
		for (int j = 0; j < nDims; ++j) {
			yResNext(j) = yRes1(yRes1.n_rows-1,j) + vel(j) * dt;
		}
		yRes1 = join_vert(yRes1, yResNext);
	}
	
	mat yRes2 = yCorrKeep.row(0);	
	for (int i = 0; i < simSamples1; ++i) {  // Simulate trajectory from resDmp2
		vel = resDmp2.getVel(yRes2.row(yRes2.n_rows-1), dt);
		for (int j = 0; j < nDims; ++j) {
			yResNext(j) = yRes2(yRes2.n_rows-1,j) + vel(j) * dt;
		}
		yRes2 = join_vert(yRes2, yResNext);
	}
		
	
	cout << "Start plotting" << endl;
	plotMat3(yDef, "Deficient trajectory");
	plotMat3(yCorr, "Corrective trajectory");
	plotMat3(yMod, "Modified trajectory");
	plotMat3(yRes1, "Resulting trajectory part 1");
	plotMat3(yRes2, "Resulting trajectory part 2");
	
}