////////////////////////////////////////////////////////////////////////////////
//  implementation of a generic matrix class                                  //  
//  LAST EDIT: Fri Feb 10 15:40:20 1995 by ekki(@prakinf.tu-ilmenau.de)
////////////////////////////////////////////////////////////////////////////////
//  This file belongs to the YART implementation. Copying, distribution and   //
//  legal info is in the file COPYRGHT which should be distributed with this  //
//  file. If COPYRGHT is not available or for more info please contact:       //
//                                                                            //  
//		yart@prakinf.tu-ilmenau.de                                    //
//                                                                            //  
// (C) Copyright 1993 - 1995 YART team                                        //
////////////////////////////////////////////////////////////////////////////////

#include "genmatrix.h"
#include <math.h>

#ifndef EPSILON
#define EPSILON 1e-12
#endif

void _Matrix::print(FILE* f) const {
    int i,j;
    for( i = 0; i < rows; i++ ) {
	  for( j = 0; j < cols; j++ )
	      fprintf(f,"%.3f ", v[i]->v[j] );
	  fprintf(f,"\n");
      }
    fprintf(f,"\n");
}

double& _Matrix::elem(int i, int j) const { return v[i]->v[j]; }

void _Matrix::flip_rows(int i,int j) {
    _Vector* p = v[i];
    v[i] = v[j];
    v[j] = p;
}

void _Matrix::newGeometry(int dim1, int dim2) {
    if (dim1<0 || dim2<0) 
	printf("Matrix would have negative dimension.\n");
    
    if (v) {
	while(rows--) delete v[rows]; 
	delete v;
    }

    rows=dim1; 
    cols=dim2; 
  
    if (rows > 0) {
	v = new _Vector*[rows];
	for (int i=0;i<rows;i++) v[i] = new _Vector(cols); 
    }
    else v = 0;
}

_Matrix::_Matrix(int dim1, int dim2) {
    if (dim1<0 || dim2<0) 
	printf("_Matrix: negative dimension.\n"); 
    
    rows=dim1; 
    cols=dim2; 
    
    if (rows > 0) {
	v = new _Vector*[rows];
	for (int i=0;i<rows;i++) v[i] = new _Vector(cols); 
    }
    else v = 0;
}

_Matrix::_Matrix(const _Matrix& p) { 
    rows = p.rows;
    cols = p.cols;
    
    if (rows > 0) {
	v = new _Vector*[rows];
	for (int i=0;i<rows;i++) v[i] = new _Vector(*p.v[i]); 
    }
    else v = 0;
}

_Matrix::_Matrix(int dim1, int dim2, double** p) {
    rows=dim1; 
    cols=dim2; 
    v = new _Vector*[rows];
    for(int i=0;i<rows;i++) { 
	v[i] = new _Vector(cols); 
	for(int j=0;j<cols;j++) elem(i,j) = p[i][j];
    }
}

_Matrix::~_Matrix() {
    if (v) {
	while(rows--) delete v[rows]; 
	delete v;
    }
}

void _Matrix::check_dimensions(const _Matrix& mat) const { 
    if (rows != mat.rows || cols != mat.cols)
	printf("incompatible matrix types.\n");
}

_Matrix::_Matrix(const _Vector& vec) {
    rows = vec.d;
    cols = 1;
    v = new _Vector*[rows];
    for(int i=0; i<rows; i++) {
	v[i] = new _Vector(1);
	elem(i,0) = vec[i];
    }
}

_Matrix& _Matrix::operator=(const _Matrix& mat) {
    register int i,j;
    
    if (rows != mat.rows || cols != mat.cols) {
	for(i=0;i<rows;i++) delete v[i];
	delete v;
	rows = mat.rows;
	cols = mat.cols;
	v = new _Vector*[rows];
	for(i=0;i<rows;i++) v[i] = new _Vector(cols);
    }

    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++) elem(i,j) = mat.elem(i,j);

    return *this;
}

int _Matrix::operator==(const _Matrix& x) const {
    register int i,j;
    if (rows != x.rows || cols != x.cols) return 0;

    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    if (elem(i,j) != x.elem(i,j)) return 0;
    
    return 1;
}

_Vector& _Matrix::row(int i) const { 
    if ( i<0 || i>=rows )  printf("_Matrix: row index out of range\n");
    return *v[i];
}

double& _Matrix::operator()(int i, int j) {
    if ( i<0 || i>=rows )  printf("_Matrix: row index out of range\n");
    if ( j<0 || j>=cols )  printf("_Matrix: col index out of range\n");
    return elem(i,j);
}

double _Matrix::operator()(int i, int j) const {
    if ( i<0 || i>=rows )  printf("_Matrix: row index out of range\n");
    if ( j<0 || j>=cols )  printf("_Matrix: col index out of range\n");
    return elem(i,j);
}

_Vector _Matrix::col(int i) const {
    if ( i<0 || i>=cols )  printf("_Matrix: col index out of range\n");
    _Vector result(rows);
    int j = rows;
    while (j--) result.v[j] = elem(j,i);
    return result;
}

_Matrix::operator _Vector() const {
    if (cols!=1) 
	printf("error: cannot make vector from _Matrix\n");
    return col(0);
}

_Matrix _Matrix::operator+(const _Matrix& mat) {
    register int i,j;
    check_dimensions(mat);
    _Matrix result(rows,cols);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    result.elem(i,j) = elem(i,j) + mat.elem(i,j);
    return result;
}

_Matrix& _Matrix::operator+=(const _Matrix& mat) {
    register int i,j;
    check_dimensions(mat);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    elem(i,j) += mat.elem(i,j);
    return *this;
}

_Matrix& _Matrix::operator-=(const _Matrix& mat) {
    register int i,j;
    check_dimensions(mat);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    elem(i,j) -= mat.elem(i,j);
    return *this;
}

_Matrix _Matrix::operator-(const _Matrix& mat) {
    register int i,j;
    check_dimensions(mat);
    _Matrix result(rows,cols);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    result.elem(i,j) = elem(i,j) - mat.elem(i,j);
    return result;
}

_Matrix _Matrix::operator-() {
    register int i,j;
    _Matrix result(rows,cols);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    result.elem(i,j) = -elem(i,j);
    return result;
}

_Matrix _Matrix::operator*(double f) {
    register int i,j;
    _Matrix result(rows,cols);
    for(i=0;i<rows;i++)
	for(j=0;j<cols;j++)
	    result.elem(i,j) = elem(i,j) *f;
    return result;
}

_Matrix _Matrix::operator*(const _Matrix& mat) {
    if (cols!=mat.rows)
	printf("_Matrix multiplication: incompatible _Matrix types\n");
  
    _Matrix result(rows, mat.cols);
    register int i,j;
    
    for (i=0;i<mat.cols;i++)
	for (j=0;j<rows;j++) result.elem(j,i) = *v[j] * mat.col(i);
    
    return result;
}

double _Matrix::det() const {
    if (rows!=cols)  
	printf("_Matrix::det: _Matrix not quadratic.\n");

    int n = rows;

    _Matrix M(n,1);

    int flips;

    double** A = triang(M,flips);

    if (A == NULL) return 0;

    double Det = 1;

    for(int i=0;i<n;i++) Det *= A[i][i];

    for(i=0;i<n;i++) delete A[i];
    delete A;

    return (flips % 2) ? -Det : Det;
}

double** _Matrix::triang(const _Matrix& M, int& flips) const {
    register double **p, **q;
    register double *l, *r, *s;

    register double pivot_el,tmp;
    
    register int i,j, col, row;

    int n = rows;
    int d = M.cols;
    int m = n+d;

    double** A = new double*[n];
    
    p = A;

    for(i=0;i<n;i++) {
	*p = new double[m];
	l = *p++;
	for(j=0;j<n;j++) *l++ = elem(i,j);
	for(j=0;j<d;j++) *l++ = M.elem(i,j);
    }

    flips = 0;

    for (col=0, row=0; row<n; row++, col++) { 
	// search for row j with maximal absolute entry in current col
	j = row;
	for (i=row+1; i<n; i++)
	    if (fabs(A[j][col]) < fabs(A[i][col])) j = i; // IIIT WAS HEERE...fabs()
	
	if ( n > j && j > row) { 
	    double* p = A[j];
	    A[j] = A[row];
	    A[row] = p;
	    flips++;
	}

	tmp = A[row][col];
	q  = &A[row];

	if (fabs(tmp) < EPSILON) {
	    // _Matrix has not full rank
	    p = A;
	    for(i=0;i<n;i++) delete *p;
	    delete A;
	    return NULL;
	}
	
	for (p = &A[n-1]; p != q; p--) { 
	    l = *p+col;
	    s = *p+m;	
	    r = *q+col;
	    
	    if (*l != 0.0) {
		pivot_el = *l/tmp;
		while(l < s) *l++ -= *r++ * pivot_el;
	    }
	}
    }
    return A;
}

_Matrix _Matrix::inv() const {
    if (rows!=cols)  
	printf("_Matrix::inv: _Matrix not quadratic.\n");
    int n = rows;
    _Matrix I(n,n);
    for (int i=0; i<n; i++) I(i,i) = 1;
    return solve(I);
}

_Matrix _Matrix::solve(const _Matrix& M) const {
    if (rows != cols || rows != M.rows)
	printf( "Solve: wrong dimensions\n");

    register double **p, ** q;
    register double *l, *r, *s;

    int      n = rows;
    int      d = M.cols;
    int      m = n+d;
    int      row, col,i;

    double** A = triang(M,i);

    if (A == NULL) printf("_Matrix::solve: _Matrix has not full rank.");

    for (col = n-1, p = &A[n-1]; col>=0; p--, col--) { 
	s = *p+m;
	double tmp = (*p)[col];
	for (l=*p+n; l < s; l++) *l /=tmp;

	for(q = A; q != p; q++ ) {
	    tmp = (*q)[col];
	    l = *q+n;
	    r = *p+n;
	    while (r < s)  *l++ -= *r++ * tmp;
	}
    } 
    _Matrix result(n,d);

    for (row=0; row<n; row++) {
	l = A[row]+n;
	for (col=0; col<d; col++) result.elem(row,col) = *l++;
	delete A[row];
    }

    delete A;

    return result;
}

_Matrix _Matrix::trans() const {
    _Matrix result(cols,rows);
    for (int i = 0; i < cols; i++)
	for (int j = 0; j < rows; j++)
	    result.elem(i,j) = elem(j,i);
    return result;
}
