/*-----------------------------------------------------------------------------

Copyright (C) 2004, 2006, 2009.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/
#undef COMPILER_HAS_BASE_IOS

#include "snp.h"
#include "snpusr.h"
#include "gsl/gsl_vector.h"
#include "gsl/gsl_multimin.h"

using namespace scl;
using namespace libsnp;
using namespace snp;
using namespace std;

namespace {

  snpll gsl_ll;

  double gsl_f(const gsl_vector* v, void* params)
  {
    realmat rho = gsl_ll.get_rho(); 
    INTEGER lrho = rho.size();
    for (INTEGER i=1; i<=lrho; ++i) rho[i] = gsl_vector_get(v,i-1);
    gsl_ll.set_rho(rho);
    REAL log_like = gsl_ll.log_likehood();
    return -log_like/gsl_ll.get_datparms().n;
  }

  void gsl_df(const gsl_vector* v, void* params, gsl_vector* df)
  {
    realmat rho = gsl_ll.get_rho(); 
    INTEGER lrho = rho.size();
    for (INTEGER i=1; i<=lrho; ++i) rho[i] = gsl_vector_get(v,i-1);
    gsl_ll.set_rho(rho);
    realmat dllwrho;
    gsl_ll.log_likehood(dllwrho);
    for (INTEGER i=1; i<=lrho; ++i) {
      gsl_vector_set(df,i-1, -dllwrho[i]/gsl_ll.get_datparms().n);
    }
  }

  void gsl_fdf(const gsl_vector* v, void* params, double* f, gsl_vector* df)
  {
    realmat rho = gsl_ll.get_rho(); 
    INTEGER lrho = rho.size();
    for (INTEGER i=1; i<=lrho; ++i) rho[i] = gsl_vector_get(v,i-1);
    gsl_ll.set_rho(rho);
    realmat dllwrho;
    REAL log_like = gsl_ll.log_likehood(dllwrho);
    for (INTEGER i=1; i<=lrho; ++i) {
      gsl_vector_set(df,i-1, -dllwrho[i]/gsl_ll.get_datparms().n);
    }
    *f = -log_like/gsl_ll.get_datparms().n;
  }
}

  
int main(int argc, char** argp, char** envp)
{
  bool debug = false;

  ofstream detail_ofs("detail.dat");
  if(!detail_ofs) error("Error, snp, detail.dat open failed");
  ostream& detail = detail_ofs;

  ifstream ctrl_ifs("control.dat");
  if(!ctrl_ifs) error("Error, snp, control.dat open failed");
  control ctrl(ctrl_ifs);

  #if defined COMPILER_HAS_BASE_IOS 
    ofstream summary_ofs("summary.dat", ios_base::app);
  #else
    ofstream summary_ofs("summary.dat", ios::app);
  #endif
  if(!summary_ofs) error("Error, snp, summary.dat open failed");
  sumry summary(summary_ofs);
  summary.write_header();

  while(ctrl.read_line()) {

    parmfile pf;

    if(!pf.read_parms(ctrl.get_input_file().c_str(),detail)) {
       detail.flush();
       error("Error, snp, cannot read parmfile");
    }
    
    summary.set_filename(ctrl.get_output_file());
    summary.set_pname(pf.get_optparms().pname);
    summary.set_n(pf.get_datparms().n);

    realmat Y;
    datread_type dr; dr.initialize(pf.get_datparms());
    if (!dr.read_data(Y)) error("Error, snp, read_data failed");

    if (pf.get_optparms().print) {
      detail << starbox("/First 12 observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    }

    trnfrm tr(pf.get_datparms(), pf.get_tranparms(), Y);
    pf.set_tranparms(tr.get_tranparms());
    
    if (pf.get_optparms().print) {
      detail << starbox("/Mean and variance of data//");
      if (pf.get_tranparms().diag) {
        detail << "(Variance has been diagonalized.)\n";
      }
      detail << pf.get_tranparms().mean << pf.get_tranparms().variance;
    }

    realmat X = Y;

    tr.normalize(Y);
    tr.normalize(X);

    if (pf.get_tranparms().squash == 1) {
      tr.spline(X);
    } else if (pf.get_tranparms().squash == 2) {
      tr.logistic(X);
    }

    if (pf.get_optparms().print) {
      detail << starbox("/First 12 normalized observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 normalized observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
      if (pf.get_tranparms().squash > 0) {
        detail << starbox("/First 12 transformed observations//");
        detail << Y("",seq(1,12));
        detail << starbox("/Last 12 transformed observations//");
        detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
      }
    }

    // Trying to use a switch here generated an internal compiler error.

    if (pf.get_optparms().task == 0) { //fit

       detail.flush();

       snpll ll(pf.get_datparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(),
         pf.get_rfunc(), pf.get_afunc_mask(), pf.get_ufunc_mask(),
         pf.get_rfunc_mask());

       ll.set_XY(&X,&Y);
       summary.set_ll_ptr(&ll);  //Allows writing a full line

       realmat rho = ll.get_rho();
       
       if (debug) {
         realmat dllwrho;
         string str1 = "/Derivatives of log likelihood wrt rho";
         string str2 = "/index, computed, numerical, relative error//";
         detail << starbox((str1+str2).c_str()) << '\n';
         for (INTEGER i=1; i<=rho.size(); ++i) {
           REAL delta = 1.0e-8;
           REAL tmp = rho[i];
           rho[i] += delta;
           ll.set_rho(rho);
           REAL del = ll.log_likehood(dllwrho);        
           rho[i] = tmp - delta;
           ll.set_rho(rho);
           del -= ll.log_likehood(dllwrho);
           del /= 2.0*delta;
           rho[i] = tmp;
           ll.set_rho(rho);
           ll.log_likehood(dllwrho);
           REAL relerr = (dllwrho[i]-del)/(dllwrho[i]+delta);
           detail << "\t i = " << fmt('d',4,i) << fmt('g',17,8,dllwrho[i]) 
                  << fmt('g',17,8,del) << fmt('g',20,8,relerr) << '\n';
         } 
       }

       detail.flush();

       gsl_ll = ll;

       INTEGER lrho = rho.size();

       const gsl_multimin_fdfminimizer_type* T;
       gsl_multimin_fdfminimizer* s;

       gsl_vector* gsl_rho = gsl_vector_alloc(lrho);

       gsl_multimin_function_fdf gsl_func;
       gsl_func.f = &gsl_f;
       gsl_func.df = &gsl_df;
       gsl_func.fdf = &gsl_fdf;
       gsl_func.n = lrho;
       gsl_func.params = 0;

       T = gsl_multimin_fdfminimizer_vector_bfgs;
       s = gsl_multimin_fdfminimizer_alloc(T,lrho);

       realmat best_rho = rho;

       INTEGER best_iter = 0; 
       if (ctrl.get_nstart()>0) {
         if (pf.get_optparms().print) {
           detail << starbox("/Starting value of rho//");
           detail << rho;
           detail << starbox("/Trials//") << '\n';
           detail.flush();
         }
         best_iter =  pf.get_optparms().itmax0;
         REAL best_f = REAL_MAX;  
         INT_32BIT jseed = ctrl.get_jseed();
         for (INTEGER trial=1; trial<=ctrl.get_nstart(); ++trial) {
           ll.set_rho(rho, ctrl.get_fold(), ctrl.get_fnew(), jseed);
           rho = ll.get_rho();
           for (INTEGER i=1; i<=lrho; ++i) gsl_vector_set(gsl_rho, i-1, rho[i]);
           gsl_multimin_fdfminimizer_set
             (s, &gsl_func, gsl_rho, 0.001, pf.get_optparms().toler);
           int iter = 0;
           int status;
           do {
             ++iter;
             status = gsl_multimin_fdfminimizer_iterate(s);
             if (status) break;
             status = gsl_multimin_test_gradient(s->gradient, 1e-3);
           } 
           while (status == GSL_CONTINUE && iter < pf.get_optparms().itmax0);
           if (pf.get_optparms().print) {
             detail << "\t trial = " << trial 
                    << ",  -loglike/n  = " << fmt('e',25,17,s->f) 
                    << "  =  " << s->f <<'\n';
             detail.flush();
           }
           if (s->f < best_f) { 
             best_iter = iter;
             best_f = s->f;
             for (INTEGER i=1; i<=lrho; ++i) {
               best_rho[i] = gsl_vector_get(s->x, i-1);
             }
           }
         }
       }

       rho = best_rho;
       
       if (pf.get_optparms().print) {
         detail << starbox("/Starting value of rho//");
         detail << rho;
         detail.flush();
         detail << starbox("/BFGS iterations//") << '\n';
       }

       for (INTEGER i=1; i<=lrho; ++i) gsl_vector_set(gsl_rho, i-1, rho[i]);
       gsl_multimin_fdfminimizer_set
         (s, &gsl_func, gsl_rho, 0.001, pf.get_optparms().toler);

       int iter = 0;
       int status;
       do {
         ++iter;
         status = gsl_multimin_fdfminimizer_iterate(s);
         if (status) break;
         status = gsl_multimin_test_gradient(s->gradient, 1e-3);
         if (pf.get_optparms().print) {
           detail << "\t iter = " << iter 
                  << ",  -loglike/n  = " << fmt('e',25,17,s->f)
                  << "  =  " << s->f <<'\n';
           if (status==GSL_SUCCESS) detail << "\t Minimum found\n";
           detail.flush();
         }
       } 
       while (status == GSL_CONTINUE && iter < pf.get_optparms().itmax1);

       for (INTEGER i=1; i<=lrho; ++i) rho[i] = gsl_vector_get(s->x, i-1);

       ll.set_rho(rho);
       pf.set_afunc(ll.get_afunc());
       pf.set_ufunc(ll.get_ufunc());
       pf.set_rfunc(ll.get_rfunc());

       summary.set_sn(s->f);
       summary.set_iter(best_iter + iter);
       summary.set_termination_code( status == GSL_SUCCESS ? 0 : 1);
       summary.write_line();

       if (pf.get_optparms().print) {
         detail << starbox("/Estimated value of rho//");
         detail << rho;
         detail << starbox("/Stats and diagnostics//");
         detail << '\n';
         ll.write_stats(detail);
       }
                                  
       if (!pf.write_parms(ctrl.get_output_file().c_str(),ctrl.get_line(),ll)){
         detail.flush();
         error("Error, snp, cannot write parmfile"); 
        }

       gsl_multimin_fdfminimizer_free(s);
       gsl_vector_free(gsl_rho);

      }
     else if (pf.get_optparms().task == 1) { //res
       residual_type rt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       rt.initialize(&os);
       rt.initialize(ctrl.get_output_file());
       rt.set_XY(&X,&Y);

       if (!rt.calculate()) error("Error, snp, residual write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 2) { //mu

       mean_type mt(pf.get_optparms(), pf.get_datparms(), pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       mt.initialize(&os);
       mt.initialize(ctrl.get_output_file());
       mt.set_XY(&X,&Y);

       if (!mt.calculate()) error("Error, snp, mean write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 3) { //sig
       variance_type vt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       vt.initialize(&os);
       vt.initialize(ctrl.get_output_file());
       vt.set_XY(&X,&Y);

       if (!vt.calculate()) error("Error, snp, variance write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 4) { //plt
       plot_type pt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       pt.initialize(&os);
       pt.initialize(ctrl.get_output_file());
       pt.set_XY(&X,&Y);

       if (!pt.calculate()) error("Error, snp, plot write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 5) { //sim

       simulate_type st(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       st.initialize(&os);
       st.initialize(ctrl.get_output_file());
       st.set_XY(&X,&Y);

       if (!st.calculate()) error("Error, snp, simulation write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 6) { //usr

       user_type ut(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(),
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       ut.initialize(&os);
       ut.initialize(ctrl.get_output_file());
       ut.set_XY(&X,&Y);

       if (!ut.calculate()) error("Error, snp, user write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else { //error
       error("Error, snp, no such task");
     }

  }
  return 0;
}
