root/bdm/estim/merger.cpp @ 186

Revision 186, 4.5 kB (checked in by smidl, 16 years ago)

correction in merger

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        switch ( nu ) {
11                case 2:
12                        coef=sqrt ( beta*2 ) * ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) );
13                        return exp ( coef*sqrt ( lam ) + mu );
14                        break;
15                case 3://Ration of Bessel
16                        break;
17                case 4:
18                        break;
19                default: // Approximate conditional density
20                        break;
21        }
22        return vec ( 0 );
23}
24
25void merger::merge ( const epdf* g0 ) {
26        it_file dbg ( "merger_debug.it" );
27
28        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
29        //Empirical density - samples
30        eEmp eSmp ( rv,Ns );
31        eSmp.set_parameters ( ones ( Ns ), g0 );
32        Array<vec> &Smp = eSmp._samples(); //aux
33        vec &w = eSmp._w(); //aux
34
35        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples The last row is ones for the ARX model
36        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
37
38        dbg << Name ( "Smp_0" ) << Smp_ex;
39
40        // Stuff for merging
41        vec lw_src ( Ns );
42        vec lw_mix ( Ns );
43        mat lW=zeros ( n,Ns );
44        vec vec0 ( 0 );
45
46        // Initial component in the mixture model
47        mat V0=1e-8*eye ( rv.count() +1 );
48        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
49                 V0, rv.count() *rv.count() +3.0 ); //initial guess of Mix: zero mean, large variance
50
51
52
53        // ============= MAIN LOOP ==================
54        bool converged=false;
55        int niter = 0;
56        char str[100];
57
58        epdf* Mpred;
59        vec Mix_pdf ( Ns );
60        while ( !converged ) {
61                //Re-estimate Mix
62                //Re-Initialize Mixture model
63                Mix.init ( &A0, Smp_ex, Nc );
64                Mix.bayesB ( Smp_ex );
65                Mpred = Mix.predictor(rv); // Allocation => must be deleted at the end!!
66       
67                // Generate new samples
68                eSmp.set_samples ( Mpred );
69                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
70
71                sprintf ( str,"Mpdf%d",niter );
72                for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
73                dbg << Name ( str ) << Mix_pdf;
74
75                sprintf ( str,"Smp%d",niter );
76                dbg << Name ( str ) << Smp_ex;
77
78                //Importace weighting
79                for ( int i=0;i<n;i++ ) {
80                        lw_src=0.0;
81                        //======== Same RVs ===========
82                        //Split according to dependency in rvs
83                        if ( rvsinrv ( i ).length() ==rv.count() ) {
84                                // no need for conditioning or marginalization
85                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
86                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
87                                }
88                        }
89                        else {
90                                vec smpk;
91                                // compute likelihood of marginal on the conditional variable
92                                if ( mpdfs ( i )->_rvc().count() >0 ) {
93                                        // Make marginal on rvc_i
94                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
95                                        for(int k=0;k<Ns;k++){
96                                        lw_src(k) += tmp_marg->evalpdflog ( get_vec(Smp(i), irv_rvcs(i) ) );}
97                                        delete tmp_marg;
98                                }
99                                // Compute likelihood of the missing variable
100                                if ( rv.count() > mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) {
101                                        // There are variales unknown to mpdfs
102                                        RV z = ( rv.subt ( mpdfs ( i )->_rv() ) ).subt ( mpdfs ( i )->_rvc() );
103                                        mpdf* tmp_cond = Mpred->condition ( z ); 
104                                        // Indeces of rest rv in Smp
105                                        ivec zinrv=z.dataind ( rv ) ;
106                                        // Indeces of rest rvc in Smp
107                                        ivec zinrvc=tmp_cond->_rvc().dataind ( rv );
108                                        // Compute likelihood
109                                        for ( int k= 0; k<Ns; k++ ) {
110                                                smpk=Smp( k );
111                                                lw_src ( k ) += log( tmp_cond->evalcond ( get_vec ( smpk,zinrv ), get_vec ( smpk,zinrvc ) ));
112                                        }
113                                        delete tmp_cond;
114                                }
115                                // Compute likelihood of the partial source
116                                for ( int k= 0; k<Ns; k++ ) {
117                                        smpk=Smp( k );
118                                        mpdfs ( i )->condition ( get_vec ( smpk,irv_rvcs(i) ) );
119                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( get_vec ( smpk, rvsinrv ( i ) ) );
120                                }
121                        }
122                        lW.set_row ( i, lw_src ); // do not divide by mix
123                }
124                //Importance of the mixture
125                for ( int j=0;j<Ns;j++ ) {
126                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
127                }
128                sprintf ( str,"lW%d",niter );
129                dbg << Name ( str ) << lW;
130
131                w = lognorm_merge ( lW ); //merge
132               
133                sprintf ( str,"w%d",niter );
134                dbg << Name ( str ) << w;
135                sprintf ( str,"lw_m%d",niter );
136                dbg << Name ( str ) << lw_mix;
137
138                //Importance weighting
139                w /=exp ( lw_mix ); // hoping that it is not numerically sensitive...
140                //renormalize
141                w /=sum ( w );
142
143                sprintf ( str,"w_is_%d",niter );
144                dbg << Name ( str ) << w;
145
146                eSmp.resample(); // So that it can be used in bayes
147                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
148
149                sprintf ( str,"Smp_res%d",niter );
150                dbg << Name ( str ) << Smp;
151
152                // ==== stopping rule ===
153                niter++;
154                converged = ( niter>6);
155        }
156
157}
Note: See TracBrowser for help on using the browser.