// Copyright (C) 2010, Guy Barrand. All rights reserved.
// See the file tools.license for terms.

#ifndef tools_qrot
#define tools_qrot

// rotation done with quaternion.

namespace tools {

template <class VEC3,class VEC4>
class qrot {
protected:
//typedef typename VEC3::elem_t T3;
  typedef typename VEC4::elem_t T; //we assume = T3
public:
  qrot()
  :m_quat(0,0,0,1)   //zero rotation around the positive Z axis.
  {}
  qrot(const VEC3& a_axis,T a_radians,T(*a_sin)(T),T(*a_cos)(T)){
    if(!set_value(a_axis,a_radians,a_sin,a_cos)) {} //FIXME : throw
  }
  qrot(const VEC3& a_from,const VEC3& a_to,T(*a_sqrt)(T),T(*a_fabs)(T)){set_value(a_from,a_to,a_sqrt,a_fabs);}
  virtual ~qrot(){}
public:
  qrot(const qrot& a_from)
  :m_quat(a_from.m_quat)
  {}
  qrot& operator=(const qrot& a_from){
    m_quat = a_from.m_quat;
    return *this;
  }
protected:
  qrot(T a_q0,T a_q1,T a_q2,T a_q3)
  :m_quat(a_q0,a_q1,a_q2,a_q3)
  {
    if(!m_quat.normalize()) {} //FIXME throw
  }

public:
  qrot& operator*=(const qrot& a_q) {
    //Multiplies the quaternions.
    //Note that order is important when combining quaternions with the
    //multiplication operator.
    // Formula from <http://www.lboro.ac.uk/departments/ma/gallery/quat/>

    T tx = m_quat.v0();
    T ty = m_quat.v1();
    T tz = m_quat.v2();
    T tw = m_quat.v3();

    T qx = a_q.m_quat.v0();
    T qy = a_q.m_quat.v1();
    T qz = a_q.m_quat.v2();
    T qw = a_q.m_quat.v3();

    m_quat.set_value(qw*tx + qx*tw + qy*tz - qz*ty,
                     qw*ty - qx*tz + qy*tw + qz*tx,
                     qw*tz + qx*ty - qy*tx + qz*tw,
                     qw*tw - qx*tx - qy*ty - qz*tz);
    m_quat.normalize();
    return *this;
  }

  bool operator==(const qrot& a_r) const {
    return m_quat.equal(a_r.m_quat);
  }
  bool operator!=(const qrot& a_r) const {
    return !operator==(a_r);
  }
  qrot operator*(const qrot& a_r) const {
    qrot tmp(*this);
    tmp *= a_r;
    return tmp;
  }

  bool invert(){
    T length = m_quat.length();
    if(length==T()) return false;

    // Optimize by doing 1 div and 4 muls instead of 4 divs.
    T inv = one() / length;

    m_quat.set_value(-m_quat.v0() * inv,
                     -m_quat.v1() * inv,
                     -m_quat.v2() * inv,
                      m_quat.v3() * inv);

    return true;
  }

  bool inverse(qrot& a_r) const {
    //Non-destructively inverses the rotation and returns the result.
    T length = m_quat.length();
    if(length==T()) return false;

    // Optimize by doing 1 div and 4 muls instead of 4 divs.
    T inv = one() / length;

    a_r.m_quat.set_value(-m_quat.v0() * inv,
                         -m_quat.v1() * inv,
                         -m_quat.v2() * inv,
                          m_quat.v3() * inv);

    return true;
  }


  bool set_value(const VEC3& a_axis,T a_radians,T(*a_sin)(T),T(*a_cos)(T)) {
    // Reset rotation with the given axis-of-rotation and rotation angle.
    // Make sure axis is not the null vector when calling this method.
    // From <http://www.automation.hut.fi/~jaro/thesis/hyper/node9.html>.
    if(a_axis.length()==T()) return false;
    m_quat.v3(a_cos(a_radians/2));
    T sineval = a_sin(a_radians/2);
    VEC3 a = a_axis;
    a.normalize();
    m_quat.v0(a.v0() * sineval);
    m_quat.v1(a.v1() * sineval);
    m_quat.v2(a.v2() * sineval);
    return true;
  }

  bool set_value(const VEC3& a_from,const VEC3& a_to,T(*a_sqrt)(T),T(*a_fabs)(T)) {
    // code taken from coin3d/SbRotation.
    //NOTE : coin3d/SbMatrix logic is transposed relative to us.

    VEC3 from(a_from);
    if(from.normalize()==T()) return false;
    VEC3 to(a_to);
    if(to.normalize()==T()) return false;

    T dot = from.dot(to);
    VEC3 crossvec;from.cross(to,crossvec);
    T crosslen = crossvec.normalize();

    if(crosslen == T()) { // Parallel vectors
      // Check if they are pointing in the same direction.
      if (dot > T()) {
        m_quat.set_value(0,0,0,1);
      } else {
        // Ok, so they are parallel and pointing in the opposite direction
        // of each other.
        // Try crossing with x axis.
        VEC3 t;from.cross(VEC3(1,0,0),t);
        // If not ok, cross with y axis.
        if(t.normalize() == T()) {
          from.cross(VEC3(0,1,0),t);
          t.normalize();
        }
        m_quat.set_value(t[0],t[1],t[2],0);
      }
    } else { // Vectors are not parallel
      // The fabs() wrapping is to avoid problems when `dot' "overflows"
      // a tiny wee bit, which can lead to sqrt() returning NaN.
      crossvec *= (T)a_sqrt(half() * a_fabs(one() - dot));
      // The fabs() wrapping is to avoid problems when `dot' "underflows"
      // a tiny wee bit, which can lead to sqrt() returning NaN.
      m_quat.set_value(crossvec[0], crossvec[1], crossvec[2],(T)a_sqrt(half()*a_fabs(one()+dot)));
    }

    return true;
  }


  bool value(VEC3& a_axis,T& a_radians,T(*a_sin)(T),T(*a_acos)(T)) const { //WARNING a_acos and NOT a_cos
    //WARNING : can fail.
    if( (m_quat.v3()<minus_one()) || (m_quat.v3()> one()) ) {
      a_axis.set_value(0,0,1);
      a_radians = 0;
      return false;
    }

    a_radians = a_acos(m_quat.v3()) * 2; //in [0,2*pi]
    T sineval = a_sin(a_radians/2);

    if(sineval==T()) { //a_radian = 0 or 2*pi.
      a_axis.set_value(0,0,1);
      a_radians = 0;
      return false;
    }
    a_axis.set_value(m_quat.v0()/sineval,
                     m_quat.v1()/sineval,
                     m_quat.v2()/sineval);
    return true;
  }

  template <class MAT4>
  void set_value(const MAT4& a_m,T(*a_sqrt)(T)) {
    // See tests/qrot.icc.
    // code taken from coin3d/SbRotation.

    //Set the rotation from the components of the given matrix.

    T scalerow = a_m.v00() + a_m.v11() + a_m.v22();
    if (scalerow > T()) {
      T _s = a_sqrt(scalerow + a_m.v33());
      m_quat.v3(_s * half());
      _s = half() / _s;

      m_quat.v0((a_m.v21() - a_m.v12()) * _s);
      m_quat.v1((a_m.v02() - a_m.v20()) * _s);
      m_quat.v2((a_m.v10() - a_m.v01()) * _s);
    } else {
      unsigned int i = 0;
      if (a_m.v11() > a_m.v00()) i = 1;
      if (a_m.v22() > a_m.value(i,i)) i = 2;

      unsigned int j = (i+1)%3;
      unsigned int k = (j+1)%3;

      T _s = a_sqrt((a_m.value(i,i) - (a_m.value(j,j) + a_m.value(k,k))) + a_m.v33());

      m_quat.set_value(i,_s * half());
      _s = half() / _s;

      m_quat.v3((a_m.value(k,j) - a_m.value(j,k)) * _s);
      m_quat.set_value(j,(a_m.value(j,i) + a_m.value(i,j)) * _s);
      m_quat.set_value(k,(a_m.value(k,i) + a_m.value(i,k)) * _s);
    }

    if (a_m.v33()!=one()) {
      m_quat.multiply(one()/a_sqrt(a_m.v33()));
    }
  }

  template <class MAT4>
  void value(MAT4& a_m) const {
    //Return this rotation in the form of a matrix.
    //NOTE : in coin3d/SbRotation, it looks as if "w <-> -w", but coin3d/SbMatrix logic is transposed relative to us.

    const T x = m_quat.v0();
    const T y = m_quat.v1();
    const T z = m_quat.v2();
    const T w = m_quat.v3();
    // q = w + x * i + y * j + z * k

    // first row :
    a_m.v00(w*w + x*x - y*y - z*z);
    a_m.v01(2*x*y - 2*w*z);
    a_m.v02(2*x*z + 2*w*y);
    a_m.v03(0);

    // second row :
    a_m.v10(2*x*y + 2*w*z);
    a_m.v11(w*w - x*x + y*y - z*z);
    a_m.v12(2*y*z - 2*w*x);
    a_m.v13(0);

    // third row :
    a_m.v20(2*x*z - 2*w*y);
    a_m.v21(2*y*z + 2*w*x);
    a_m.v22(w*w - x*x - y*y + z*z);
    a_m.v23(0);

    // fourth row :
    a_m.v30(0);
    a_m.v31(0);
    a_m.v32(0);
    a_m.v33(w*w + x*x + y*y + z*z);
  }

  template <class MAT3>
  T value_3(MAT3& a_m) const {
    //Return this rotation in the form of a 3D matrix.
    //NOTE : in coin3d/SbRotation, it looks as if "w <-> -w", but coin3d/SbMatrix logic is transposed relative to us.

    const T x = m_quat.v0();
    const T y = m_quat.v1();
    const T z = m_quat.v2();
    const T w = m_quat.v3();
    // q = w + x * i + y * j + z * k

    // first row :
    a_m.v00(w*w + x*x - y*y - z*z);
    a_m.v01(2*x*y - 2*w*z);
    a_m.v02(2*x*z + 2*w*y);

    // second row :
    a_m.v10(2*x*y + 2*w*z);
    a_m.v11(w*w - x*x + y*y - z*z);
    a_m.v12(2*y*z - 2*w*x);

    // third row :
    a_m.v20(2*x*z - 2*w*y);
    a_m.v21(2*y*z + 2*w*x);
    a_m.v22(w*w - x*x - y*y + z*z);

    // fourth row :
    //a_m.v30(0);
    //a_m.v31(0);
    //a_m.v32(0);
    //a_m.v33(w*w + x*x + y*y + z*z);
    return w*w + x*x + y*y + z*z;  //should be 1.
  }

  void mul_vec(const VEC3& a_in,VEC3& a_out) const {
    const T x = m_quat.v0();
    const T y = m_quat.v1();
    const T z = m_quat.v2();
    const T w = m_quat.v3();

    // first row :
    T v0 =     (w*w + x*x - y*y - z*z) * a_in.v0()
              +        (2*x*y - 2*w*z) * a_in.v1()
              +        (2*x*z + 2*w*y) * a_in.v2();

    T v1 =             (2*x*y + 2*w*z) * a_in.v0()
              +(w*w - x*x + y*y - z*z) * a_in.v1()
              +        (2*y*z - 2*w*x) * a_in.v2();

    T v2 =             (2*x*z - 2*w*y) * a_in.v0()
              +        (2*y*z + 2*w*x) * a_in.v1()
              +(w*w - x*x - y*y + z*z) * a_in.v2();

    a_out.set_value(v0,v1,v2);
  }

  void mul_vec(VEC3& a_v) const {
    const T x = m_quat.v0();
    const T y = m_quat.v1();
    const T z = m_quat.v2();
    const T w = m_quat.v3();

    // first row :
    T v0 =     (w*w + x*x - y*y - z*z) * a_v.v0()
              +        (2*x*y - 2*w*z) * a_v.v1()
              +        (2*x*z + 2*w*y) * a_v.v2();

    T v1 =             (2*x*y + 2*w*z) * a_v.v0()
              +(w*w - x*x + y*y - z*z) * a_v.v1()
              +        (2*y*z - 2*w*x) * a_v.v2();

    T v2 =             (2*x*z - 2*w*y) * a_v.v0()
              +        (2*y*z + 2*w*x) * a_v.v1()
              +(w*w - x*x - y*y + z*z) * a_v.v2();

    a_v.set_value(v0,v1,v2);
  }

  void mul_3(T& a_x,T& a_y,T& a_z) const {
    const T x = m_quat.v0();
    const T y = m_quat.v1();
    const T z = m_quat.v2();
    const T w = m_quat.v3();

    // first row :
    T v0 =     (w*w + x*x - y*y - z*z) * a_x
              +        (2*x*y - 2*w*z) * a_y
              +        (2*x*z + 2*w*y) * a_z;

    T v1 =             (2*x*y + 2*w*z) * a_x
              +(w*w - x*x + y*y - z*z) * a_y
              +        (2*y*z - 2*w*x) * a_z;

    T v2 =             (2*x*z - 2*w*y) * a_x
              +        (2*y*z + 2*w*x) * a_y
              +(w*w - x*x - y*y + z*z) * a_z;

    a_x = v0;
    a_y = v1;
    a_z = v2;
  }

public: //for io::streamer
  const VEC4& quat() const {return m_quat;}
  VEC4& quat() {return m_quat;}

protected:
  static T one() {return T(1);}
  static T minus_one() {return T(-1);}
  static T half() {return T(0.5);}

protected:
  VEC4 m_quat;

public:
  //NOTE : don't handle a static object because of mem balance.
  //static const qrot<double>& identity() {
  //  static const qrot<double> s_v(0,0,0,1);
  //  return s_v;
  //}

//private:static void check_instantiation() {qrot<float> v;}
};

}

#endif
