root/bdm/estim/merger.cpp @ 192

Revision 192, 4.4 kB (checked in by smidl, 16 years ago)

modification of datalinks and switch mprod and merger to use them

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        switch ( nu ) {
11                case 2:
12                        coef=sqrt ( beta*2 ) * ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) );
13                        return exp ( coef*sqrt ( lam ) + mu );
14                        break;
15                case 3://Ration of Bessel
16                        break;
17                case 4:
18                        break;
19                default: // Approximate conditional density
20                        break;
21        }
22        return vec ( 0 );
23}
24
25void merger::merge ( const epdf* g0 ) {
26        it_file dbg ( "merger_debug.it" );
27
28        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
29        //Empirical density - samples
30        eEmp eSmp ( rv,Ns );
31        eSmp.set_parameters ( ones ( Ns ), g0 );
32        Array<vec> &Smp = eSmp._samples(); //aux
33        vec &w = eSmp._w(); //aux
34
35        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
36        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
37
38        dbg << Name ( "Smp_0" ) << Smp_ex;
39
40        // Stuff for merging
41        vec lw_src ( Ns );
42        vec lw_mix ( Ns );
43        mat lW=zeros ( n,Ns );
44        vec vec0 ( 0 );
45
46        // Initial component in the mixture model
47        mat V0=1e-8*eye ( rv.count() +1 );
48        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
49                 V0, rv.count() *rv.count() +3.0 ); //initial guess of Mix: zero mean, large variance
50
51        // ============= MAIN LOOP ==================
52        bool converged=false;
53        int niter = 0;
54        char str[100];
55
56        epdf* Mpred;
57        vec Mix_pdf ( Ns );
58        while ( !converged ) {
59                //Re-estimate Mix
60                //Re-Initialize Mixture model
61                Mix.init ( &A0, Smp_ex, Nc );
62                Mix.bayesB ( Smp_ex, ones ( Ns ) );//w*Ns );
63                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
64
65                // Generate new samples
66                eSmp.set_samples ( Mpred );
67                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
68
69                sprintf ( str,"Mpdf%d",niter );
70                for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
71                dbg << Name ( str ) << Mix_pdf;
72
73                sprintf ( str,"Smp%d",niter );
74                dbg << Name ( str ) << Smp_ex;
75
76                //Importace weighting
77                for ( int i=0;i<n;i++ ) {
78                        lw_src=0.0;
79                        //======== Same RVs ===========
80                        //Split according to dependency in rvs
81                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
82                                // no need for conditioning or marginalization
83                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
84                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
85                                }
86                        }
87                        else {
88                                // compute likelihood of marginal on the conditional variable
89                                if ( mpdfs ( i )->_rvc().count() >0 ) {
90                                        // Make marginal on rvc_i
91                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
92                                        //compute vector of lw_src
93                                        for ( int k=0;k<Ns;k++ ) {
94                                                lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_val ( Smp ( i ) ) );
95                                        }
96                                        delete tmp_marg;
97                                }
98                                // Compute likelihood of the missing variable
99                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
100                                        // There are variales unknown to mpdfs(i) : rvzs
101                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
102                                        // Compute likelihood
103                                        for ( int k= 0; k<Ns; k++ ) {
104                                                lw_src ( k ) += log (
105                                                                    tmp_cond->evalcond (
106                                                                        zdls ( i )->get_val ( Smp ( k ) ),
107                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
108                                        }
109                                        delete tmp_cond;
110                                }
111                                // Compute likelihood of the partial source
112                                for ( int k= 0; k<Ns; k++ ) {
113                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp(k) ) );
114                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp(k) ) );
115                                }
116                        }
117                        lW.set_row ( i, lw_src ); // do not divide by mix
118                }
119                //Importance of the mixture
120                for ( int j=0;j<Ns;j++ ) {
121                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
122                }
123                sprintf ( str,"lW%d",niter );
124                dbg << Name ( str ) << lW;
125
126                w = lognorm_merge ( lW ); //merge
127
128                sprintf ( str,"w%d",niter );
129                dbg << Name ( str ) << w;
130                sprintf ( str,"lw_m%d",niter );
131                dbg << Name ( str ) << lw_mix;
132
133                //Importance weighting
134                w /=exp ( lw_mix ); // hoping that it is not numerically sensitive...
135                //renormalize
136                w /=sum ( w );
137
138                sprintf ( str,"w_is_%d",niter );
139                dbg << Name ( str ) << w;
140
141                eSmp.resample(); // So that it can be used in bayes
142                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
143
144                sprintf ( str,"Smp_res%d",niter );
145                dbg << Name ( str ) << Smp;
146
147                // ==== stopping rule ===
148                niter++;
149                converged = ( niter>20 );
150        }
151
152}
Note: See TracBrowser for help on using the browser.