/*
  Copyright(C) 2007-2012 National Institute of Information and Communications Technology
*/

/*
  svmtools
  Training program
*/


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include <sys/times.h>
#include "exception.h"
#include "svm_common.h"
#include "svm_kernel.h"
#include "svm_smo.h"


#define DFLT_COST 1.0
#define DFLT_KERNEL_TYPE 1
#define DFLT_DEGREE 2.0f
#define DFLT_GAMMA 1.0f
#define DFLT_COEF 1.0f
#define DFLT_CACHE_SIZE 40


int main(int argc, char **argv) {
  int i;
  char *exmfile, *mdlfile, *alpfile;
  FILE *fp;
  SVM_EXM *exm;
  SVM_TPRM tprm;
  double *alpha;
  double b;
  struct tms tms_start, tms_finish;
  double utime;

  times(&tms_start);

  /* $B%Q%i%a!<%?:n@.$H%*%W%7%g%s$N=hM}(B */

  tprm.c = DFLT_COST;
  tprm.kprm.ktype = DFLT_KERNEL_TYPE;
  tprm.kprm.degree = DFLT_DEGREE;
  tprm.kprm.gamma = DFLT_GAMMA;
  tprm.kprm.coef = DFLT_COEF;
  tprm.alpha_eps = 1e-12;
  tprm.kkt_eps = 1.0e-9;
  tprm.grad_del = 0.001;
  tprm.kkt_del = 0.01;
  tprm.shrink_loop = 100;
  tprm.expand_loop = 10;
  tprm.kc_size = DFLT_CACHE_SIZE;

  for (i = 1; i < argc; i += 2) {
    if (strcmp(argv[i], "-c") == 0) {
      tprm.c = (float)atof(argv[i + 1]);
      exception(tprm.c < 0.0, "C must be positive");
    } else if (strcmp(argv[i], "-t") == 0) {
      tprm.kprm.ktype = atoi(argv[i + 1]);
      exception(tprm.kprm.ktype < 0 || 3 < tprm.kprm.ktype, "bad kernel type");
    } else if (strcmp(argv[i], "-d") == 0) {
      tprm.kprm.degree = (float)atof(argv[i + 1]);
    } else if (strcmp(argv[i], "-g") == 0) {
      tprm.kprm.gamma = (float)atof(argv[i + 1]);
    } else if (strcmp(argv[i], "-r") == 0) {
      tprm.kprm.coef = (float)atof(argv[i + 1]);
    } else if (strcmp(argv[i], "-m") == 0) {
      tprm.kc_size = atoi(argv[i + 1]);
    } else {
      break;
    }
  }

  if (argc - i != 2 && argc - i != 3) {
    fprintf(stderr, "usage: %s [<options>...] <example file> <model file> [<weight file>]\n", argv[0]);
    fprintf(stderr, "<options>:\n");
    fprintf(stderr, "\t-c <cost> : cost C of constraints violation [%g]\n", DFLT_COST);
    fprintf(stderr, "\t-t <kernel_type> : kernel type [%d]\n", DFLT_KERNEL_TYPE);
    fprintf(stderr, "\t   0 - linear <x,y>\n");
    fprintf(stderr, "\t   1 - polynomial (gamma <x, y> + coef)^d\n");
    fprintf(stderr, "\t   2 - rbf exp(-gamma ||x - y||^2)\n");
    fprintf(stderr, "\t   3 - sigmoid tanh(gamma <x, y> + coef)\n");
    fprintf(stderr, "\t-d <degree> : degree [%g]\n", DFLT_DEGREE);
    fprintf(stderr, "\t-g <gamma> : gamma [%g]\n", DFLT_COEF);
    fprintf(stderr, "\t-r <coef> : coef [%g]\n", DFLT_COEF);
    fprintf(stderr, "\t-m <cache_size> : cache size (MB) [%d]\n", (int)DFLT_CACHE_SIZE);

    return 1;
  }
  exmfile = argv[i];
  mdlfile = argv[i + 1];
  if (argc - i == 3) {
    alpfile = argv[i + 2];
  } else {
    alpfile = NULL;
  }

  /* $B71N}%G!<%?$NFI$_9~$_(B */
  fprintf(stderr, "Reading the example file... ");
  fp = fopen(exmfile, "rt");
  exception(fp == NULL, "cannot open the example file '%s'", exmfile);
  exm = svm_readexm(fp);
  exception(exm == NULL, "cannot read the example file '%s'", exmfile);
  fclose(fp);
  exception(exm->num == 0, "example file contains no data");
  fprintf(stderr, "done (%d examples).\n", exm->num);

  /* $B%a%b%j3d$jEv$F(B */
  alpha = smalloc(sizeof(double) * exm->num);

  /* $B3X=,(B */
  fprintf(stderr, "Training... ");
  exception(svm_smo(&tprm, exm, alpha, &b), "svm_smo() failed");
  fprintf(stderr, "done.\n");

  /* $B%b%G%k%U%!%$%k$N=q$-9~$_(B */
  fprintf(stderr, "Writing the model file... ");
  fp = fopen(mdlfile, "wt");
  exception(fp == NULL, "cannot open the model file '%s'", mdlfile);
  exception(svm_createmdl(&tprm.kprm, exm, alpha, b, fp), "cannot write the model file '%s'", mdlfile);
  fclose(fp);
  fprintf(stderr, "done.\n");

  /* $B&A$N=q$-9~$_(B */
  if (alpfile != NULL) {
    fprintf(stderr, "Writing the weight file... ");
    fp = fopen(alpfile, "wt");
    exception(fp == NULL, "cannot open the weight file '%s'", alpfile);
    for (i = 0; i < exm->num; i++) {
      fprintf(fp, "%.7g\n", alpha[i]);
    }
    fclose(fp);
    fprintf(stderr, "done.\n");
  }

  free(alpha);

  times(&tms_finish);

  /* $B>pJsI=<((B */

  utime = (tms_finish.tms_utime - tms_start.tms_utime) / (double)sysconf(_SC_CLK_TCK);
  printf("CPU time: %.1f\n", utime);

  return 0;
}
