/* ===================================================================== 
   Predictives and cond expectations in MDP

   to be used with mdp.c in the same directory
/* ===================================================================== */ 
#include<stdio.h>
#include<stdlib.h>
#include<math.h>

#include "rand.h"
#include "nrutil.h"
#include "matrix.h"
#include "vector.h"
#include "rand-nr.h" 
#include "mess.h"
#include "interface.h"
#include "ppmx.h"
#include "predict.h"

/* ***********************************************************
   data structures
/* *********************************************************** */
char pyinit[30],pyp[30],pypi[30],pys[30],pysi[30],pyx[30],pyy[30];
char pyp2[30],pys2[30],pypi2[30],pysi2[30],pyy2[30];
/* -----------------------------------------------------------
   grid for univariate pdf and exp value
/* ----------------------------------------------------------- */
struct Y_GRID{
  int ny;              /* grid size for response */
  int nx;              /* number of covariate vectors */
  int yi;              /* index of response (probably 0) */

  int n_update;       /* number of times z-value has been upd */
  double **x;         /* nx x (p-p0) matrix of covariates */
  double *y;          /* (ny x 1) grid over y */
  double *y2;          /* (ny x 1) grid over y */

  double **p,         /* (ny x 1) pdf on grid over y */
    **S,              /* 1-cdf */
    /* aux for updating p */
    **pi,             /* aux of unstandardized prob's */
    **Si,

    **p2,
    **S2,
    
    **pi2,
    **Si2;
};

struct Y_GRID *yalloc(void)
/* allocates memory for an X_GRID */
{
  return (struct Y_GRID *) malloc(sizeof(struct Y_GRID));
}

int p,p0,px;
/* global variables -- initialized in py_init */                

/* ***********************************************************
   univariate predictive
/* *********************************************************** */
/* -----------------------------------------------------------
   initialize univ predicitve
/* ----------------------------------------------------------- */
struct Y_GRID *py_init(struct MDP_PARS mdp, struct DTA dta, int seeds, int simu)
{
  int nx, ny, yi;
  double yrange[2];
  double yrange2[2];
  struct Y_GRID *py;


  
  p0 = dta.p0;
  if (simu==0)
  {
    p = dta.p;
    px = p-p0;
  }
  else{
    p = mdp.tr_p1+mdp.tr_p2+p0;
    px=mdp.tr_p1+mdp.tr_p2;
  }
  
  

  /* read pars */
   if (seeds==0)
  {
    openIn("py-init_0.mdp");
  }
  else{
    sprintf(pyinit,"py-init_%d.mdp",seeds);
    openIn(pyinit);
  }
  scanInt(" ny ", &ny);  /* grid size for response */
  scanInt(" nx ", &nx);  /* number of cov vectors */
  scanInt(" yi ", &yi);  /* index of response */



 
  /* alloc mem */
  py = yalloc();
  py->y = dvector(0,ny);
  py->y2 = dvector(0,ny);
  printf("nx %d, ny %d \n",nx,ny);
 
  py->x = dmatrix(0,nx-1,0,px-1);


  py->p = dmatrix(0,nx-1,0,ny-1);
  py->p2 = dmatrix(0,nx-1,0,ny-1);


  py->S = dmatrix(0,nx-1,0,ny);
  py->S2= dmatrix(0,nx-1,0,ny);

  py->pi = dmatrix(0,nx-1,0,ny-1);
  py->pi2 = dmatrix(0,nx-1,0,ny-1);

  py->Si = dmatrix(0,nx-1,0,ny);
  py->Si2 = dmatrix(0,nx-1,0,ny);

  /* init members */
  py->ny=ny;
 
  py->nx=nx;

  py->yi=yi;
  py->n_update = 0;

  /* read y grid */
  scanDoubleArray(" ygrid ", yrange,2); /* range of y grid */
  grid(yrange[0],yrange[1],ny,py->y);

  scanDoubleArray(" ygrid2 ", yrange2,2); /* range of y grid */
  grid(yrange2[0],yrange2[1],ny,py->y2);

  /* read x grid */
  scanDoubleMatrix(" xgrid ", py->x, nx, px);

 
  
 
  return py;
}

/* -----------------------------------------------------------
   print_py
/* ----------------------------------------------------------- */
/*   print E(z|xy) */
void print_py(struct Y_GRID *py, struct MDP_PARS mdp, int seeds,int simu)
{
  static int first=1;
  sprintf(pyp,"py-p_%d.mdp",seeds);
  sprintf(pys,"py-S_%d.mdp",seeds);
  sprintf(pypi,"py-pi_%d.mdp",seeds);
  sprintf(pysi,"py-Si_%d.mdp",seeds);

  sprintf(pyp2,"py-p2_%d.mdp",seeds);
  sprintf(pys2,"py-S2_%d.mdp",seeds);
  sprintf(pypi2,"py-pi2_%d.mdp",seeds);
  sprintf(pysi2,"py-Si2_%d.mdp",seeds);


  
    /* print y-grid */
  if (seeds==0)
  {
     openOut("py-y.mdp");
    writeDoubleArray(py->y,py->ny,1);
    closeOut();

    openOut("py-y2.mdp");
    writeDoubleArray(py->y2,py->ny,1);
    closeOut();
  }
  else{
    sprintf(pyy,"py-y_%d.mdp",seeds);
    openOut(pyy);
    writeDoubleArray(py->y,py->ny,1);
    closeOut();

    sprintf(pyy2,"py-y2_%d.mdp",seeds);
    openOut(pyy2);
    writeDoubleArray(py->y2,py->ny,1);
    closeOut();
  }
    if (first){
    /* print y-grid */
   
  
      openOut("py-x.mdp");
      writeDoubleMatrix(py->x,py->nx,px);
      closeOut();
   
    
    
    
  }

  if (seeds==0)
  {
      /* write p-grid */
    openOut("py-p.mdp");
    writeDoubleMatrix(py->p,py->nx,py->ny);
    closeOut();
    openOut("py-S.mdp");
    writeDoubleMatrix(py->S,py->nx,py->ny);
    closeOut();
    openOut("py-S2.mdp");
    writeDoubleMatrix(py->S2,py->nx,py->ny);
    closeOut();
    if (first) 
      openOut("py-pi.mdp");
    else
      openAppend("py-pi.mdp");
    writeDoubleMatrix(py->pi,py->nx,py->ny);
    closeOut();
    if (first) 
      openOut("py-Si.mdp");
    else
      openAppend("py-Si.mdp");
    writeDoubleMatrix(py->Si,py->nx,py->ny);
    closeOut();
    if (first) 
      openOut("py-Si2.mdp");
    else
      openAppend("py-Si2.mdp");
    writeDoubleMatrix(py->Si2,py->nx,py->ny);
    closeOut();

  }
  else{
    openOut(pyp);
      writeDoubleMatrix(py->p,py->nx,py->ny);
      closeOut();
      
      openOut(pys);
      writeDoubleMatrix(py->S,py->nx,py->ny);
      closeOut();
      if (first) {
       
        openOut(pypi);
      }
      else{
        openAppend(pypi);
      }
      writeDoubleMatrix(py->pi,py->nx,py->ny);
      closeOut();
      if (first){
         openOut(pysi);
      } 
      else{
         openAppend(pysi);
      }
       
      writeDoubleMatrix(py->Si,py->nx,py->ny);
      closeOut();

/*    TTE2                                   */
      openOut(pyp2);
      writeDoubleMatrix(py->p2,py->nx,py->ny);
      closeOut();
      
      openOut(pys2);
      writeDoubleMatrix(py->S2,py->nx,py->ny);
      closeOut();
      if (first) {
       
        openOut(pypi2);
      }
      else{
        openAppend(pypi2);
      }
      writeDoubleMatrix(py->pi2,py->nx,py->ny);
      closeOut();
      if (first){
         openOut(pysi2);
      } 
      else{
         openAppend(pysi2);
      }
       
      writeDoubleMatrix(py->Si2,py->nx,py->ny);
      closeOut();



  }
  
 
  
  first=0;
  
}


/* -----------------------------------------------------------
   update
/* ----------------------------------------------------------- */
/*   get expected values over grid xgrid, ygrid (global??) */
     void py_update(int time, struct Y_GRID *py, 
		    struct MDP_PARS mdp, struct DTA dta,
		    struct SIM_PAR sim, int postMCMC,int simu)
{
  int i,j,k,
    yi,nx,ny,K,n,
    n_updates;
  double kd,alpha,an,lpr,
    sxx,sx,m, sxx2,sx2,m2,
    wmx,s,
    *yline,
    y,r,*w, *c,
    mxh, hC, lgr;


  /* set aux pars */  
  yi = py->yi;
  ny = py->ny;
  nx = py->nx;
  alpha = mdp.alpha;
  kd = 1.0*mdp.n_class;
  an = mdp.alpha+1.0*dta.n_obs;
  n = dta.n_obs;
  K = mdp.n_class;
  n_updates = py->n_update;

  /* alloc mem */
  w = dvector(0,K);
  c = dvector(0,K);
 
  yline= dvector(0,p);
  

  printf("nx %d, px %d \n",nx,px);
  
  x_zero(yline,p);

 if (postMCMC==0)
  {
  /* make weights */
  for (i=0;i<nx;i++){
    for(j=0;j<px;j++){

     /* fill covariates into fake data record */
      yline[p0+j]=py->x[i][j];
     
    }
    if (simu==0)
    {
        c[K] = logpr(yline,i,K);
        w[K]=  log(alpha/an);
        for(k=0;k<K;k++){
          c[k] = logpr(yline,i,k);
          w[k]=  log(mdp.count[k]/an);
          if (k==0)        
            mxh=c[k];
          else 
        if (c[k] > mxh) mxh=c[k];
        }
    }
    else{
      c[K] =  logpr_simu(yline,i,K);
        w[K]=  log(alpha/an);
        for(k=0;k<K;k++){
          c[k] =  logpr_simu(yline,i,k);
          w[k]=  log(mdp.count[k]/an);
          if (k==0)        
            mxh=c[k];
          else 
        if (c[k] > mxh) mxh=c[k];
        }
     
    }
    
    
    
    if (sim.dahl==1){ /* Dahl's method */
      c[K] = log(1.0*mdp.alpha);
      w[K] = 0;
      for(k=0, hC=0.0; k<K;k++)
	hC += exp(c[k]-mxh);
      lgr = log( n/hC ); /* multiply with r to assure sum pr*q = n-1 */
      for(k=0,wmx=0;k<K;k++){
	c[k] = c[k] - mxh /*scaling*/ + lgr; /* to stdze */
	w[k] = 0;
      }
    }


    for(k=0;k<=K;k++){
      w[k] = c[k]+w[k];
      if ( (k==0) | (w[k]> wmx)) wmx=w[k];
    }
    for(s=k=0;k<=K;k++){
      w[k] = exp(w[k]-wmx);
      s+=w[k];
    }
    x_div_r(w,s,w,K+1);

    // printf("p0= %d",p0);
    /* predictive if y from G0 */
    if (p0==1){
   
      sxx = mdp.Vt1[mdp.n_class];
      m = mdp.mean1;
    } else{  
      sxx = mdp.Vt[mdp.n_class][yi][yi];
      m = mdp.mean[yi];

      sxx2 = mdp.Vt[mdp.n_class][yi+1][yi+1];
      m2 = mdp.mean[yi+1];

    }
   
    sx = sqrt(sxx);
    sx2=sqrt(sxx2);

    //  printf("m %f and sx %f , m2 %f and sx2 %f\n", m,sx, m2,sx2);
    for (j = 0; j < ny; j++){
      py->pi[i][j] = w[K]*pdfnorm(py->y[j], m, sx);
      py->Si[i][j] = w[K]*(1.0-cdfnorm(py->y[j], m, sx) );

       py->pi2[i][j] = w[K]*pdfnorm(py->y2[j], m2, sx2);
      py->Si2[i][j] = w[K]*(1.0-cdfnorm(py->y2[j], m2, sx2) );
    }// for j
    
  
    /* pred if y is in one of the existing classes */
    for (k = 0; k < mdp.n_class; k++){
      if (p0>1){
       
	      m = mdp.mu[k][yi];
	      sxx = mdp.V[k][yi][yi];

        m2=mdp.mu[k][yi+1];
        sxx2=mdp.V[k][yi+1][yi+1];;
      } else {
	      m = mdp.mu1[k];
	      sxx = mdp.V1[k];
      }
      sx = sqrt(sxx);
      sx2=sqrt(sxx2);
      for (j = 0; j < ny; j++){
	py->pi[i][j] += w[k]*pdfnorm(py->y[j], m, sx);
	py->Si[i][j] += w[k]*(1.0-cdfnorm(py->y[j], m, sx) );

  py->pi2[i][j] += w[k]*pdfnorm(py->y2[j], m2, sx2);
	py->Si2[i][j] += w[k]*(1.0-cdfnorm(py->y2[j], m2, sx2) );
      }// for j
    }
   } // for i
  }
  else{
    for (i=0;i<nx;i++){
    for(j=0;j<px;j++) /* fill covariates into fake data record */
      yline[p0+j]=py->x[i][j];
    // c[K] = logpr(yline,i,K);
    // w[K]=  log(alpha/an);
    if(simu==0){
      for(k=0;k<K;k++){
          c[k] = logpr(yline,i,k);
          // printf("c[k]:%f \n",c[k]);
          w[k]=  log(mdp.count[k]/an);
          if (k==0)        
            mxh=c[k];
          else 
            if (c[k] > mxh) mxh=c[k];
      }
    }
    else{
        for(k=0;k<K;k++){
          c[k] = logpr_simu(yline,i,k);
          // printf("c[k]:%f \n",c[k]);
          w[k]=  log(mdp.count[k]/an);
          if (k==0)        
            mxh=c[k];
          else 
            if (c[k] > mxh) mxh=c[k];
      }

    }
  
  
   
    for(k=0;k<K;k++){
      w[k] = c[k]+w[k];
      if ( (k==0) | (w[k]> wmx)) wmx=w[k];
    }
    for(s=k=0;k<K;k++){
      w[k] = exp(w[k]-wmx);
      s+=w[k];
    }
    x_div_r(w,s,w,K);
     /*Skip the new class survival estiamte */

      //the probability of new class is zero
     for (j = 0; j < ny; j++){
      py->pi[i][j] =0;//
      py->Si[i][j] =0; //

      py->pi2[i][j] =0;//
      py->Si2[i][j] =0; //
      // printf("py->Si[i][j], %f, \n",py->Si[i][j]);
    }// for j
    
  
    /* pred if y is in one of the existing classes */
    for (k = 0; k < mdp.n_class; k++){
      if (p0>1){
	      m = mdp.mu[k][yi];
	      sxx = mdp.V[k][yi][yi];

        m2 = mdp.mu[k][yi+1];
	      sxx2 = mdp.V[k][yi+1][yi+1];
      } else {
	      m = mdp.mu1[k];
	      sxx = mdp.V1[k];
      }
      sx = sqrt(sxx);
      sx2 = sqrt(sxx2);
      for (j = 0; j < ny; j++){
	    py->pi[i][j] += w[k]*pdfnorm(py->y[j], m, sx);//
	    py->Si[i][j] += w[k]*(1.0-cdfnorm(py->y[j], m, sx) );// 

      py->pi2[i][j] += w[k]*pdfnorm(py->y2[j], m2, sx2);//
	    py->Si2[i][j] += w[k]*(1.0-cdfnorm(py->y2[j], m2, sx2) );// 
      }// for j
    }
   }// for i

  }

  /* update py->p */
  n_updates = py->n_update;
  r = n_updates/(n_updates+1.0);
  for (i = 0; i < nx; i++)
    for(j=0;j<ny;j++){
      py->p[i][j] = r*py->p[i][j] + (1-r)*py->pi[i][j];
      py->S[i][j] = r*py->S[i][j] + (1-r)*py->Si[i][j];

       py->p2[i][j] = r*py->p2[i][j] + (1-r)*py->pi2[i][j];
      py->S2[i][j] = r*py->S2[i][j] + (1-r)*py->Si2[i][j];

  }
  
  /* increment n_update */
  py->n_update = n_updates+1;

  /* free alloc mem */
  free_dvector(w,0,K);
  free_dvector(c,0,K);
  free_dvector(yline,0,p);
  return;
}


/* ===================================================================== 
   problem specific pred routines
/* ===================================================================== */ 

static struct Y_GRID *py;

/* ***********************************************************
   pred_init
/* *********************************************************** */
void pred_init(struct MDP_PARS mdp,struct DTA dta, int seeds, int simu)
{
  message("pred_init");
  if (mdp.py == 1) py = py_init(mdp,dta,seeds,simu );
}

/* ***********************************************************
   pred_update
/* *********************************************************** */
 void pred_update(int time, struct MDP_PARS mdp, struct DTA dta,
		  struct SIM_PAR sim, int postMCMC, int seeds, int simu)
{
  if (mdp.py == 1) py_update(time,py,mdp,dta,sim,postMCMC,simu);
  pred_finish(mdp, seeds, simu);
}
/* ***********************************************************
   pred_finish
/* *********************************************************** */
void pred_finish(struct MDP_PARS mdp, int seeds,int simu)
{
  if (mdp.py == 1) print_py(py,mdp, seeds,simu);
}

