root/bdm/estim/merger.cpp @ 197

Revision 197, 4.7 kB (checked in by smidl, 16 years ago)

opravy v bdm

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        switch ( nu ) {
11                case 2:
12                        coef=sqrt ( beta*2 ) * ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) );
13                        return exp ( coef*sqrt ( lam ) + mu );
14                        break;
15                case 3://Ration of Bessel
16                        break;
17                case 4:
18                        break;
19                default: // Approximate conditional density
20                        break;
21        }
22        return vec ( 0 );
23}
24
25void merger::merge ( const epdf* g0 ) {
26        it_file dbg ( "merger_debug.it" );
27
28        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
29        //Empirical density - samples
30        eEmp eSmp ( rv,Ns );
31        eSmp.set_parameters ( ones ( Ns ), g0 );
32        Array<vec> &Smp = eSmp._samples(); //aux
33        vec &w = eSmp._w(); //aux
34
35        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
36        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
37
38        dbg << Name ( "Smp_0" ) << Smp_ex;
39
40        // Stuff for merging
41        vec lw_src ( Ns );
42        vec lw_mix ( Ns );
43        mat lW=zeros ( n,Ns );
44        vec vec0 ( 0 );
45
46        // Initial component in the mixture model
47        mat V0=1e-8*eye ( rv.count() +1 );
48        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
49                 V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance
50
51        Mix.init ( &A0, Smp_ex, Nc );
52        //Preserve initial mixture for repetitive estimation via flattening
53        MixEF Mix_init(Mix);
54
55        // ============= MAIN LOOP ==================
56        bool converged=false;
57        int niter = 0;
58        char str[100];
59
60        epdf* Mpred;
61        vec Mix_pdf ( Ns );
62        while ( !converged ) {
63                //Re-estimate Mix
64                //Re-Initialize Mixture model
65                Mix.flatten(&Mix_init);
66                Mix.bayesB ( Smp_ex, w*Ns );
67                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
68
69                sprintf ( str,"Mpred_mean%d",niter );
70                dbg << Name ( str ) << Mpred->mean();
71
72                // Generate new samples
73                eSmp.set_samples ( Mpred );
74                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
75
76                sprintf ( str,"Mpdf%d",niter );
77                for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
78                dbg << Name ( str ) << Mix_pdf;
79
80                sprintf ( str,"Smp%d",niter );
81                dbg << Name ( str ) << Smp_ex;
82
83                //Importace weighting
84                for ( int i=0;i<n;i++ ) {
85                        lw_src=0.0;
86                        //======== Same RVs ===========
87                        //Split according to dependency in rvs
88                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
89                                // no need for conditioning or marginalization
90                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
91                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
92                                }
93                        }
94                        else {
95                                // compute likelihood of marginal on the conditional variable
96                                if ( mpdfs ( i )->_rvc().count() >0 ) {
97                                        // Make marginal on rvc_i
98                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
99                                        //compute vector of lw_src
100                                        for ( int k=0;k<Ns;k++ ) {
101                                                lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_val ( Smp ( i ) ) );
102                                        }
103                                        delete tmp_marg;
104                                }
105                                // Compute likelihood of the missing variable
106                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
107                                        ///////////////
108                                        // There are variales unknown to mpdfs(i) : rvzs
109                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
110                                        // Compute likelihood
111                                        vec lw_dbg=lw_src;
112                                        for ( int k= 0; k<Ns; k++ ) {
113                                                lw_src ( k ) += log (
114                                                                    tmp_cond->evalcond (
115                                                                        zdls ( i )->get_val ( Smp ( k ) ),
116                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
117                                        }
118                                        delete tmp_cond;
119                                }
120                                // Compute likelihood of the partial source
121                                for ( int k= 0; k<Ns; k++ ) {
122                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
123                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp ( k ) ) );
124                                }
125
126                        }
127                        lW.set_row ( i, lw_src ); // do not divide by mix
128                }
129                //Importance of the mixture
130                for ( int j=0;j<Ns;j++ ) {
131                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
132                }
133                sprintf ( str,"lW%d",niter );
134                dbg << Name ( str ) << lW;
135
136                w = lognorm_merge ( lW ); //merge
137
138                sprintf ( str,"w%d",niter );
139                dbg << Name ( str ) << w;
140                sprintf ( str,"lw_m%d",niter );
141                dbg << Name ( str ) << lw_mix;
142
143                //Importance weighting
144                w /=exp ( lw_mix ); // hoping that it is not numerically sensitive...
145                //renormalize
146                w /=sum ( w );
147
148                sprintf ( str,"w_is_%d",niter );
149                dbg << Name ( str ) << w;
150
151//              eSmp.resample(); // So that it can be used in bayes
152//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
153
154                sprintf ( str,"Smp_res%d",niter );
155                dbg << Name ( str ) << Smp;
156
157                // ==== stopping rule ===
158                niter++;
159                converged = ( niter>9 );
160        }
161
162}
Note: See TracBrowser for help on using the browser.