root/bdm/estim/merger.cpp @ 213

Revision 213, 5.8 kB (checked in by smidl, 15 years ago)

Merging - new experiment

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
9        vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
10        double coef=0.0;
11        vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
12        switch ( nu ) {
13                case 2:
14                        coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
15                        return  coef*sq2bl + mu ;
16                        break;
17                case 3://Ratio of Bessel
18                        coef = sqrt ( ( 3*beta-2 ) /3*beta );
19                        return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
20                        break;
21                case 4:
22                        break;
23                default: // Approximate conditional density
24                        break;
25        }
26        return vec ( 0 );
27}
28
29void merger::merge ( const epdf* g0 ) {
30//      it_file dbg ( "merger_debug.it" );
31
32        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
33        //Empirical density - samples
34        eSmp.set_parameters ( ones ( Ns ), g0 );
35        Array<vec> &Smp = eSmp._samples(); //aux
36        vec &w = eSmp._w(); //aux
37
38        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
39        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
40
41//      dbg << Name ( "Smp_0" ) << Smp_ex;
42
43        // Stuff for merging
44        vec lw_src ( Ns );
45        vec lw_mix ( Ns );
46        vec lw ( Ns );
47        mat lW=zeros ( n,Ns );
48        vec vec0 ( 0 );
49
50        // Initial component in the mixture model
51        mat V0=1e-8*eye ( rv.count() +1 );
52        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
53                 V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance
54
55        Mix.init ( &A0, Smp_ex, Nc );
56        //Preserve initial mixture for repetitive estimation via flattening
57        MixEF Mix_init ( Mix );
58
59        // ============= MAIN LOOP ==================
60        bool converged=false;
61        int niter = 0;
62        char str[100];
63
64        epdf* Mpred=Mix.predictor ( rv );
65        vec Mix_pdf ( Ns );
66        while ( !converged ) {
67                //Re-estimate Mix
68                //Re-Initialize Mixture model
69                Mix.flatten ( &Mix_init );
70                Mix.bayesB ( Smp_ex, w*Ns );
71                delete Mpred;
72                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
73
74                // This will be active only later in iterations!!!
75                if ( 1./sum_sqr ( w ) <0.5*Ns ) {
76                        // Generate new samples
77                        eSmp.set_samples ( Mpred );
78                        for ( int i=0;i<Ns;i++ ) {
79                                //////////// !!!!!!!!!!!!!
80                                if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
81                                set_col_part ( Smp_ex,i,Smp ( i ) );
82                        }
83                        if(0){cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
84                        cout << sum ( Smp_ex,2 ) /Ns <<endl;
85                        cout << Smp_ex*Smp_ex.T() /Ns << endl;}
86                }
87//              sprintf ( str,"Mpred_mean%d",niter );
88//              dbg << Name ( str ) << Mpred->mean();
89
90
91//              sprintf ( str,"Mpdf%d",niter );
92//              for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
93//              dbg << Name ( str ) << Mix_pdf;
94
95//              sprintf ( str,"Smp%d",niter );
96//              dbg << Name ( str ) << Smp_ex;
97
98                //Importace weighting
99                for ( int i=0;i<n;i++ ) {
100                        lw_src=0.0;
101                        //======== Same RVs ===========
102                        //Split according to dependency in rvs
103                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
104                                // no need for conditioning or marginalization
105                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
106                                        lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) );
107                                }
108                        }
109                        else {
110                                // compute likelihood of marginal on the conditional variable
111                                if ( mpdfs ( i )->_rvc().count() >0 ) {
112                                        // Make marginal on rvc_i
113                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
114                                        //compute vector of lw_src
115                                        for ( int k=0;k<Ns;k++ ) {
116                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
117                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
118                                        }
119                                        delete tmp_marg;
120
121//                                      sprintf ( str,"marg%d",niter );
122//                                      dbg << Name ( str ) << lw_src;
123
124                                }
125                                // Compute likelihood of the missing variable
126                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
127                                        ///////////////
128                                        // There are variales unknown to mpdfs(i) : rvzs
129                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
130                                        // Compute likelihood
131                                        vec lw_dbg=lw_src;
132                                        for ( int k= 0; k<Ns; k++ ) {
133                                                lw_src ( k ) += log (
134                                                                    tmp_cond->evallogcond (
135                                                                        zdls ( i )->get_val ( Smp ( k ) ),
136                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
137                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
138                                                        lw_src ( k ) = -1e16; cout << "!";
139                                                }
140                                        }
141                                        delete tmp_cond;
142                                }
143                                // Compute likelihood of the partial source
144                                for ( int k= 0; k<Ns; k++ ) {
145                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
146                                        lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->get_val ( Smp ( k ) ) );
147                                }
148
149                        }
150//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
151                        lW.set_row ( i, lw_src ); // do not divide by mix
152                }
153                //Importance of the mixture
154                for ( int j=0;j<Ns;j++ ) {
155                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
156                }
157//              sprintf ( str,"lW%d",niter );
158//              dbg << Name ( str ) << lW;
159
160                lw = lognorm_merge ( lW ); //merge
161
162//              sprintf ( str,"w%d",niter );
163//              dbg << Name ( str ) << w;
164//              sprintf ( str,"lw_m%d",niter );
165//              dbg << Name ( str ) << lw_mix;
166
167                //Importance weighting
168                lw -=  lw_mix; // hoping that it is not numerically sensitive...
169                w = exp ( lw-max ( lw ) );
170                //renormalize
171                double sumw =sum ( w );
172                if ( std::isfinite ( sumw ) ) {
173                        w = w/sumw;
174                }
175                else {
176                        it_file itf ( "merg_err.it" );
177                        itf << Name ( "w" ) << w;
178                }
179
180//              sprintf ( str,"w_is_%d",niter );
181//              dbg << Name ( str ) << w;
182
183//              eSmp.resample(); // So that it can be used in bayes
184//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
185
186//              sprintf ( str,"Smp_res%d",niter );
187//              dbg << Name ( str ) << Smp;
188
189                // ==== stopping rule ===
190                niter++;
191                converged = ( niter>20 );
192        }
193        delete Mpred;
194        cout << endl;
195
196}
Note: See TracBrowser for help on using the browser.