[176] | 1 | #include <itpp/itbase.h> |
---|
| 2 | #include "merger.h" |
---|
| 3 | #include "arx.h" |
---|
| 4 | |
---|
| 5 | vec merger::lognorm_merge ( mat &lW ) { |
---|
| 6 | int nu=lW.rows(); |
---|
| 7 | vec mu = sum ( lW ) /nu; //mean of logs |
---|
| 8 | vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); |
---|
| 9 | double coef=0.0; |
---|
| 10 | switch ( nu ) { |
---|
| 11 | case 2: |
---|
| 12 | coef=sqrt ( beta*2 ) * ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) ); |
---|
| 13 | return exp ( coef*sqrt ( lam ) + mu ); |
---|
| 14 | break; |
---|
| 15 | case 3://Ration of Bessel |
---|
| 16 | break; |
---|
| 17 | case 4: |
---|
| 18 | break; |
---|
[182] | 19 | default: // Approximate conditional density |
---|
| 20 | break; |
---|
[176] | 21 | } |
---|
[180] | 22 | return vec ( 0 ); |
---|
[176] | 23 | } |
---|
| 24 | |
---|
| 25 | void merger::merge ( const epdf* g0 ) { |
---|
[180] | 26 | it_file dbg ( "merger_debug.it" ); |
---|
[176] | 27 | |
---|
[180] | 28 | it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" ); |
---|
| 29 | //Empirical density - samples |
---|
| 30 | eEmp eSmp ( rv,Ns ); |
---|
| 31 | eSmp.set_parameters ( ones ( Ns ), g0 ); |
---|
| 32 | Array<vec> &Smp = eSmp._samples(); //aux |
---|
| 33 | vec &w = eSmp._w(); //aux |
---|
| 34 | |
---|
[192] | 35 | mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones |
---|
[180] | 36 | for ( int i=0;i<Ns;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );} |
---|
| 37 | |
---|
[182] | 38 | dbg << Name ( "Smp_0" ) << Smp_ex; |
---|
[180] | 39 | |
---|
| 40 | // Stuff for merging |
---|
[176] | 41 | vec lw_src ( Ns ); |
---|
| 42 | vec lw_mix ( Ns ); |
---|
[180] | 43 | mat lW=zeros ( n,Ns ); |
---|
[176] | 44 | vec vec0 ( 0 ); |
---|
| 45 | |
---|
[180] | 46 | // Initial component in the mixture model |
---|
[176] | 47 | mat V0=1e-8*eye ( rv.count() +1 ); |
---|
[180] | 48 | ARX A0 ( RV ( "{th_r }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\ |
---|
[182] | 49 | V0, rv.count() *rv.count() +3.0 ); //initial guess of Mix: zero mean, large variance |
---|
[176] | 50 | |
---|
[180] | 51 | // ============= MAIN LOOP ================== |
---|
[176] | 52 | bool converged=false; |
---|
[180] | 53 | int niter = 0; |
---|
| 54 | char str[100]; |
---|
[182] | 55 | |
---|
| 56 | epdf* Mpred; |
---|
| 57 | vec Mix_pdf ( Ns ); |
---|
[176] | 58 | while ( !converged ) { |
---|
[180] | 59 | //Re-estimate Mix |
---|
| 60 | //Re-Initialize Mixture model |
---|
| 61 | Mix.init ( &A0, Smp_ex, Nc ); |
---|
[192] | 62 | Mix.bayesB ( Smp_ex, ones ( Ns ) );//w*Ns ); |
---|
| 63 | Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!! |
---|
| 64 | |
---|
[180] | 65 | // Generate new samples |
---|
[182] | 66 | eSmp.set_samples ( Mpred ); |
---|
[180] | 67 | for ( int i=0;i<Ns;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );} |
---|
| 68 | |
---|
| 69 | sprintf ( str,"Mpdf%d",niter ); |
---|
[182] | 70 | for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );} |
---|
[180] | 71 | dbg << Name ( str ) << Mix_pdf; |
---|
| 72 | |
---|
| 73 | sprintf ( str,"Smp%d",niter ); |
---|
| 74 | dbg << Name ( str ) << Smp_ex; |
---|
| 75 | |
---|
| 76 | //Importace weighting |
---|
[176] | 77 | for ( int i=0;i<n;i++ ) { |
---|
[182] | 78 | lw_src=0.0; |
---|
| 79 | //======== Same RVs =========== |
---|
[176] | 80 | //Split according to dependency in rvs |
---|
[192] | 81 | if ( mpdfs ( i )->_rv().count() ==rv.count() ) { |
---|
[176] | 82 | // no need for conditioning or marginalization |
---|
[182] | 83 | for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle |
---|
[180] | 84 | lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) ); |
---|
| 85 | } |
---|
[176] | 86 | } |
---|
[182] | 87 | else { |
---|
| 88 | // compute likelihood of marginal on the conditional variable |
---|
| 89 | if ( mpdfs ( i )->_rvc().count() >0 ) { |
---|
| 90 | // Make marginal on rvc_i |
---|
| 91 | epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() ); |
---|
[192] | 92 | //compute vector of lw_src |
---|
| 93 | for ( int k=0;k<Ns;k++ ) { |
---|
| 94 | lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_val ( Smp ( i ) ) ); |
---|
| 95 | } |
---|
[182] | 96 | delete tmp_marg; |
---|
| 97 | } |
---|
| 98 | // Compute likelihood of the missing variable |
---|
[192] | 99 | if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) { |
---|
| 100 | // There are variales unknown to mpdfs(i) : rvzs |
---|
| 101 | mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) ); |
---|
[182] | 102 | // Compute likelihood |
---|
| 103 | for ( int k= 0; k<Ns; k++ ) { |
---|
[192] | 104 | lw_src ( k ) += log ( |
---|
| 105 | tmp_cond->evalcond ( |
---|
| 106 | zdls ( i )->get_val ( Smp ( k ) ), |
---|
| 107 | zdls ( i )->get_cond ( Smp ( k ) ) ) ); |
---|
[182] | 108 | } |
---|
| 109 | delete tmp_cond; |
---|
| 110 | } |
---|
| 111 | // Compute likelihood of the partial source |
---|
| 112 | for ( int k= 0; k<Ns; k++ ) { |
---|
[192] | 113 | mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp(k) ) ); |
---|
| 114 | lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp(k) ) ); |
---|
[182] | 115 | } |
---|
| 116 | } |
---|
[180] | 117 | lW.set_row ( i, lw_src ); // do not divide by mix |
---|
[176] | 118 | } |
---|
[180] | 119 | //Importance of the mixture |
---|
| 120 | for ( int j=0;j<Ns;j++ ) { |
---|
| 121 | lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) ); |
---|
| 122 | } |
---|
| 123 | sprintf ( str,"lW%d",niter ); |
---|
| 124 | dbg << Name ( str ) << lW; |
---|
| 125 | |
---|
| 126 | w = lognorm_merge ( lW ); //merge |
---|
[192] | 127 | |
---|
[180] | 128 | sprintf ( str,"w%d",niter ); |
---|
| 129 | dbg << Name ( str ) << w; |
---|
[182] | 130 | sprintf ( str,"lw_m%d",niter ); |
---|
| 131 | dbg << Name ( str ) << lw_mix; |
---|
[180] | 132 | |
---|
| 133 | //Importance weighting |
---|
| 134 | w /=exp ( lw_mix ); // hoping that it is not numerically sensitive... |
---|
| 135 | //renormalize |
---|
| 136 | w /=sum ( w ); |
---|
| 137 | |
---|
[182] | 138 | sprintf ( str,"w_is_%d",niter ); |
---|
| 139 | dbg << Name ( str ) << w; |
---|
| 140 | |
---|
[180] | 141 | eSmp.resample(); // So that it can be used in bayes |
---|
| 142 | for ( int i=0;i<Ns;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );} |
---|
| 143 | |
---|
| 144 | sprintf ( str,"Smp_res%d",niter ); |
---|
| 145 | dbg << Name ( str ) << Smp; |
---|
| 146 | |
---|
| 147 | // ==== stopping rule === |
---|
| 148 | niter++; |
---|
[192] | 149 | converged = ( niter>20 ); |
---|
[176] | 150 | } |
---|
| 151 | |
---|
| 152 | } |
---|