/*
  Copyright(C) 2007-2012 National Institute of Information and Communications Technology
*/

/*
  svmtools
  Sequential minimum optimization module
*/


#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include "exception.h"
#include "svm_common.h"
#include "svm_kernel.h"
#include "svm_smo.h"


#define MIN(x, y) (((x) <= (y)) ? (x) : (y))
#define MAX(x, y) (((x) >= (y)) ? (x) : (y))


static int selwrkset(SVM_TPRM *tprm, SVM_EXM *exm, double *g, int *alpha_stat, int *act, int act_num, int *i1, int *i2);
static int solve(SVM_TPRM *tprm, SVM_EXM *exm, double *alpha, double *g, int ii1, int ii2, double q11, double q22, double q12, double *a1, double *a2);
static int update(SVM_TPRM *tprm, SVM_EXM *exm, double *alpha, double *g, int *alpha_stat, int *act, int act_num, int *keep, double *lambda_eq, double *err_max, int ii1, int ii2, float *q1, float *q2, double a1, double a2);
static int shrink(SVM_TPRM *tprm, SVM_EXM *exm, SVM_KC *kc, int *act, int *act_num, int *keep);
static int unshrink(SVM_TPRM *tprm, SVM_EXM *exm, SVM_KC **kc, int *act, int *act_num, int *keep, double *g, double *alpha, int *alpha_stat, double lambda_eq, double *err_max, int *err_num);
static double calcb(SVM_EXM *exm, double *g, int *alpha_stat);


/*
  SVM$B$N3X=,(B
*/
int svm_smo(SVM_TPRM *tprm, SVM_EXM *exm, double *alpha, double *b) {
  int i;
  int *alpha_stat;
  double *g;
  int *act, act_num, *keep;
  int loop;
  int i1, i2;
  float *q1, *q2;
  int shrinking = 1;
  int count = 1;
  int miss;
  SVM_KC *kc;
  double err_max = DBL_MAX;

  /* alpha$B$N=i4|2=(B */
  for (i = 0; i < exm->num; i++) alpha[i] = 0.0;

  /* $B71N}%G!<%?$,L5$$>l9g$N=hM}(B */
  if (exm->num == 0) {
    *b = 0.0;
    return 0;
  }

  /* $B%+!<%M%k%-%c%C%7%e$N=i4|2=(B */
  kc = svm_knl_kcinit(&tprm->kprm, exm->num, exm->label, exm->sv, tprm->kc_size);
  if (kc == NULL) return 1;

  /* alpha_stat$B$N=i4|2=(B */
  alpha_stat = smalloc(sizeof(int) * exm->num);
  for (i = 0; i < exm->num; i++) alpha_stat[i] = -1;  /* lower bound */

  /* Gradient$B$N=i4|2=(B */
  g = smalloc(sizeof(double) * exm->num);
  for (i = 0; i < exm->num; i++) g[i] = -1.0;

  /* shrinking$BMQJQ?t$N=i4|2=(B */
  act = smalloc(sizeof(int) * exm->num);
  for (i = 0; i < exm->num; i++) act[i] = i;
  act_num = exm->num;
  keep = smalloc(sizeof(int) * exm->num);
  for (i = 0; i < exm->num; i++) keep[i] = 0;

  /* SVM$B$N3X=,(B */
  for (loop = 1; ; loop++) {
    double a1, a2;
    double lambda_eq;
    int err_num;

    if (loop == count) {
      if (svm_verbose) fprintf(stderr, "\n%d (active=%d, cache=%d) ", loop, act_num, kc->row);
      count *= 10;
    }

    /* $B%o!<%-%s%0%;%C%H$NA*Br(B */
    if (selwrkset(tprm, exm, g, alpha_stat, act, act_num, &i1, &i2)) {
      break;
    }

    q1 = svm_knl_kckernel(kc, i1, NULL);
    q2 = svm_knl_kckernel(kc, i2, q1);

    /* 2$BJQ?t(BQP$B$r2r$/(B */
    solve(tprm, exm, alpha, g, act[i1], act[i2], q1[i1], q2[i2], (exm->label[act[i1]] == exm->label[act[i2]]) ? q1[i2] : -q1[i2], &a1, &a2);

    update(tprm, exm, alpha, g, alpha_stat, act, act_num, keep, &lambda_eq, &err_max, act[i1], act[i2], q1, q2, a1, a2);

    /* Shrinking */
    if (loop % tprm->shrink_loop == 0 && shrinking) {
      shrink(tprm, exm, kc, act, &act_num, keep);
      if ((loop / tprm->shrink_loop) % tprm->expand_loop == 0) {
	svm_knl_kcexpand(kc);
      }
    }

    /* $B<}B+H=Dj(B */
    if (err_max < tprm->kkt_del) {
      /* Unshrinking$B$NI,MW$,$J$$$H$-(B */
      if (act_num == exm->num) break;

      shrinking = 0;  /* shrinking$B$r(Boff */
      if (svm_verbose) fprintf(stderr, "\n%d (active=%d, cache=%d) Unshrinking... ", loop, act_num, kc->row);
      if (unshrink(tprm, exm, &kc, act, &act_num, keep, g, alpha, alpha_stat, lambda_eq, &err_max, &err_num)) return 1;
      if (svm_verbose) fprintf(stderr, "(num: %d, max: %f) ", err_num, err_max);
      if (svm_verbose) fprintf(stderr, "\n%d (active=%d, cache=%d) ", loop, act_num, kc->row);
      if (err_max < tprm->kkt_del) break;
    }
  }

  /* alpha$B$ND4@0(B */
  for (i = 0; i < exm->num; i++) {
    if (alpha_stat[i] == +1) alpha[i] = tprm->c;
    else if (alpha_stat[i] == -1) alpha[i] = 0.0;
  }

  /* b$B$N7W;;(B */
  *b = calcb(exm, g, alpha_stat);

  if (svm_verbose) fprintf(stderr, "converged (%d iterations) ", loop);

  /* missclassify$B$N7W;;(B */
  miss = 0;
  for (i = 0; i < exm->num; i++) {
    double z;

    z = g[i] + exm->label[i] * *b + 1.0;
    if (z <= 0.0) miss++;
  }

  if (svm_verbose) fprintf(stderr, "(missclassified: %d) ", miss);

  /* $B%a%b%j2rJ|(B */
  free(alpha_stat);
  free(g);
  free(act);
  free(keep);

  svm_knl_kcdelete(kc);

  return 0;
}


/*
  $B%o!<%-%s%0%;%C%H$NA*Br(B
*/

static int selwrkset(SVM_TPRM *tprm, SVM_EXM *exm, double *g, int *alpha_stat, int *act, int act_num, int *i1, int *i2) {
  int i;
  double gmax1, gmax2;
  int gmax1_idx, gmax2_idx;

  gmax1_idx = -1;
  gmax1 = -DBL_MAX;
  gmax2_idx = -1;
  gmax2 = -DBL_MAX;

  for (i = 0; i < act_num; i++) {
    int ii;

    ii = act[i];

    if (exm->label[ii] > 0) {
      if (alpha_stat[ii] != +1) {
	if (-g[ii] > gmax1) {
	  gmax1 = -g[ii];
	  gmax1_idx = i;
	}
      }
      if (alpha_stat[ii] != -1) {
	if (g[ii] > gmax2) {
	  gmax2 = g[ii];
	  gmax2_idx = i;
	}
      }
    } else {
      if (alpha_stat[ii] != +1) {
	if (-g[ii] > gmax2) {
	  gmax2 = -g[ii];
	  gmax2_idx = i;
	}
      }
      if (alpha_stat[ii] != -1) {
	if (g[ii] > gmax1) {
	  gmax1 = g[ii];
	  gmax1_idx = i;
	}
      }
    }
  }

  if (gmax1 + gmax2 < tprm->grad_del) return 1;

  *i1 = gmax1_idx;
  *i2 = gmax2_idx;

  return 0;
}


/*
  2$BJQ?t$N(BQP$B$r2r$/(B
*/

static int solve(SVM_TPRM *tprm, SVM_EXM *exm, double *alpha, double *g, int ii1, int ii2, double q11, double q22, double q12, double *a1, double *a2) {
  double H, L;
  int y1, y2;
  double d;
  double eta;
  double s;

  y1 = exm->label[ii1];
  y2 = exm->label[ii2];
  s = y1 * y2;

  d = y1 * alpha[ii1] + y2 * alpha[ii2];
  if (y1 == y2) {
    L = (d - tprm->c * y1) / y2;
    H = d / y2;
  } else {
    L = d / y2;
    H = (d - tprm->c * y1) / y2;
  }
  L = MAX(0.0, L);
  H = MIN(tprm->c, H);

  eta = q11 + q22 - 2.0 * q12;
  if (eta > 0.0) {
    *a2 = alpha[ii2] + (s * g[ii1] - g[ii2]) / eta;
    if (*a2 < L) *a2 = L; else if (H < *a2) *a2 = H;
  } else {
    /* ii1$B$H(Bii2$B$N%Y%/%H%k$,Ey$7$$>l9g(B(eta=0) */
    if (s * g[ii1] - g[ii2] > 0.0) *a2 = H; else *a2 = L;
  }

  *a1 = (d - y2 * *a2) / y1;

  return 0;
}


/*
  $BJQ?t$N99?7(B
*/

static int update(SVM_TPRM *tprm, SVM_EXM *exm, double *alpha, double *g, int *alpha_stat, int *act, int act_num, int *keep, double *lambda_eq, double *err_max, int ii1, int ii2, float *q1, float *q2, double a1, double a2) {
  int i, ii;
  double delta1, delta2;
  double nb_sum;
  int nb_num;

  /* alpha_stat$B$N99?7(B */
  if (a1 <= tprm->alpha_eps) alpha_stat[ii1] = -1;
  else if (a1 >= tprm->c - tprm->alpha_eps) alpha_stat[ii1] = +1;
  else alpha_stat[ii1] = 0;
  if (a2 <= tprm->alpha_eps) alpha_stat[ii2] = -1;
  else if (a2 >= tprm->c - tprm->alpha_eps) alpha_stat[ii2] = +1;
  else alpha_stat[ii2] = 0;

  delta1 = a1 - alpha[ii1];
  delta2 = a2 - alpha[ii2];

  nb_sum = 0.0; nb_num = 0;
  for (i = 0; i < act_num; i++) {
    ii = act[i];

    /* G$B$N99?7(B */
    g[ii] += q1[i] * delta1 + q2[i] * delta2;

    if (alpha_stat[ii] == 0) {
      nb_sum -= g[ii] * exm->label[ii];
      nb_num++;
    }
  }
  /* lambda^eq$B$N7W;;(B */
  if (nb_num > 0) *lambda_eq = nb_sum / nb_num; else *lambda_eq = 0.0;

  *err_max = 0.0;
  for (i = 0; i < act_num; i++) {
    double z;

    ii = act[i];

    if (exm->label[ii] > 0) {
      z = g[ii] + *lambda_eq;
    } else {
      z = g[ii] - *lambda_eq;
    }

    /* keep$B$N99?7(B */
    if ((alpha_stat[ii] > 0 && -z > tprm->kkt_eps) || (alpha_stat[ii] < 0 && z > tprm->kkt_eps)) {
      keep[ii]++;
    } else {
      keep[ii] = 0;
    }

    /* error$B$N99?7(B */
    if (alpha_stat[ii] >= 0) {
      if (z > *err_max) *err_max = z;
    }
    if (alpha_stat[ii] <= 0) {
      if (-z > *err_max) *err_max = -z;
    }
  }

  /* alpha$B$N99?7(B */
  alpha[ii1] = a1;
  alpha[ii2] = a2;

  return 0;
}


/*
  Shrinking$B$r9T$&(B
*/

static int shrink(SVM_TPRM *tprm, SVM_EXM *exm, SVM_KC *kc, int *act, int *act_num, int *keep) {
  int i;

  for (i = 0; i < *act_num; i++) {
    if (keep[act[i]] >= tprm->shrink_loop) {
      keep[act[i]] = -1;
      svm_knl_kchide(kc, i);
      act[i] = act[*act_num - 1];
      (*act_num)--;
      i--;
    }
  }

  return 0;
}


/*
  b$B$N7W;;(B
*/

static double calcb(SVM_EXM *exm, double *g, int *alpha_stat) {
  int i;
  double ub, lb, nb_sum;
  int nb_num;

  ub = -DBL_MAX;
  lb = DBL_MAX;
  nb_sum = 0.0;
  nb_num = 0;

  for (i = 0; i < exm->num; i++) {
    double yg;

    yg = -exm->label[i] * g[i];

    if (alpha_stat[i] == -1) {
      if (exm->label[i] > 0) ub = MAX(ub, yg);
      else lb = MIN(lb, yg);
    } else if (alpha_stat[i] == +1) {
      if (exm->label[i] < 0) ub = MAX(ub, yg);
      else lb = MIN(lb, yg);
    } else {
      nb_num++;
      nb_sum += yg;
    }
  }

  return (nb_num > 0) ? nb_sum / nb_num : (ub + lb) / 2.0;
}


/*
  Shrink$B$7$?$b$N$r85$KLa$9(B
*/

static int unshrink(SVM_TPRM *tprm, SVM_EXM *exm, SVM_KC **kc, int *act, int *act_num, int *keep, double *g, double *alpha, int *alpha_stat, double lambda_eq, double *err_max, int *err_num) {
  int i, ii;
  int actlb_num, actsv_num, inactsv_num, inactlb_num;
  int sv_num;
  int *wrk;
  SVM_SV **tmpsv;
  float *q, *q_ptr, *q_end;
  double *beta, *beta_ptr;
  SVM_KWRK *kwrk;

  svm_knl_kcdelete(*kc);
  wrk = smalloc(sizeof(int) * exm->num);

  /* active/inactive$B$J;vNc$H(Bsupport vector/lower bound$B$rJ,N`$9$k(B */
  /* $BG[Ns(Bwrk$B$N9=B$(B: [active lb][active SV][inactive SV][inactive lb] */
  actlb_num = 0;
  actsv_num = 0;
  inactsv_num = 0;
  inactlb_num = 0;
  for (i = 0; i < exm->num; i++) {
    if (keep[i] != -1) {
      if (alpha_stat[i] < 0) {
	wrk[actlb_num++] = i;
      } else {
	wrk[*act_num - ++actsv_num] = i;
      }
    } else {
      if (alpha_stat[i] >= 0) {
	wrk[*act_num + inactsv_num++] = i;
      } else {
	wrk[exm->num - ++inactlb_num] = i;
      }
    }
  }

  if (svm_verbose) fprintf(stderr, "(%d:%d/%d:%d) ", actsv_num, actlb_num, inactsv_num, inactlb_num);

  /* KKT$B>r7o$N%A%'%C%/(B */
  sv_num = actsv_num + inactsv_num;
  tmpsv = smalloc(sizeof(SVM_SV *) * sv_num);
  for (i = 0; i < sv_num; i++) tmpsv[i] = exm->sv[wrk[actlb_num + i]];
  if (svm_invidx) {
    kwrk = svm_knl_kwrkinit_inv(&tprm->kprm, sv_num, tmpsv);
  } else {
    kwrk = svm_knl_kwrkinit(&tprm->kprm, sv_num, tmpsv);
  }
  if (kwrk == NULL) return 1;
  q = smalloc(sizeof(float) * sv_num);
  beta = smalloc(sizeof(double) * sv_num);
  for (i = 0; i < sv_num; i++) beta[i] = exm->label[wrk[actlb_num + i]] * alpha[wrk[actlb_num + i]];
  q_end = q + sv_num;
  *err_max = 0.0;
  *err_num = 0;
  for (i = 0; i < inactsv_num + inactlb_num; i++) {
    double z, zz;
    double gg;

    ii = wrk[actlb_num + actsv_num + i];

    /* gradient$B$N7W;;(B */
    svm_knl_kwrkkernel(&tprm->kprm, kwrk, exm->sv[ii], q);
    gg = 0.0;
    for (q_ptr = q, beta_ptr = beta; q_ptr < q_end; q_ptr++, beta_ptr++) gg += *q_ptr * *beta_ptr;
    z = gg + lambda_eq;
    if (exm->label[ii] < 0) {
      gg = -gg;
      z = -z;
    }
    gg += -1.0;
    z += -1.0;
    g[ii] = gg;

    zz = z * alpha_stat[ii];
    if (zz > -0.5) {
      act[(*act_num)++] = ii;
      /* KKT$B>r7o$rK~$?$5$J$$>l9g(B */
      if (zz > 0.0) {
	/* error$B$N99?7(B */
	if (zz > *err_max) *err_max = zz;
	(*err_num)++;
      }
    }
  }

  svm_knl_kwrkdelete(kwrk);
  free(tmpsv);
  free(beta);
  free(q);
  free(wrk);

  /* $B%-%c%C%7%e$N%j%;%C%H(B */
  int *tmplabel;
  tmplabel = smalloc(sizeof(int) * *act_num);
  tmpsv = smalloc(sizeof(SVM_SV *) * *act_num);
  for (i = 0; i < *act_num; i++) {
    tmplabel[i] = exm->label[act[i]];
    tmpsv[i] = exm->sv[act[i]];
  }
  *kc = svm_knl_kcinit(&tprm->kprm, *act_num, tmplabel, tmpsv, tprm->kc_size);
  free(tmpsv);
  free(tmplabel);
  if (*kc == NULL) return 1;

  return 0;
}
