/*-----------------------------------------------------------------------------

Copyright (C) 2004, 2005, 2006, 2009, 2010.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/
#undef COMPILER_HAS_BASE_IOS

#include "snp.h"
#include "snpusr.h"

using namespace scl;
using namespace libsnp;
using namespace snp;
using namespace std;

namespace {

  snpll nlopt_ll;

  class objective_function : public nleqns_base {
  private:
    INTEGER count;
  public:
    objective_function() : count(0) { }
    void reset_count() { count = 0; }
    INTEGER get_count() { return count; }
    bool get_f(const realmat& rho, realmat& f)
    {
      if (f.nrow()!=1) f.resize(1,1);
      nlopt_ll.set_rho(rho);
      REAL log_like = nlopt_ll.log_likehood();
      f[1] = -log_like/nlopt_ll.get_datparms().n;
      ++count;
      return true;
    }
    bool get_F(const realmat& rho, realmat& f, realmat& F)
    {
      if (f.nrow()!=1) f.resize(1,1);
      nlopt_ll.set_rho(rho);
      realmat dllwrho;
      REAL log_like = nlopt_ll.log_likehood(dllwrho);
      f[1] = -log_like/nlopt_ll.get_datparms().n;
      F = -dllwrho/nlopt_ll.get_datparms().n;
      ++count;
      return true;
    }
  };
}

  
int main(int argc, char** argp, char** envp)
{
  const bool debug = false;
  const bool set_H_matrix = false;   // Setting to true degrades performance.

  ofstream detail_ofs("detail.dat");
  if(!detail_ofs) error("Error, snp, detail.dat open failed");
  ostream& detail = detail_ofs;

  ifstream ctrl_ifs("control.dat");
  if(!ctrl_ifs) error("Error, snp, control.dat open failed");
  control ctrl(ctrl_ifs);

  #if defined COMPILER_HAS_BASE_IOS 
    ofstream summary_ofs("summary.dat", ios_base::app);
  #else
    ofstream summary_ofs("summary.dat", ios::app);
  #endif
  if(!summary_ofs) error("Error, snp, summary.dat open failed");
  sumry summary(summary_ofs);
  summary.write_header();

  while(ctrl.read_line()) {

    parmfile pf;

    if(!pf.read_parms(ctrl.get_input_file().c_str(),detail)) {
       detail.flush();
       error("Error, snp, cannot read parmfile");
    }

    if (pf.get_optparms().print) {
      detail << starbox("/Control line for above parmfile//");
      detail << '\n' << ctrl.get_line() << endl;
    }
    
    summary.set_filename(ctrl.get_output_file());
    summary.set_pname(pf.get_optparms().pname);
    summary.set_n(pf.get_datparms().n);

    realmat Y;
    datread_type dr; dr.initialize(pf.get_datparms());
    if (!dr.read_data(Y)) error("Error, snp, read_data failed");

    if (pf.get_optparms().print) {
      detail << starbox("/First 12 observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    }

    trnfrm tr(pf.get_datparms(), pf.get_tranparms(), Y);
    pf.set_tranparms(tr.get_tranparms());
    
    if (pf.get_optparms().print) {
      detail << starbox("/Mean and variance of data//");
      if (pf.get_tranparms().diag) {
        detail << "(Variance has been diagonalized.)\n";
      }
      detail << pf.get_tranparms().mean << pf.get_tranparms().variance;
    }

    realmat X = Y;

    tr.normalize(Y);
    tr.normalize(X);
    
    if (pf.get_tranparms().squash == 1) {
      tr.spline(X);
    } else if (pf.get_tranparms().squash == 2) {
      tr.logistic(X);
    }

    if (pf.get_optparms().print) {
      detail << starbox("/First 12 normalized observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 normalized observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
      if (pf.get_tranparms().squash > 0) {
        detail << starbox("/First 12 transformed observations//");
        detail << X("",seq(1,12));
        detail << starbox("/Last 12 transformed observations//");
        detail << X("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
      }
    }

    // Trying to use a switch here generated an internal compiler error.

    if (pf.get_optparms().task == 0) { //fit

       if (pf.get_optparms().print) detail.flush();

       snpll ll(pf.get_datparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(),
         pf.get_rfunc(), pf.get_afunc_mask(), pf.get_ufunc_mask(),
         pf.get_rfunc_mask());

       ll.set_XY(&X,&Y);
       summary.set_ll_ptr(&ll);  //Allows writing a full summary line

       realmat rho = ll.get_rho();

       realmat H_matrix;

       if (set_H_matrix) { // This doesn't work well.  Delete sometime.
         realmat dllwrho;
         realmat rho_infmat;
         ll.log_likehood(dllwrho, rho_infmat);
         rho_infmat = rho_infmat/ll.get_datparms().n;
         H_matrix = invpsd(rho_infmat);  
         #if defined GNU_GPP_COMPILER
           bool is_finite = true;
           for (INTEGER i=1; i<=H_matrix.size(); ++i) {
             if (!finite(H_matrix[i])) is_finite = false;
           }
           if (!is_finite) {
             fill(H_matrix);
             for (INTEGER i=1; i<=H_matrix.get_rows(); ++i) {
               H_matrix(i,i) = 1.0;
             }
           }  
         #endif
       }
       
       if (debug) {
         realmat dllwrho;
         detail << starbox("/Test log likelihood for coding errors//") <<'\n';
         INTEGER inc = rho.size()/10 > 1 ? rho.size()/10 : 1;
         for (INTEGER i=1; i<=rho.size(); i+=inc) {
           REAL delta = 1.e-3*rho[i];
           REAL tmp = rho[i];
           rho[i] += delta;
           ll.set_rho(rho);
           rho[i] = tmp;
           REAL lla = ll.log_likehood(dllwrho);
           REAL llb = ll.log_likehood();
           detail << "\t i = " << fmt('d',4,i) << fmt('g',17,8,lla) 
                  << fmt('e',26,16,lla-llb) << '\n';
         }
       }

       if (debug) {
         realmat dllwrho;
         string str1 = "/Derivatives of log likelihood wrt rho";
         string str2 = "/index, computed, numerical, relative error//";
         detail << starbox((str1+str2).c_str()) << '\n';
         for (INTEGER i=1; i<=rho.size(); ++i) {
           REAL delta = pow(REAL_EPSILON,0.33333333);
           REAL tmp = rho[i];
           rho[i] += delta;
           ll.set_rho(rho);
           REAL del = ll.log_likehood(dllwrho);        
           rho[i] = tmp - delta;
           ll.set_rho(rho);
           del -= ll.log_likehood(dllwrho);
           del /= 2.0*delta;
           rho[i] = tmp;
           ll.set_rho(rho);
           ll.log_likehood(dllwrho);
           REAL relerr = (dllwrho[i]-del)/(dllwrho[i]+delta);
           detail << "\t i = " << fmt('d',4,i) << fmt('g',17,8,dllwrho[i]) 
                  << fmt('g',17,8,del) << fmt('g',20,8,relerr) << '\n';
         } 
       }

       if (pf.get_optparms().print) detail.flush();

       nlopt_ll = ll;

       objective_function obj;

       nlopt minimizer(obj);
       if (pf.get_optparms().print) {
         minimizer.set_output(true, &detail);
         minimizer.set_warning_messages(true);
       }
       minimizer.set_iter_limit(pf.get_optparms().itmax0);
       minimizer.set_solution_tolerance(pf.get_optparms().toler);

       realmat best_rho = rho;

       INTEGER best_iter = 0; 
       INTEGER evaluations = 0;
       
       if (ctrl.get_nstart()>0) {
         minimizer.set_output(false);
         minimizer.set_warning_messages(false);
         if (pf.get_optparms().print) {
           detail << starbox("/Starting value of rho//");
           detail << rho;
           detail << starbox("/Trials//") << '\n';
           detail.flush();
         }
         best_iter = pf.get_optparms().itmax0;
         REAL best_f = REAL_MAX;  
         INT_32BIT jseed = ctrl.get_jseed();
         for (INTEGER trial=1; trial<=ctrl.get_nstart(); ++trial) {
           ll.set_rho(rho, ctrl.get_fold(), ctrl.get_fnew(), jseed);
           realmat rho_start = ll.get_rho();
           realmat rho_stop;
           realmat sn;
           //obj.get_f(rho_start,sn);
           obj.reset_count();
           if (set_H_matrix) minimizer.set_H_matrix(H_matrix);
           minimizer.minimize(rho_start, rho_stop);
           obj.get_f(rho_stop,sn);

           evaluations += obj.get_count();

           if (pf.get_optparms().print) {
             detail << "     trial =" << fmt('d',4,trial) 
                    << ",  obj =" << fmt('e',23,15,sn[1]) 
                    << ",  eval = " << fmt('d',4,obj.get_count())
                    << ",  iter = " << fmt('d',4,minimizer.get_iter_count())
                    << '\n';
             detail.flush();
           }
           if (sn[1] < best_f) { 
             best_iter = minimizer.get_iter_count();
             best_f = sn[1];
             best_rho = rho_stop;
           }
         }
         if (pf.get_optparms().print) {
           minimizer.set_output(true, &detail);
           minimizer.set_warning_messages(true);
         }
       }

       realmat rho_start = best_rho;
       realmat rho_stop = rho_start;
       
       minimizer.set_iter_limit(pf.get_optparms().itmax1);

       if (set_H_matrix) minimizer.set_H_matrix(H_matrix);

       obj.reset_count();

       if (pf.get_optparms().itmax1 > 0) {
         minimizer.minimize(rho_start, rho_stop);
       }

       realmat sn;
       obj.get_f(rho_stop,sn);

       rho = rho_stop;

       ll.set_rho(rho);
       pf.set_afunc(ll.get_afunc());
       pf.set_ufunc(ll.get_ufunc());
       pf.set_rfunc(ll.get_rfunc());

       summary.set_sn(sn[1]);
       summary.set_iter(best_iter + minimizer.get_iter_count());
       summary.set_termination_code(minimizer.get_termination_code());
       summary.write_line();

       evaluations += obj.get_count();

       if (pf.get_optparms().print) {
         detail << "\n\t Number of function evaluations = " 
                << obj.get_count() << '\n';
         detail << "\n\t Total number of function evaluations = " 
                << evaluations << '\n';
         detail << starbox("/Estimated value of rho//");
         detail << rho;
         detail << starbox("/Stats and diagnostics//");
         detail << '\n';
         ll.write_stats(detail);
       }
                                  
       if (!pf.write_parms(ctrl.get_output_file().c_str(),ctrl.get_line(),ll)){
         detail.flush();
         error("Error, snp, cannot write parmfile"); 
        }

      }
     else if (pf.get_optparms().task == 1) { //res
       residual_type rt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       rt.initialize(&os);
       rt.initialize(ctrl.get_output_file());
       rt.set_XY(&X,&Y);

       if (!rt.calculate()) error("Error, snp, residual write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 2) { //mu

       mean_type mt(pf.get_optparms(), pf.get_datparms(), pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       mt.initialize(&os);
       mt.initialize(ctrl.get_output_file());
       mt.set_XY(&X,&Y);

       if (!mt.calculate()) error("Error, snp, mean write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 3) { //sig
       variance_type vt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       vt.initialize(&os);
       vt.initialize(ctrl.get_output_file());
       vt.set_XY(&X,&Y);

       if (!vt.calculate()) error("Error, snp, variance write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 4) { //plt
       plot_type pt(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       pt.initialize(&os);
       pt.initialize(ctrl.get_output_file());
       pt.set_XY(&X,&Y);

       if (!pt.calculate()) error("Error, snp, plot write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 5) { //sim

       simulate_type st(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       st.initialize(&os);
       st.initialize(ctrl.get_output_file());
       st.set_XY(&X,&Y);

       if (!st.calculate()) error("Error, snp, simulation write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else if (pf.get_optparms().task == 6) { //usr

       user_type ut(pf.get_optparms(),pf.get_datparms(),pf.get_tranparms(),
         pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(), pf.get_rfunc(), 
         pf.get_afunc_mask(), pf.get_ufunc_mask(), pf.get_rfunc_mask(),
         detail, &tr);

       ofstream os(ctrl.get_output_file().c_str());
       if (!os) error("Error, snp, can't open " + ctrl.get_output_file());

       ut.initialize(&os);
       ut.initialize(ctrl.get_output_file());
       ut.set_XY(&X,&Y);

       if (!ut.calculate()) error("Error, snp, user write failed");

       summary.write_partial_line();  // Information in ll not set.
     }
     else { //error
       error("Error, snp, no such task");
     }

  }
  return 0;
}
