#include "libscl.h"

using namespace std;
using namespace scl;

class nleqns : public nleqns_base {
public:
  bool get_f(const realmat& x, realmat& f) 
  {
    if (x.nrow() != 2) error("Error, nleqns, wrong dimension for x");
    if (f.nrow() != 1) f.resize(1,1);
    realmat e(2,1);
    e[1] = 2.0 - exp(1.0*x[1] + 1.0*x[2]);
    e[2] = 2.0 - exp(2.0*x[2]);
    f = T(e)*e;
    return true;
  }
  bool get_F(const realmat& x, realmat& f, realmat& F)
  {
    if (x.nrow() != 2) error("Error, nleqns, wrong dimension for x");
    if (f.nrow() != 1) f.resize(1,1);
    if (F.nrow() != 1 || F.ncol() != 2) F.resize(1,2);
    realmat e(2,1);
    e[1] = 2.0 - exp(1.0*x[1] + 1.0*x[2]);
    e[2] = 2.0 - exp(2.0*x[2]);
    realmat E(2,2);
    E(1,1) = e[1]-2.0; E(1,2) = e[2]-2.0;
    E(2,1) = 0.0;      E(2,2) = 2.0*(e[2]-2.0);
    f = T(e)*e;
    F = 2.0*T(e)*E;
    return true;
  }
};

class nleqns_df : public nleqns_base {
public:
  bool get_f(const realmat& x, realmat& f) 
  {
    if (x.nrow() != 2) error("Error, nleqns, wrong dimension for x");
    if (f.nrow() != 1) f.resize(1,1);
    realmat e(2,1);
    e[1] = 2.0 - exp(1.0*x[1] + 1.0*x[2]);
    e[2] = 2.0 - exp(2.0*x[2]);
    f = T(e)*e;
    return true;
  }
  bool get_F(const realmat& x, realmat& f, realmat& F)
  {
    if (! this->get_f(x,f) ) return false;
    return nleqns_base::df(x,F);
  }
};

class rosenbrock : public nleqns_base {
public:
  bool get_f(const realmat& x, realmat& f) 
  {
    if (x.nrow() != 2) error("Error, nleqns, wrong dimension for x");
    if (f.nrow() != 1) f.resize(1,1);
    f[1] = 100.0*pow(x[2] - pow(x[1],2),2) + pow(1.0 - x[1],2);
    return true;
  }
  bool get_F(const realmat& x, realmat& f, realmat& F)
  {
    if (x.nrow() != 2) error("Error, nleqns, wrong dimension for x");
    if (f.nrow() != 1) f.resize(1,1);
    if (F.nrow() != 1 || F.ncol() != 2) F.resize(1,2);
    f[1] = 100.0*pow(x[2] - pow(x[1],2),2) + pow(1.0 - x[1],2);
    F(1,1) = -400.0*(x[2] - pow(x[1],2))*x[1] - 2.0*(1.0 - x[1]);
    F(1,2) = 200.0*(x[2] - pow(x[1],2));
    return true;
  }
};


int main(int argc, char** argp, char** envp)
{
  cout << starbox("/Using Rosenbrock//");

  rosenbrock rb;

  realmat x0(2,1,0.0);

  realmat f0, F0; 

  rb.get_F(x0, f0, F0);

  realmat direction(2,1); 
  direction[1] = 1.0;
  direction[2] = 0.0;

  realmat x, f , F;

  linesrch rbs(rb);

  REAL guess = 1.0;
  for (REAL a=0; a<=guess; a+=guess/25.0) {
    realmat u = x0 + a*direction;
    realmat g; rb.get_f(u,g);
    cout << a << ' ' << g[1] << '\n';
  }
  rbs.set_initial_guess(guess);
  //rbs.set_solution_tolerance(REAL_MIN);
  rbs.set_solution_tolerance(EPS);
  REAL alpha = rbs.search(direction, 0.0, x0, f0, F0, x, f, F);
  cout << "\n\t termination_code = " << rbs.get_termination_code() << '\n';
  cout << "\n\t alpha = " << alpha << '\n';

  rbs.set_initial_guess(0.1);
  //rbs.set_solution_tolerance(REAL_MIN);
  rbs.set_solution_tolerance(EPS);
  rbs.set_warning_messages(true);
  alpha = rbs.search(direction, 0.0, x0, f0, F0, x, f, F);
  cout << "\n\t termination_code = " << rbs.get_termination_code() << '\n';
  cout << "\n\t alpha = " << alpha << '\n';


  cout << starbox("/Using exponential with analytic derivatives//");

  nleqns obj;

  x0.resize(2,1,0.0); 

  obj.get_F(x0, f0, F0);

  direction = -T(F0);

  linesrch objs(obj);
  
  guess = 0.5;
  for (REAL a=0; a<=guess; a+=guess/25.0) {
    realmat u = x0 + a*direction;
    realmat g; obj.get_f(u,g);
    cout << a << ' ' << g[1] << '\n';
  }
  objs.set_initial_guess(guess);
  objs.set_warning_messages(true);
  alpha = objs.search(direction, 0.0, x0, f0, F0, x, f, F);

  cout << "\n\t alpha = " << alpha << '\n';

  objs.set_initial_guess(1.0);
  alpha = objs.search(direction, 0.0, x0, f0, F0, x, f, F);

  cout << "\n\t alpha = " << alpha << '\n';


  cout << starbox("/Using exponential with numerical differentiator df//");
  
  nleqns_df obj_df;

  obj_df.get_F(x0, f0, F0);

  direction = -T(F0);

  linesrch obj_dfs(obj_df);
  guess = 0.5;
  for (REAL a=0; a<=guess; a+=guess/25.0) {
    realmat u = x0 + a*direction;
    realmat g; obj_df.get_f(u,g);
    cout << a << ' ' << g[1] << '\n';
  }
  obj_dfs.set_initial_guess(guess);
  obj_dfs.set_warning_messages(true);
  alpha = obj_dfs.search(direction, 0.0, x0, f0, F0, x, f, F);

  cout << "\n\t alpha = " << alpha << '\n';


  return 0;

}
