/*-----------------------------------------------------------------------------

Copyright (C) 2004, 2006, 2007, 2009.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libscl.h"
#include "snpusr.h"

using namespace scl;
using namespace libsnp;
using namespace snp;
using namespace std;

snp::datread::datread() { }

void snp::datread::initialize(datparms dp) {dpm = dp;}

bool snp::datread::read_data(realmat& data)
{
  if (data.nrow() != dpm.M || data.ncol() != dpm.n) {
    data.resize(dpm.M, dpm.n);
  }

  ifstream dat_stream(dpm.dsn.c_str());
  if (!dat_stream) return false;

  INTEGER max=0; 
  for (INTEGER i=1; i<=dpm.fields.size(); ++i) {
    max = dpm.fields[i] > max ? dpm.fields[i] : max ;
  }
  intvec idx(max,0);
  for (INTEGER i=1; i<=dpm.fields.size(); ++i) idx[dpm.fields[i]] = i;

  
  string discard;

  for (INTEGER t=1; t<=dpm.n; ++t) {
    for (INTEGER j=1; j<=max; ++j) {
      if (idx[j]==0) {
        dat_stream >> discard;
      }
      else {
        dat_stream >> data(idx[j],t);
      }
    }
    getline(dat_stream,discard);
  }

  return dat_stream.good();
}

snp::resmusig::resmusig(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm, 
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), af(afn), uf(ufn), rf(rfn),
  detail(dos), tr(trn)
{ }

void snp::resmusig::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow() != y->nrow()) error("Error, resmusig, row dim of x & y differ");
  if (x->ncol() != y->ncol()) error("Error, resmusig, col dim of x & y differ");
  X = x;  Y = y;
}

bool snp::resmusig::initialize(ostream* out_stream)
{
  os = out_stream;
  return (*os).good();
}

bool snp::resmusig::calculate()
{ 
  if (Y==0||X==0) error("Error, resmusig, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, resmusig, this should never happen");

  realmat dawa0,dawA;
  realmat duwb0, duwb0_lag;
  kronprd duwB, duwB_lag;
  realmat dRwb0,dRwB;
  realmat dRwRparms;

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = (*X)(i,1);
  }

  u = uf(x,duwb0,duwB);

  for (INTEGER t=3; t<=dpm.drop; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
      y_lag[i] = (*Y)(i,t-1);
      x_lag[i] = (*X)(i,t-1);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);
  }

  realmat mu, sig;
  
  for (INTEGER t=dpm.drop+1; t<=dpm.n; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y_lag[i] = y[i];
      x_lag[i] = x[i];
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);

    f.musig(mu,sig);

    if (opm.task == 1) {  //res
      realmat R(dpm.M,dpm.M);
      realmat P(dpm.M,dpm.M);
      if (dpm.M == 1) {
        R[1] = sqrt(sig[1]);
        P[1] = 1.0/R[1];
      }
      else {
        realmat U,S,V;
        svd(sig,U,S,V);
        for (INTEGER i=1; i<=dpm.M; ++i) {
          S[i] = sqrt(S[i]);
        }
        R = U*diag(S);
        for (INTEGER i=1; i<=dpm.M; ++i) {
          S[i] = 1.0/S[i];
        }
        P = diag(S)*T(U);
      }
      realmat z = P*(y-mu);
      for (INTEGER i=1; i<=z.size(); ++i) {
        (*os) << fmt('e',26,17,z[i]) << '\n';
      }
    }
    else if (opm.task == 2) {  //mu
      (*tr).unnormalize(mu);
      for (INTEGER i=1; i<=mu.size(); ++i) {
        (*os) << fmt('e',26,17,mu[i]) << '\n';
      }
    }
    else { //sig
      (*tr).unscale(sig);
      for (INTEGER j=1; j<=sig.get_cols(); ++j) {
        for (INTEGER i=1; i<=j; ++i) {
          (*os) << fmt('e',26,17,sig(i,j)) << '\n';
        } 
      }
    }
  }
  return (*os).good();
}

bool snp::resmusig::finalize()
{ 
  return (*os).good();
}

snp::simulate::simulate(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm,
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), af(afn), uf(ufn), rf(rfn),
  detail(dos), tr(trn)
{ }

void snp::simulate::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow() != y->nrow()) error("Error, simulate, row dim of x & y differ");
  if (x->ncol() != y->ncol()) error("Error, simulate, col dim of x & y differ");
  X = x;  Y = y;
}

bool snp::simulate::initialize(ostream* out_stream)
{
  os = out_stream;
  return (*os).good();
}

bool snp::simulate::calculate()
{ 
  if (Y==0||X==0) error("Error, simulate, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, simulate, this should never happen");

  INT_32BIT seed = opm.iseed;

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = (*X)(i,1);
  }

  u = uf(x);

  for (INTEGER t=1; t<=2; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
    }
    realmat v = y;
    (*tr).unnormalize(v);
    for (INTEGER i=1; i<=v.size(); ++i) {
      (*os) << fmt('e',26,17,v[i]) << '\n';
    }
  }

  for (INTEGER t=3; t<=dpm.drop; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
      y_lag[i] = (*Y)(i,t-1);
      x_lag[i] = (*X)(i,t-1);
    }

    u_lag = u;

    u = uf(x_lag);

    f.set_R(rf(y_lag,u_lag,x_lag));
    f.set_a(af(x_lag));
    f.set_u(u);
    
    realmat v = y;

    (*tr).unnormalize(v);
    for (INTEGER i=1; i<=v.size(); ++i) {
      (*os) << fmt('e',26,17,v[i]) << '\n';
    }
  }

  for (INTEGER t=dpm.drop+1; t<=dpm.n+opm.extra; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y_lag[i] = y[i];
      x_lag[i] = x[i];
    }

    u_lag = u;

    u = uf(x_lag);

    f.set_R(rf(y_lag,u_lag,x_lag));
    f.set_a(af(x_lag));
    f.set_u(u);


    y = f.sampy(seed);
    x = y;

    if (tpm.squash == 1) {
      tr->spline(x);
    } else if (tpm.squash == 2) {
      tr->logistic(x);
    }

    realmat v = y;
    tr->unnormalize(v);
    for (INTEGER i=1; i<=v.size(); ++i) {
      (*os) << fmt('e',26,17,v[i]) << '\n';
    }
  }
  return (*os).good();
}

bool snp::simulate::finalize()
{ 
  return (*os).good();
}

snp::plot::plot(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm,
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), af(afn), uf(ufn), rf(rfn),
  detail(dos), tr(trn)
{ }

void snp::plot::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow() != y->nrow()) error("Error, plot, row dim of x & y differ");
  if (x->ncol() != y->ncol()) error("Error, plot, col dim of x & y differ");
  X = x;  Y = y;
}

bool snp::plot::initialize(ostream* out_stream)
{
  os = out_stream;
  return (*os).good();
}

bool snp::plot::calculate()
{ 
  if (Y==0||X==0) error("Error, plot, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, plot, this should never happen");

  realmat dawa0,dawA;
  realmat duwb0, duwb0_lag;
  kronprd duwB, duwB_lag;
  realmat dRwb0,dRwB;
  realmat dRwRparms;

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = (*X)(i,1);
  }

  u = uf(x,duwb0,duwB);

  for (INTEGER t=3; t<=dpm.drop; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
      y_lag[i] = (*Y)(i,t-1);
      x_lag[i] = (*X)(i,t-1);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);
  }

  INTEGER maxlag = 0;
  maxlag = maxlag > af.get_lag() ? maxlag : af.get_lag(); 
  maxlag = maxlag > uf.get_lag() ? maxlag : uf.get_lag(); 
  maxlag = maxlag > rf.get_p()+1 ? maxlag : rf.get_p()+1; 
  maxlag = maxlag > rf.get_q()+1 ? maxlag : rf.get_q()+1; 
  maxlag = maxlag > rf.get_v()+1 ? maxlag : rf.get_v()+1; 
  maxlag = maxlag > rf.get_w()+1 ? maxlag : rf.get_w()+1; 

  if (opm.print) {
    detail << starbox("/Plot conditioning set//");
  }

  if (dpm.cond == 0) {
    for (INTEGER t=dpm.drop+1; t<=dpm.n; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
        y[i] = (*Y)(i,t);
        x[i] = (*X)(i,t);
      }
    
      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);
    }
    for (INTEGER t=1; t<=maxlag+1; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
      }
      y = tpm.mean;
      tr->normalize(y);
      x = y;
      if (tpm.squash == 1) {
        tr->spline(x);
      } else if (tpm.squash == 2) {
        tr->logistic(x);
      }

      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);

      if (opm.print && t>1) {
        realmat v = y_lag;
        (*tr).unnormalize(v);
        detail << cbind(v,x_lag);
      }
    }
  }
  else {
    for (INTEGER t=dpm.drop+1; t<=dpm.cond+1; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
        if (t<=dpm.cond) {
          y[i] = (*Y)(i,t);
          x[i] = (*X)(i,t);
        }
      }

      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);

      if (opm.print && dpm.cond+1-maxlag < t) {
        realmat v = y_lag;
        (*tr).unnormalize(v);
        detail << cbind(v,x_lag);
      }
    }
  }

  realmat mu, sig;
  f.musig(mu,sig);
  
  (*tr).unnormalize(mu);
  (*tr).unscale(sig);

  realmat yinc(dpm.M,1);
  for (INTEGER i=1; i<=dpm.M; ++i) {
    yinc[i] = opm.sfac*sqrt(sig(i,i))/opm.ngrid;
  }

  if (opm.print) {
    detail << starbox("/Conditional mean, variance, and plot increment//");
    detail << mu << sig << yinc;
  }

  for (INTEGER i=1; i<=dpm.M; ++i) {
    (*os) << mu[i] << '\n';
  }

  for (INTEGER j=1; j<=dpm.M; ++j) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
       (*os) << sig(i,j) << '\n';
    }
  }

  for (INTEGER i=1; i<=dpm.M; ++i) {
    (*os) << yinc[i] << '\n';
  }

  REAL dy = 1.0;
  for (INTEGER i=1; i<=dpm.M; ++i) { dy *= yinc[i]; }
  REAL sum_hat = 0.0;
  realmat mu_hat(dpm.M,1,0.0);
  realmat sig_hat(dpm.M,dpm.M,0.0);

  realmat v(dpm.M,1);

  intvec top(dpm.M, opm.ngrid);
  intvec idx(dpm.M,-opm.ngrid); 
  --idx[1];
  
  while (idx != top) {

    for (INTEGER i=1; i<=dpm.M; ++i) {
      if (idx[i] < opm.ngrid) {
        idx[i] += 1;
        break;
      }
      else {
        idx[i] = -opm.ngrid;
      }
    } 

    for (INTEGER i=1; i<=dpm.M; ++i) {
      v[i] = y[i] = yinc[i]*idx[i];
      (*os) << y[i] << '\n';
    }

    (*tr).normalize(v);
    
    realmat dlfwa, dlfwu, dlfwR;

    REAL lf = f.log_f(v, dlfwa, dlfwu, dlfwR);
    REAL f0 = exp(lf);
    f0 *= (*tr).get_detP();

    (*os) << f0 << '\n';

    sum_hat += f0*dy;
    mu_hat += y*(f0*dy);
    sig_hat += (y*T(y))*(f0*dy);
  } 
  sig_hat -= mu_hat*T(mu_hat);

  if (opm.print) {
    detail << starbox("/Check on suitability as a quadrature rule//");

    detail << starbox("/Values computed analytically//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,1.0) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu;
    detail << "\t Variance = " << sig;

    detail << starbox("/Values computed by quadrature//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,sum_hat) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu_hat;
    detail << "\t Variance = " << sig_hat;

    sum_hat = (sum_hat - 1.0)/(1.0 + EPS);
    for (INTEGER j=1; j<=dpm.M; ++j) {
      mu_hat[j] = (mu_hat[j] - mu[j])/(mu[j] + EPS);
      for (INTEGER i=1; i<=dpm.M; ++i) {
        sig_hat(i,j) = (sig_hat(i,j) - sig(i,j))/(sig(i,j) + EPS);
      }
    }

    detail << starbox("/Relative error//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,sum_hat) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu_hat;
    detail << "\t Variance = " << sig_hat;
    detail << '\n';
    detail.flush();
  }
    
  return (*os).good();
}

bool snp::plot::finalize()
{ 
  return (*os).good();
}

snp::leverage::leverage(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm,
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), af(afn), uf(ufn), rf(rfn),
  detail(dos), tr(trn)
{ }

void snp::leverage::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow() != y->nrow()) error("Error, leverage, row dim of x & y differ");
  if (x->ncol() != y->ncol()) error("Error, leverage, col dim of x & y differ");
  X = x;  Y = y;
}

bool snp::leverage::initialize(ostream* out_stream)
{
  os = out_stream;
  return (*os).good();
}

bool snp::leverage::calculate()
{ 
  if (Y==0||X==0) error("Error, leverage, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, leverage, this should never happen");

  INTEGER ngrid = 50;
  realmat delta(1,2*ngrid+1);
  realmat average(f.get_lR(),2*ngrid+1,0.0);


  realmat dawa0,dawA;
  realmat duwb0, duwb0_lag;
  kronprd duwB, duwB_lag;
  realmat dRwb0,dRwB;
  realmat dRwRparms;

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = (*X)(i,1);
  }

  u = uf(x,duwb0,duwB);

  for (INTEGER t=3; t<=dpm.drop; ++t) {
    
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
      y_lag[i] = (*Y)(i,t-1);
      x_lag[i] = (*X)(i,t-1);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);

  }

  for (INTEGER t=dpm.drop+1; t<=dpm.n; ++t) {

    for (INTEGER i=1; i<=dpm.M; ++i) {
      y_lag[i] = y[i];
      x_lag[i] = x[i];
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    for (INTEGER grid=-ngrid; grid<=ngrid; ++grid) {

      snpden fd = f;
      afunc afd = af;
      ufunc ufd = uf;
      rfunc rfd = rf;
  
      INTEGER idx = grid + ngrid + 1;
      delta[idx] = 5.0*REAL(grid)/REAL(ngrid);
      
      realmat yd = y_lag;
      yd[1] += delta[idx];
  
      realmat xd = yd;

      if (tpm.squash == 1) {
        tr->spline(xd);
      } else if (tpm.squash == 2) {
        tr->logistic(xd);
      }
   
      fd.set_R(rfd(yd,u_lag,xd,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      fd.set_a(afd(xd,dawa0,dawA));
      fd.set_u(ufd(xd,duwb0,duwB));
    
      realmat mu, sig;
      fd.musig(mu,sig);
    
      tr->unscale(sig);
      
      for (INTEGER j=1; j<=sig.get_cols(); ++j) {
        for (INTEGER i=1; i<=j; ++i) {
          INTEGER ij = (j*(j-1))/2 + i;
          average(ij,idx) += sig(i,j)/REAL(dpm.n-dpm.drop);
        } 
      }
    }
    
    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);
  }

  for (INTEGER grid=-ngrid; grid<=ngrid; ++grid) {
    INTEGER idx = grid + ngrid + 1;
    (*os) << delta[idx] << ' ';
    for (INTEGER i=1; i<=f.get_lR(); ++i) (*os) << average(i,idx) << ' ';
    (*os) << '\n';
  }

  return (*os).good();
}

bool snp::leverage::finalize()
{ 
  return (*os).good();
}

snp::quadrature::quadrature(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm,
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), af(afn), uf(ufn), rf(rfn),
  detail(dos), tr(trn)
{ }

void snp::quadrature::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow()!=y->nrow()) error("Error, quadrature, nrow of x & y differ");
  if (x->ncol()!=y->ncol()) error("Error, quadrature, ncol of x & y differ");
  X = x;  Y = y;
}

bool snp::quadrature::initialize(ostream* out_stream)
{
  os = out_stream;
  return (*os).good();
}

bool snp::quadrature::calculate()
{ 
  if (Y==0||X==0) error("Error, quadrature, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, quadrature, this should never happen");

  realmat dawa0,dawA;
  realmat duwb0, duwb0_lag;
  kronprd duwB, duwB_lag;
  realmat dRwb0,dRwB;
  realmat dRwRparms;

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = (*X)(i,1);
  }

  u = uf(x,duwb0,duwB);

  for (INTEGER t=3; t<=dpm.drop; ++t) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = (*Y)(i,t);
      x[i] = (*X)(i,t);
      y_lag[i] = (*Y)(i,t-1);
      x_lag[i] = (*X)(i,t-1);
    }

    u_lag = u;
    duwb0_lag = duwb0;
    duwB_lag = duwB;

    u = uf(x_lag,duwb0,duwB);

    f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
    f.set_a(af(x_lag,dawa0,dawA));
    f.set_u(u);
  }

  INTEGER maxlag = 0;
  maxlag = maxlag > af.get_lag() ? maxlag : af.get_lag(); 
  maxlag = maxlag > uf.get_lag() ? maxlag : uf.get_lag(); 
  maxlag = maxlag > rf.get_p()+1 ? maxlag : rf.get_p()+1; 
  maxlag = maxlag > rf.get_q()+1 ? maxlag : rf.get_q()+1; 
  maxlag = maxlag > rf.get_v()+1 ? maxlag : rf.get_v()+1; 
  maxlag = maxlag > rf.get_w()+1 ? maxlag : rf.get_w()+1; 

  if (opm.print) {
    detail<<starbox("/Quadrature conditioning set//");
  }

  if (dpm.cond == 0) {
    for (INTEGER t=dpm.drop+1; t<=dpm.n; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
        y[i] = (*Y)(i,t);
        x[i] = (*X)(i,t);
      }
    
      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);
    }
    for (INTEGER t=1; t<=maxlag+1; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
      }
      y = tpm.mean;
      tr->normalize(y);
      x = y;
      if (tpm.squash == 1) {
        tr->spline(x);
      } else if (tpm.squash == 2) {
        tr->logistic(x);
      }

      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);

      if (opm.print && t>1) {
        realmat v = y_lag;
        (*tr).unnormalize(v);
        detail << cbind(v,x_lag);
      }
    }
  }
  else {
    for (INTEGER t=dpm.drop+1; t<=dpm.cond+1; ++t) {
      for (INTEGER i=1; i<=dpm.M; ++i) {
        y_lag[i] = y[i];
        x_lag[i] = x[i];
        if (t<=dpm.cond) {
          y[i] = (*Y)(i,t);
          x[i] = (*X)(i,t);
        }
      }

      u_lag = u;
      duwb0_lag = duwb0;
      duwB_lag = duwB;

      u = uf(x_lag,duwb0,duwB);

      f.set_R(rf(y_lag,u_lag,x_lag,duwb0_lag,duwB_lag,dRwb0,dRwB,dRwRparms));
      f.set_a(af(x_lag,dawa0,dawA));
      f.set_u(u);

      if (opm.print && dpm.cond+1-maxlag < t) {
        realmat v = y_lag;
        (*tr).unnormalize(v);
        detail << cbind(v,x_lag);
      }
    }
  }

  realmat abcissae;
  realmat weights;

  f.quad(opm.ngrid,abcissae,weights);

  INTEGER npts = weights.size();

  if (opm.print) {
    detail << starbox("/Quadrature rule//");
    detail << '\n';
    detail << '\t' << "Order of rule = " << opm.ngrid << '\n';
    detail << '\t' << "Number of points = " << npts << '\n';
    detail << '\t' << "Dimension of abcissae = " << dpm.M << '\n';
    detail << '\n';
    detail << "\tOutput format is " << dpm.M << " values of the abcissae\n";
    detail << "\tfollowed by the weight, end to end, for a\n";
    detail << "\ttotal file length of " << (dpm.M + 1)*npts << '\n';
  }

  for (INTEGER j=1; j<=npts; ++j) {
    for (INTEGER i=1; i<=dpm.M; ++i) y[i] = abcissae(i,j);
    (*tr).unnormalize(y);
    for (INTEGER i=1; i<=dpm.M; ++i) abcissae(i,j) = y[i];
  } 

  realmat mu, sig;
  f.musig(mu,sig);
  
  (*tr).unnormalize(mu);
  (*tr).unscale(sig);

  if (opm.print) {
    detail << starbox("/Check on quadrature rule//");

    detail << starbox("/Values computed analytically//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,1.0) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu;
    detail << "\t Variance = " << sig;

    REAL sum_weights = 0;
    for (INTEGER j=1; j<=weights.size(); ++j) sum_weights += weights[j];

    realmat mu_hat(dpm.M,1,0.0);
    realmat sig_hat(dpm.M,dpm.M,0.0);
    realmat y_j;
    for (INTEGER j=1; j<=weights.size(); ++j) {
      y_j = abcissae("",j);
      mu_hat += y_j*weights[j];
      sig_hat += (y_j*T(y_j))*weights[j];
    }
    sig_hat -= mu_hat*T(mu_hat);

    detail << starbox("/Values computed by quadrature//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,sum_weights) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu_hat;
    detail << "\t Variance = " << sig_hat;

    sum_weights = (sum_weights - 1.0)/(1.0 + EPS);
    for (INTEGER j=1; j<=dpm.M; ++j) {
      mu_hat[j] = (mu_hat[j] - mu[j])/(mu[j] + EPS);
      for (INTEGER i=1; i<=dpm.M; ++i) {
        sig_hat(i,j) = (sig_hat(i,j) - sig(i,j))/(sig(i,j) + EPS);
      }
    }

    detail << starbox("/Relative error//");
    detail << '\n';
    detail << "\t Integral of density = " << fmt('f',8,5,sum_weights) << '\n';
    detail << '\n';
    detail << "\t Mean = " << mu_hat;
    detail << "\t Variance = " << sig_hat;
    detail << '\n';
    detail.flush();
  }

  for (INTEGER j=1; j<=npts; ++j) {
    for (INTEGER i=1; i<=dpm.M; ++i) {
      *os << fmt('f',25,17,abcissae(i,j)) << '\n';
    }
    *os << fmt('f',25,17,weights[j]) << '\n';
  }

  return (*os).good();
}

bool snp::quadrature::finalize()
{ 
  return (*os).good();
}

snp::rhostats::rhostats(optparms op, datparms dp, tranparms tp, snpden fn, 
  afunc afn, ufunc ufn, rfunc rfn, afunc afm, ufunc ufm, rfunc rfm,
  ostream& dos, const trnfrm* trn)
: opm(op), dpm(dp), tpm(tp), f(fn), 
  af(afn), uf(ufn), rf(rfn),
  af_mask(afm), uf_mask(ufm), rf_mask(rfm),
  detail(dos), tr(trn)
{ }

void snp::rhostats::set_XY(const realmat* x, const realmat* y)
{
  if (x->nrow() != y->nrow()) error("Error, rhostats, row dim of x & y differ");
  if (x->ncol() != y->ncol()) error("Error, rhostats, col dim of x & y differ");
  X = x;  Y = y;
}

bool snp::rhostats::initialize(ostream* out_stream)
{
  os = out_stream;
  return os->good();
}

bool snp::rhostats::initialize(string fn)
{
  filename = fn;
  return true;
}

bool snp::rhostats::calculate()
{ 
  if (Y==0||X==0) error("Error, rhostats, data not initialized");
  if (dpm.M != Y->nrow() || dpm.M != X->nrow() || dpm.M != f.get_ly())
    error("Error, rhostats, this should never happen");

  snpll ll(dpm,f,af,uf,rf,af_mask,uf_mask,rf_mask);
  ll.set_XY(X,Y);

  realmat rho = ll.get_rho();
  realmat dllwrho;
  realmat rho_infmat;
  realmat rho_hessian;
  REAL logl = ll.log_likehood(dllwrho,rho_infmat);
  REAL n = ll.get_datparms().n;
  ll.rho_hessian(rho_hessian);
  if (!vecwrite((filename+".parm").c_str(),rho)) return false;
  if (!vecwrite((filename+".logl").c_str(),realmat(1,1,logl))) return false;
  if (!vecwrite((filename+".size").c_str(),realmat(1,1,n))) return false;
  if (!vecwrite((filename+".grad").c_str(),dllwrho)) return false;
  if (!vecwrite((filename+".infm").c_str(),rho_infmat)) return false;
  if (!vecwrite((filename+".hess").c_str(),rho_hessian)) return false;;
  ll.write_stats(*os);
  return os->good();
}

bool snp::rhostats::finalize()
{ 
  return os->good();
}

