/*
 * this module is written in C for efficiency.
 */
/* leigh notes:
 * more work on make_range
 */

#include "matrix.h"
#include <math.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/times.h>


#define MAX_LINE_LEN 100
#define MAX_STRUC_NAME 50
#define MAX_NUM_STRUCS 1000 
#define MAX_FNAME_LEN 100
#define MAX_NAME_LEN 10
#define MAX_ATOM_NUM 5000
#define MAX_ATOM_TYPES 20
#define MAX_RES_NUM 500

struct S_list {
	int size;
	char list[MAX_NUM_STRUCS][MAX_STRUC_NAME];
} S_list;
struct S_list objs;


struct Pdbline {
	char atom[MAX_NAME_LEN];
	int aNum;
	char res[MAX_NAME_LEN];
	int rNum;
	float x, y, z;
} Pdb;

struct pdb_data {
	char obj[MAX_STRUC_NAME];
	struct Pdbline line[MAX_ATOM_NUM];
	int numatoms;
} pdb_data;
static struct pdb_data data[MAX_NUM_STRUCS];


void print_matrix();
void print_vector();
void update_pdb();
void make_range();

/* global variables */
int _numobjs;
char _workingdir[MAX_LINE_LEN];
float _rmsresavg[MAX_RES_NUM];

int debug = 1;

/* hold all of the residue numbers in the range */
struct Range {
		int num;
		int idx[MAX_RES_NUM];
};
struct Range _range;

/* dictionary of (actual) residue numbers to atom Indexes */
int _resdict[MAX_RES_NUM][MAX_ATOM_NUM];
int _resdict2[MAX_RES_NUM];
int _numres;  /* the actual number of residues */

/*------------------------------------------------------------------------------
 * read in one pdb file
 */
void readpdb(char *object, int idx) {
FILE *fp;
char fname[MAX_FNAME_LEN], line[MAX_LINE_LEN];
register int i, count, ret;
char *rc;
char junk[25];


	sprintf(fname, "%s/%s.pdb", _workingdir, object);

	if ((fp = fopen(fname, "r")) == NULL) {
		printf("ERROR:  cannot open pdb file %s\n", fname);
		return;
	}
	count = 0;
	rc = fgets(line, MAX_LINE_LEN - 1, fp);
	while (rc != NULL) {
		if (0 == strncmp(line, "ATOM", 4)) {
			if (count >= MAX_ATOM_NUM) {
				printf(" \n Maximum number of atoms (%d) exceeded\n", MAX_ATOM_NUM);
				exit(1);
			}
			/* chain indicator in 5th column */
			ret = sscanf(line, "ATOM %d %s %s %*[A-Z] %d %f %f %f %*lf %*lf", 
				&data[idx].line[count].aNum, data[idx].line[count].atom,
				data[idx].line[count].res, &data[idx].line[count].rNum,
				&data[idx].line[count].x, &data[idx].line[count].y, 
				&data[idx].line[count].z);

			/* chain indicator in last column */
			if (ret != 7) {
			ret = sscanf(line, "ATOM %d %s %s %d %f %f %f %*lf %*lf %*c", 
				&data[idx].line[count].aNum, data[idx].line[count].atom,
				data[idx].line[count].res, &data[idx].line[count].rNum,
				&data[idx].line[count].x, &data[idx].line[count].y, 
				&data[idx].line[count].z);
			}

			if (ret != 7) {  /* try a different format */
				ret = sscanf(line, "%*s%d%s%s%d%f%f%f%*lf%*lf", 
				&data[idx].line[count].aNum, data[idx].line[count].atom,
				data[idx].line[count].res, &data[idx].line[count].rNum,
				&data[idx].line[count].x, &data[idx].line[count].y, 
				&data[idx].line[count].z);
			}

			if (ret != 7) {
				printf("SORRY.  cannot read PDB file.\n");
				exit(1);
			}
			count++;
		}
		rc = fgets(line, MAX_LINE_LEN - 1, fp);
	}

	data[idx].numatoms = count;
	strcpy(data[idx].obj, object);

	fclose(fp);
}


/*------------------------------------------------------------------------------
 * initialize everything.
 */
static void init(char *obj_str, char *range, char *workingdir) {
int i, j, rNum, numatoms;
char *result = NULL;


	/* set global variable */
	strcpy(_workingdir, workingdir);

	/* initialize the object list */
	i = 0;
	result = strtok(obj_str, " ");
	while (result != NULL) {
		strcpy(objs.list[i++], result);
		result = strtok(NULL, " ");
	}
	objs.size = i;
	_numobjs = i;
		
	/* read in the pdb files for all objects */
	for (i=0; i<objs.size; i++) {
		readpdb(objs.list[i], i);
	}

	_numres = data[0].line[data[0].numatoms-1].rNum;
	/* make up the res->atom number dictionary */
	/* this is actual residue numbers to atom indexes */
	for (i=0; i<data[0].numatoms; i++) {
		rNum = data[0].line[i].rNum;
		numatoms = _resdict2[rNum]++;
		_resdict[rNum][numatoms] = i;
	}

	/* turn the range from a string into our data structure */
	make_range(range);

}

/*------------------------------------------------------------------------------
 * binary search routine used to check if a resnum occurs in the input
 */
int isvalid(int resnum, int all[MAX_RES_NUM], int lenall) {
	
int mid;
int left = 0;
int right = lenall;

	while (left <= right) {
		mid = (left+right)/2 ;
		if (all[mid] == resnum) { return 1; }
		if (resnum < all[mid]) { right = mid - 1; }
		else { left = mid + 1; }
	}
	return 0;
}


/*------------------------------------------------------------------------------
 * turn the range from string form into our data structure
 * ex. input:  "1,3-9,11-111" 
 */
void make_range(char *range) {
int i, j, low, high, r, start, stop;
char copy[MAX_LINE_LEN], *ret;
int allrange[MAX_RES_NUM], allrangenum;
int tmprange[MAX_RES_NUM], tmprangenum;

	j = 1;
	_range.num = 0;
	tmprangenum = 0;
	strcpy(copy, range);

	allrange[0] = data[0].line[0].rNum;
	for (i=1; i<data[0].numatoms; i++) {
		if (allrange[j-1] != data[0].line[i].rNum) {
			allrange[j] = data[0].line[i].rNum;
			j++;
		}
	}
	allrangenum = j-1;

	if (! strcmp(range, "all")) {
		_range.num = j - 1;
		for (i=0; i<=_range.num; i++) {
			_range.idx[i] = allrange[i];
		} 
	} else {
	/* break the string into tokens */
		ret = strtok(copy, ",");
		while (ret != NULL) {
			/* check for a range ie 3-9 */	
			if (strchr(ret, '-')) {
				r = sscanf(ret, "%d-%d", &start, &stop);
				if (r != 0) {
					for (i=start; i<=stop; i++) { tmprange[tmprangenum++] = i; }
				} else {
					printf("ERROR parsing range: %s\n", range);
				}
			} else {
				r = sscanf(ret, "%d", &start);
				if (r != 0) {
					tmprange[tmprangenum++] = start;
				} else {
					printf("ERROR parsing range: %s\n", range);
				}
			}
			ret = strtok(NULL, ",");
		}
		tmprangenum--;  /* starts at 0 */

		/* verify that this is a valid range */
		j = 0;
		for (i=0; i<=tmprangenum; i++) {
			/*	_range.idx[j++] = tmprange[i]; */
			if (isvalid(tmprange[i], allrange, allrangenum)) {
				_range.idx[j++] = tmprange[i];
			}
		}
		_range.num = j-1;
	}

}


/*------------------------------------------------------------------------------
 * do the opposite of make_range.  takes a range, and returns a string,
 * which is the range in human readable form.
 */
void de_make_range(struct Range *range, char *token) {

int i, len;

	if (range->num <= 0) {
		strcpy(token, "");
		return;
	}

	/* initialize our token */
	sprintf(token, "%d", range->idx[0]);

	for (i=1; i<=range->num; i++) {

		if (range->idx[i] != (range->idx[i-1]+1)) {
			len = strlen(token);
			if (token[len-1] == '-') {
				sprintf(token, "%s%d,%d", token, range->idx[i-1], range->idx[i]);
			} else {
				sprintf(token, "%s,%d", token, range->idx[i]);
			}
		} else if ( range->idx[i] == (range->idx[i-1] + 1)) {
			len = strlen(token);
			if (token[len-1] != '-') { sprintf(token, "%s-", token); }
		}
	}	

	/* finish off */
	sprintf(token, "%s%d", token, range->idx[i-1]);

}


/*------------------------------------------------------------------------------
 */
int matches_res(struct Range *range, int res) {
int i, found = 0;

	for (i=0; i<=range->num; i++) {
		if (range->idx[i] == res) {
			found = 1;
			break;
		}
	}
	
/*	printf("res %d in range? %d\n", res, found); */
	return found;
}


/*------------------------------------------------------------------------------
 * given the list of atoms to include, return true or false if this atom matches
 * ex. pattern = "C N CA" or pattern = "-H -O"
 */
int matches_atom(char *pattern, char *atomname) {
char *ret;
char copy[MAX_LINE_LEN];
char ret2[MAX_LINE_LEN];
int go = 0;

	strcpy(copy, pattern);

	/* look at each token in the pattern */
	ret = strtok(copy, " ");
	while (ret != NULL) {

		/* this is an exclude.  the atomname cannot start with the token */
		if (ret[0] == '-') {
			go = 1;
			/* get rid of the leading - */
			strcpy(ret2, ret+1);
			/* starts with the token.  a match */
			if (! strncmp(atomname, ret2, strlen(ret2))) { return 0; }
		}

		/* this is an include.  must be exact match. */
		if ( ! strcmp(atomname, ret) ) {  go = 1; }

		ret = strtok(NULL, " ");
	}

	return go;
}

/*------------------------------------------------------------------------------
 * return a matrix with the given specifications.
 * example: get_matrix("aria_1", (range 2-4), "N CA C")
 */
MAT * get_matrix(char *struc, struct Range *range, char *atoms ) {
MAT *A;
int realrow = 0;
int i, r, numatms, atomidx, resnum;
struct Pdbline *mypdb;
float tmpcdx[MAX_ATOM_NUM][3];


	for (i=0; i<=_numobjs; i++) {
		if (! strcmp(data[i].obj, struc)) {
			numatms = data[i].numatoms;
			mypdb = data[i].line;
		}
	}
	if (numatms == 0) { 
		printf("ERROR:  cannot find data for %s\n", struc);
		return m_get(0,0); 
	}

	/* for every residue number that we are looking at */
	for (r=0; r<= range->num; r++) {

		resnum = range->idx[r];

		/* get its list of corresponding atoms */
		numatms = _resdict2[resnum];
		for (i=0; i<numatms; i++) {
			atomidx = _resdict[resnum][i];

			if (matches_atom(atoms, mypdb[atomidx].atom)) { 
				tmpcdx[realrow][0] = mypdb[atomidx].x;
				tmpcdx[realrow][1] = mypdb[atomidx].y;
				tmpcdx[realrow][2] = mypdb[atomidx].z;
				realrow++;	
			}
		}
	}

	/* create the official matrix */
	A = m_get(realrow, 3);
	for (i=0; i<realrow; i++) {
		A->me[i][0] = tmpcdx[i][0];
		A->me[i][1] = tmpcdx[i][1];
		A->me[i][2] = tmpcdx[i][2];
	}
	
	return A;

}

/*------------------------------------------------------------------------------
 * debug function 
 */
MAT * get_matrix2(char *range, char *atoms ) {
MAT *A;


	A = m_get(3,3);
	A->me[0][0] = 18.39548492;
	A->me[0][1] = -14.33436584;
	A->me[0][2] = 5.37147713;
	A->me[1][0] = 19.98446655;
	A->me[1][1] = -14.84322929;
	A->me[1][2] = 7.20363998;
	A->me[2][0] = 19.86079025;
	A->me[2][1] = -14.43043613;
	A->me[2][2] = 5.7812438;
	return A;
}


/*------------------------------------------------------------------------------
 * take a matrix, and centre it
 *  m = rows
 *  n = columns
 */
MAT *centre_matrix(MAT * A, VEC *avg) {
MAT * B;
VEC *row, *tmp; 
int i, j;


	row = v_get(A->n);
	tmp = v_get(avg->dim);
	B = m_get(A->m,A->n);

	/* sum each column */
	for (i=0; i < A->m; i++) {
		get_row(A, i, row);
		v_add(tmp, row, tmp);
	}

	for (i=0; i < A->n; i++) { tmp->ve[i] = tmp->ve[i] / A->m ; }

	for (i=0; i < A->m; i++) { 
		for (j=0; j < A->n; j++) { 
			B->me[i][j] = A->me[i][j] - tmp->ve[j];	
		}
	}
	avg = v_copy(tmp, avg);

	V_FREE(row);
	V_FREE(tmp);

	return B;
}


/*------------------------------------------------------------------------------
 * routine to compute the determinant of a rotation matrix (3x3)
 */
int get_determinant(MAT *Rot) {
float det1;
int det;

	if (Rot->m != 3) {
		printf("ERROR:  size of Rot = %d\n", Rot->m);
		return(0);
	}

	det1 = (Rot->me[0][0] * Rot->me[1][1] * Rot->me[2][2] +
	      Rot->me[0][1] * Rot->me[1][2] * Rot->me[2][0] +
	      Rot->me[0][2] * Rot->me[1][0] * Rot->me[2][1] ) -

	      ( Rot->me[0][2] * Rot->me[1][1] * Rot->me[2][0] +
	      Rot->me[0][1] * Rot->me[1][0] * Rot->me[2][2] +
	      Rot->me[0][0] * Rot->me[1][2] * Rot->me[2][1] );

	det = det1;

	return(det);
}


/*------------------------------------------------------------------------------
 * do_super - do a superimposition
 */
static void do_super(char *ref, struct Range *myrange, char *atoms) {
MAT *A, *B, *A_, *B_, *C, *U, *V, *Vt, *Rot, *Tran;
VEC *d, *a_avg, *b_avg;
int rows, cols, i, j, det;


	/* get coord list for ref structure */
	A = get_matrix(ref, myrange, atoms);

	rows = A->m;
	cols = A->n;
	if (rows == 0 || cols == 0) { return; }
	C = m_get(cols,cols);
	U = m_get(C->m,C->m);
	V = m_get(C->n,C->n);
	Vt = m_get(C->n,C->n);
	Rot = m_get(rows,cols);
	Tran = m_get(rows,cols);
	d = v_get(cols);
	a_avg = v_get(cols);
	b_avg = v_get(cols);

	A_ = centre_matrix(A, a_avg);

	/* for each other structure */
	for (i=0; i<objs.size; i++) {

		if (! strcmp(data[i].obj, ref)) { continue; }

		/* get coordinate list */
		B = get_matrix(objs.list[i], myrange, atoms);

		B_ = centre_matrix(B, b_avg);
		mtrm_mlt(B_, A_, C);

		/* singular value decomposition */
		svd(C, U, V, d); 

		/* get rotation matrix */
		m_transp(V,Vt);
		m_mlt(Vt, U, Rot);  
		m_transp(Rot,Rot);

		/* check for a reflection */
		det = get_determinant(Rot);
		if (det < 0) {
			V->me[2][0] =  0 - V->me[2][0];
			V->me[2][1] =  0 - V->me[2][1];
			V->me[2][2] =  0 - V->me[2][2];
			m_transp(V,Vt);
			m_mlt(Vt, U, Rot);	
			m_transp(Rot,Rot);
		}

		/* get translation vector */
		vm_mlt(Rot,b_avg,d);
		v_sub(a_avg, d, d);

		/*
		if (debug == 1) {
			printf("object = %s\n", data[i].obj);
			printf("determinant = %d\n", det);
			print_vector(d, "translation vector");
			print_matrix(Rot, "rotation vector");
		}
		*/

		update_pdb(i, Rot, d);
		M_FREE(B);
		M_FREE(B_);
		
	}
	M_FREE(A);
	M_FREE(A_);
	M_FREE(C);
	M_FREE(U);
	M_FREE(V);
	M_FREE(Vt);
	M_FREE(Tran);
	M_FREE(Rot);
	V_FREE(d);
	V_FREE(a_avg);
	V_FREE(b_avg);


}


/*------------------------------------------------------------------------------
 * update our pdb data structures with the rotation and translations
 * to do the superimposition
 */
void update_pdb(int idx, MAT *Rot, VEC *tran) {
int i, j;
float x, y, z;

	for ( i = 0; i<data[idx].numatoms; i++) {	
	
		x = data[idx].line[i].x;
		y = data[idx].line[i].y;
		z = data[idx].line[i].z;
		data[idx].line[i].x = Rot->me[0][0] * x + Rot->me[1][0] * y + Rot->me[2][0] * z + tran->ve[0];
		data[idx].line[i].y = Rot->me[0][1] * x + Rot->me[1][1] * y + Rot->me[2][1] * z + tran->ve[1];
		data[idx].line[i].z = (Rot->me[0][2] * x) + (Rot->me[1][2] * y) + (Rot->me[2][2] * z) + tran->ve[2];

	}
}


/*------------------------------------------------------------------------------
 * debug code to print a vector out
 */
void print_vector(VEC *a, char *msg) {
int i;

	printf("%s \n", msg);
	for (i=0; i < a->dim; i++) { 
		printf("%f ", a->ve[i]);
	}
	printf("\n");
}

/*------------------------------------------------------------------------------
 * debug code to print a matrix out
 */
void print_matrix(MAT *A, char *msg) {
int i, j;

	printf("%s \n", msg);
	printf("rows = %d  columns = %d \n", A->m, A->n);
	for (i=0; i < A->m; i++) { 
		for (j=0; j < A->n; j++) { 
			printf("%f ", A->me[i][j]);
		}
		printf("\n");
	}
	printf("\n");
}


/*------------------------------------------------------------------------------
 * calc_mean - calculate the mean structure.  add this to the other pdb objects.
 */
static void calc_mean() {
int i, j, numatms;
int mean = _numobjs;
float x, y, z;

	/* for each pdb file, for each atom, compute the average */
	numatms = data[0].numatoms;

	for (i=0; i<numatms; i++) {
		x = y = z = 0;
		for (j=0; j<_numobjs; j++) {

			x += data[j].line[i].x;	
			y += data[j].line[i].y;	
			z += data[j].line[i].z;	
		}		
		j -= 1;
		data[mean].line[i].aNum = data[j].line[i].aNum;
		data[mean].line[i].rNum = data[j].line[i].rNum;
		strcpy(data[mean].line[i].atom, data[j].line[i].atom);
		strcpy(data[mean].line[i].res, data[j].line[i].res);

		data[mean].line[i].x = x / _numobjs;
		data[mean].line[i].y = y / _numobjs;
		data[mean].line[i].z = z / _numobjs;
		/*
		if ((debug == 1) && (i < 20)) {
		printf("mean coord [%d] = %f %f %f\n", i, x/_numobjs, y/_numobjs, z/_numobjs); 
		}
		*/
	}
	strcpy(data[mean].obj, "mean");
	data[mean].numatoms = numatms;

}

/*------------------------------------------------------------------------------
 * return rmsd deviation between two set of coordinates
 */
float calc_rmsd(MAT *A, MAT *B) {
MAT *C, *D;
VEC *v;
float result = 0;
int size, i, j;

	C = m_get(A->m,A->n);
	D = m_get(A->m,A->n);
	v = v_get(A->n);

	m_sub(A, B, C);
	size = C->m;

	/* multiply columns in matrix */
	for (i=0; i < C->n; i++) { 
		for (j=0; j < C->m; j++) { 
			D->me[j][i] = C->me[j][i] * C->me[j][i];
		}
	}

	/* sum columns in matrix */
	for (i=0; i < D->n; i++) { 
		for (j=0; j < D->m; j++) { 
			v->ve[i] += D->me[j][i];
		}
	}

	/* sum vector */
	for (i=0; i<v->dim; i++) {
		result += v->ve[i];
	}
	result = sqrt(result/size);

	M_FREE(C);
	M_FREE(D);
	V_FREE(v);

	return(result);
}


/*------------------------------------------------------------------------------
 * get_global_rmsd - calculate the rmsd from each structure to the mean
 */
static float get_global_rmsd(int refidx, struct Range *myrange, char *atoms) {
MAT *A, *B;
float rmsd, avermsd = 0;
int i;

	A = get_matrix(data[refidx].obj, myrange, atoms);

	if (A->m == 0) return 0.0 ;

	for (i=0; i<_numobjs; i++) {
		if (i == refidx) { continue; }
		B = get_matrix(data[i].obj, myrange, atoms);
		rmsd = calc_rmsd(A, B);

if (debug == 1) {
	printf("rmsd from %s to mean = %.3f \n", data[i].obj, rmsd);
}
		M_FREE(B);
		avermsd += rmsd;
	}

	M_FREE(A);


	if (avermsd == 0.0) { return 0.0; }

	return(avermsd/_numobjs);
	
}

/*------------------------------------------------------------------------------
 * calc_res_rmsd - calculate the rmsds per residue
 */
static void calc_res_rmsd(char *atoms) {
MAT *A, *B;
int i, j, rNum, mean, res0, resN;
float rms;
struct Range myrange;

struct Range {
		int num;
		int idx[MAX_RES_NUM];
};


	mean = _numobjs;
	rNum = -999;
	res0 = data[mean].line[0].rNum;
	resN = data[mean].line[data[0].numatoms-1].rNum;
	for (j=0; j<=resN; j++) {
		_rmsresavg[j] = 0;
	}


	for (i=res0; i<=resN; i++) {

		myrange.num = 0;
		myrange.idx[0] = i;

		A = get_matrix(data[mean].obj, &myrange, atoms);

		for (j=0; j<_numobjs; j++) {
			B = get_matrix(data[j].obj, &myrange, atoms);
			rms = calc_rmsd(A, B);
			M_FREE(B);
			_rmsresavg[i] += rms;
		}
		M_FREE(A);
		_rmsresavg[i] = _rmsresavg[i] / _numobjs;

	}

}

/*------------------------------------------------------------------------------
 * given the _rmsresavg, and the given range, eliminate the residue 
 * with the highest RMSD avereage value.
 */

void new_range(struct Range *myrange) {
int i;
float high = -9.0;
int highidx = -9;
int idx = -9;

	for (i=0; i<=myrange->num; i++) {
		if (_rmsresavg[myrange->idx[i]] > high) {
			high = _rmsresavg[myrange->idx[i]];
			highidx = myrange->idx[i];
			idx = i;
		}
	}

	printf(" -> eliminating res %2d with rmsd %5.2f ", highidx, high);
	/* bubble down the other values */
	for (i=idx; i<=myrange->num-1; i++) {
		myrange->idx[i] = myrange->idx[i+1];
	}
	myrange->num--;
	
}


/*-------------------------------------------------------------------------
 * MAIN ROUTINE.  This is called from the python code.
 */
char * C_do_best_range(char *ref, char *satoms, char *objs, char *range, char *workingdir, double limit) {
int i;
float rmsd;
struct Range newrange;
static char strrange[MAX_LINE_LEN];
struct tms before, after;
FILE *fp;


	init(objs, range, workingdir);
	newrange.num = _range.num;
	for (i=0; i<=newrange.num; i++) {
		newrange.idx[i] = _range.idx[i];
	}

	do_super(ref, &_range, satoms);
	calc_mean();
	debug = 0;


	/* _numobjs is the index of the mean object */
	rmsd = get_global_rmsd(_numobjs, &_range, satoms);
	printf("\nCALCULATING BEST RANGE FOR SUPERIMPOSITION:\n");
	printf(" -> starting rmsd: %.2f\n", rmsd);

	while (rmsd > limit) {

		calc_res_rmsd(satoms);

		/* eliminate worst residue */
		new_range(&newrange);

		/* now repeat the process */
		do_super(ref, &newrange, satoms);

		calc_mean();
		rmsd = get_global_rmsd(_numobjs, &newrange, satoms);

		printf("NEW rmsd = %.2f\n", rmsd);
		
	}
	printf("FINAL RMSD: %.2f\n", rmsd);

	/* return the range */
	de_make_range(&newrange, strrange);
	printf("BEST RANGE: %s\n", strrange);
	return strrange;

}


main () {
}
