/*

Description: Tensor computation
             Basic functions to process tensors
             
Author:      Raimundo Sierra
             www.rsierra.com

Copyright:   Copyright (c) 2000 Raimundo Sierra. All rights reserved.
LICENSE:     This program is free software; you can redistribute it and/or modify
             it under the terms of the GNU General Public License as published by
             the Free Software Foundation; either version 2 of the License, or
             (at your option) any later version.

             This program is distributed in the hope that it will be useful,
             but WITHOUT ANY WARRANTY; without even the implied warranty of
             MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
             GNU General Public License for more details.

             You should have received a copy of the GNU General Public License
             along with this program; if not, write to the Free Software
             Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

Institution: Surgical Planning Laboratory
             Department of Radiology
             Brigham and Women's Hospital
             Harvard Medical School
             MA 02115
             USA

Date:        11.2000 - 3.2001

Modifications:
<Date>   <Init>    <Version>    <Description>

*/
#ifndef _tensor
#define _tensor 1
#include <iostream.h>
#include <limits.h>

extern "C" {
#include "INSTALL/f2c.h"
#include "clapack.h"
}

const float INF =  FLT_MAX;                       // the biggest float number

#ifdef debug
  extern int tensorcounter;                       // counts calls to constructors
  extern int tensordestructorcounter;             // counts calls to destructors
#endif

class tensor
{
 private:
  float _t[9];                                    // the tensor
                                                  // sorts the eigenvalues by size, used in eig();
  void eigsort(const double *W, int *i) const;

 public:   

  tensor();                                       // constructors
  tensor(const tensor &t);                        // Copy constructor

                                                  // Tensor Data provided as 6 Values: 
                                                  //  / t1, t4, t5 \ 
                                                  //  | t4, t2, t6 | 
                                                  //  \ t5, t6, t3 /
  tensor(float t1, float t2, float t3, float t4, float t5, float t6);

                                                  // All 9 values provided:
                                                  //  / t1, t2, t3 \ 
                                                  //  | t4, t5, t6 | 
                                                  //  \ t7, t8, t9 /
  tensor(float t1, float t2, float t3, float t4, float t5, float t6, float t7, float t8, float t9); 
                             
  ~tensor();                                      // destructor

  void add(const tensor &a, const tensor &b);     // add two tensors; help function for +, - operator: c.add(a,b) is c=a+b
  void add(const tensor &a, const float x);       // add constant to tensor c.add(x) is c = c+x;

  void multiply(const tensor &a, const tensor &b);// multiply two tensors; help function for * operator
  void multiply(const tensor &a, const float x);  // multiply tensor with scalar

  void divide(const tensor &a, const tensor &b);  // divide two tensors; help function for / operator
                                                  // t.divide(s) means s/t => use operator overloading
  void divide(const tensor &a, const float x);    // divide tensor by scalar

  tensor& operator= (const tensor &t);     // Assignment operator
  bool isZero() const;                     // check if tensor is zero (true if so)
  bool operator==(const tensor &t) const;  // compares all elements of the tensors, no tolerance
  bool operator!=(const tensor &t) const;  // dito
  bool operator< (const tensor &t) const;  // less then, uses measurement specified in measure()
  bool operator> (const tensor &t) const;  // dito
  bool operator<=(const tensor &t) const;  // dito
  bool operator>=(const tensor &t) const;  // dito

  void print() const;                      // prints out the tensor

                                           // set tensor with data provided as 6 Values: 
                                           //  / t1, t4, t5 \ 
                                           //  | t4, t2, t6 | 
                                           //  \ t5, t6, t3 /
  void set(float t1, float t2, float t3, float t4, float t5, float t6);

                                           // set tensor with all 9 values provided:
                                           //  / t1, t2, t3 \ 
                                           //  | t4, t5, t6 | 
                                           //  \ t7, t8, t9 /
  void set(float t1, float t2, float t3, float t4, float t5, float t6, float t7, float t8, float t9);
  void setZero();                         // set all values to zero
  //void makeSymmetric();                   // set values 

  float measure()const {return euclid();}; // called by < > etc. enter here desired measurement
                                           // has to return a scalar value which is the measurement

  float euclid() const;                    // euclidian magnitude: Sqrt(Sum(a[i][j])) over the whole matrix
  float Max() const;                       // maximal value in tensor
  float Min() const;                       // minimal value in tensor

  void scale(const float factor);          // scale tensor by factor
  float det() const;                       // determinant of tensor
  float trace() const;                     // trace of tensor = sum of diagonal elements

  tensor inv(tensor &result) const;        // inverse of tensor
  tensor trans(tensor &result) const;      // transpose of tensor

                                           // get the tensorvalue at position i,j 
  float get(const int i, const int j) const;

  float cosinAngle(const tensor &s);       // "angle" measurement between 2 tensors;

                                           // computes eigenvalues and eigenvectors
                                           // calls CLAPACK function dsygv_() 
                                           // negative return value is an error, 1 is success
  int eig( float &eig1, float &eig2, float &eig3, float *vector1, float *vector2, float *vector3, const int sort=0) const;

                                           // singular value decomposition of tensor
                                           // Tensor = q*r*s where r is the matrix with diagonal elements r1, r2, r3 
                                           // q and s are the decompositions so that
  int svd( float &r1, float &r2, float &r3, tensor &q, tensor &s); 

  // the following functions do NOT save any processing time
  // compared to eig(), the just provide specific values
  // thus avoid them whenever possible -> use class eigen.h resp. eigenfield.h

  float* eigvec1(float *vector) const;     // first eigenvector, corresponding to largest eigenvalue
  float* eigvec2(float *vector) const;     // second eigenvector
  float* eigvec3(float *vector) const;     // third eigenvector, corresponding to smallest eigenvalue

  float eigval1() const;                   // first (largest) eigenvalue
  float eigval2() const;                   // second eigenvalue
  float eigval3() const;                   // third (smallest) eigenvalue

  float anisotropy() const;                // anisotropy measurement = 1 - (eigval3/eigval1) see paper 
                                           // Image Processing of DT MRI C.F. Westin

};

// operators + - * / << 
  
tensor operator+ (const tensor &s, const tensor &t);   // add two tensors  
tensor operator+ (const tensor &t, const float x);     // add constant to tensor
tensor operator+ (const float x,   const tensor &t);
  
tensor operator- (const tensor &s, const tensor &t);   // subtract two tensors  
tensor operator- (const tensor &t, const float x);     // subtract constant to tensor
tensor operator- (const float x,   const tensor &t);
  
tensor operator* (const tensor &s, const tensor &t);   // multiply two tensors  
tensor operator* (const tensor &t, const float x);     // multiply constant to tensor
tensor operator* (const float x,   const tensor &t);

tensor operator/ (const tensor &s, const tensor &t);   // divide two tensors  
tensor operator/ (const tensor &t, const float x);     // divide tensor by constant

tensor operator- (const tensor &t);                    // change sign

ostream &operator << (ostream &out, const tensor &t);  // output operator

#endif
