root/bdm/estim/merger.cpp @ 283

Revision 283, 5.7 kB (checked in by smidl, 15 years ago)

get rid of BMcond + adaptation in doprava/

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "arx.h"
4
5namespace bdm{
6vec merger::lognorm_merge ( mat &lW ) {
7        int nu=lW.rows();
8        vec mu = sum ( lW ) /nu; //mean of logs
9        // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
10        vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
11        double coef=0.0;
12        vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
13        switch ( nu ) {
14                case 2:
15                        coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
16                        return  coef*sq2bl + mu ;
17                        break;
18                case 3://Ratio of Bessel
19                        coef = sqrt ( ( 3*beta-2 ) /3*beta );
20                        return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
21                        break;
22                case 4:
23                        break;
24                default: // Approximate conditional density
25                        break;
26        }
27        return vec ( 0 );
28}
29
30void merger::merge ( const epdf* g0 ) {
31//      it_file dbg ( "merger_debug.it" );
32
33        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
34        //Empirical density - samples
35        eSmp.set_statistics ( ones ( Ns ), g0 );
36        Array<vec> &Smp = eSmp._samples(); //aux
37        vec &w = eSmp._w(); //aux
38
39        mat Smp_ex =ones ( dim +1,Ns ); // Extended samples for the ARX model - the last row is ones
40        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
41
42//      dbg << Name ( "Smp_0" ) << Smp_ex;
43
44        // Stuff for merging
45        vec lw_src ( Ns );
46        vec lw_mix ( Ns );
47        vec lw ( Ns );
48        mat lW=zeros ( n,Ns );
49        vec vec0 ( 0 );
50
51        // Initial component in the mixture model
52        mat V0=1e-8*eye ( dim +1 );
53        ARX A0; 
54        A0.set_statistics(dim, V0, dim*dim +5.0 ); //initial guess of Mix: zero mean, large variance
55
56        Mix.init ( &A0, Smp_ex, Nc );
57        //Preserve initial mixture for repetitive estimation via flattening
58        MixEF Mix_init ( Mix );
59
60        // ============= MAIN LOOP ==================
61        bool converged=false;
62        int niter = 0;
63        char str[100];
64
65        epdf* Mpred=Mix.epredictor (  );
66        vec Mix_pdf ( Ns );
67        while ( !converged ) {
68                //Re-estimate Mix
69                //Re-Initialize Mixture model
70                Mix.flatten ( &Mix_init );
71                Mix.bayesB ( Smp_ex, w*Ns );
72                delete Mpred;
73                Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
74
75                // This will be active only later in iterations!!!
76                if ( 1./sum_sqr ( w ) <0.5*Ns ) {
77                        // Generate new samples
78                        eSmp.set_samples ( Mpred );
79                        for ( int i=0;i<Ns;i++ ) {
80                                //////////// !!!!!!!!!!!!!
81                                if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
82                                set_col_part ( Smp_ex,i,Smp ( i ) );
83                        }
84                        if(0){cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
85                        cout << sum ( Smp_ex,2 ) /Ns <<endl;
86                        cout << Smp_ex*Smp_ex.T() /Ns << endl;}
87                }
88//              sprintf ( str,"Mpred_mean%d",niter );
89//              dbg << Name ( str ) << Mpred->mean();
90
91
92//              sprintf ( str,"Mpdf%d",niter );
93//              for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
94//              dbg << Name ( str ) << Mix_pdf;
95
96//              sprintf ( str,"Smp%d",niter );
97//              dbg << Name ( str ) << Smp_ex;
98
99                //Importace weighting
100                for ( int i=0;i<n;i++ ) {
101                        lw_src=0.0;
102                        //======== Same RVs ===========
103                        //Split according to dependency in rvs
104                        if ( mpdfs ( i )->dimension() ==dim ) {
105                                // no need for conditioning or marginalization
106                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
107                                        lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) );
108                                }
109                        }
110                        else {
111                                // compute likelihood of marginal on the conditional variable
112                                if ( mpdfs ( i )->dimensionc() >0 ) {
113                                        // Make marginal on rvc_i
114                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
115                                        //compute vector of lw_src
116                                        for ( int k=0;k<Ns;k++ ) {
117                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
118                                                lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
119                                        }
120                                        delete tmp_marg;
121
122//                                      sprintf ( str,"marg%d",niter );
123//                                      dbg << Name ( str ) << lw_src;
124
125                                }
126                                // Compute likelihood of the missing variable
127                                if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) ) {
128                                        ///////////////
129                                        // There are variales unknown to mpdfs(i) : rvzs
130                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
131                                        // Compute likelihood
132                                        vec lw_dbg=lw_src;
133                                        for ( int k= 0; k<Ns; k++ ) {
134                                                lw_src ( k ) += log (
135                                                                    tmp_cond->evallogcond (
136                                                                        zdls ( i )->pushdown ( Smp ( k ) ),
137                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
138                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
139                                                        lw_src ( k ) = -1e16; cout << "!";
140                                                }
141                                        }
142                                        delete tmp_cond;
143                                }
144                                // Compute likelihood of the partial source
145                                for ( int k= 0; k<Ns; k++ ) {
146                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
147                                        lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
148                                }
149
150                        }
151//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
152                        lW.set_row ( i, lw_src ); // do not divide by mix
153                }
154                //Importance of the mixture
155                for ( int j=0;j<Ns;j++ ) {
156                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
157                }
158//              sprintf ( str,"lW%d",niter );
159//              dbg << Name ( str ) << lW;
160
161                lw = lognorm_merge ( lW ); //merge
162
163//              sprintf ( str,"w%d",niter );
164//              dbg << Name ( str ) << w;
165//              sprintf ( str,"lw_m%d",niter );
166//              dbg << Name ( str ) << lw_mix;
167
168                //Importance weighting
169                lw -=  lw_mix; // hoping that it is not numerically sensitive...
170                w = exp ( lw-max ( lw ) );
171                //renormalize
172                double sumw =sum ( w );
173                if ( std::isfinite ( sumw ) ) {
174                        w = w/sumw;
175                }
176                else {
177                        it_file itf ( "merg_err.it" );
178                        itf << Name ( "w" ) << w;
179                }
180
181//              sprintf ( str,"w_is_%d",niter );
182//              dbg << Name ( str ) << w;
183
184//              eSmp.resample(); // So that it can be used in bayes
185//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
186
187//              sprintf ( str,"Smp_res%d",niter );
188//              dbg << Name ( str ) << Smp;
189
190                // ==== stopping rule ===
191                niter++;
192                converged = ( niter>20 );
193        }
194        delete Mpred;
195        cout << endl;
196
197}
198
199}
Note: See TracBrowser for help on using the browser.