root/bdm/estim/merger.cpp @ 205

Revision 205, 5.6 kB (checked in by smidl, 16 years ago)

merger posledni verze

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
11        switch ( nu ) {
12                case 2:
13                        coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
14                        return  coef*sq2bl + mu ;
15                        break;
16                case 3://Ratio of Bessel
17                        coef = sqrt ( ( 3*beta-2 ) /3*beta );
18                        return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
19                        break;
20                case 4:
21                        break;
22                default: // Approximate conditional density
23                        break;
24        }
25        return vec ( 0 );
26}
27
28void merger::merge ( const epdf* g0 ) {
29//      it_file dbg ( "merger_debug.it" );
30
31        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
32        //Empirical density - samples
33        eSmp.set_parameters ( ones ( Ns ), g0 );
34        Array<vec> &Smp = eSmp._samples(); //aux
35        vec &w = eSmp._w(); //aux
36
37        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
38        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
39
40//      dbg << Name ( "Smp_0" ) << Smp_ex;
41
42        // Stuff for merging
43        vec lw_src ( Ns );
44        vec lw_mix ( Ns );
45        vec lw ( Ns );
46        mat lW=zeros ( n,Ns );
47        vec vec0 ( 0 );
48
49        // Initial component in the mixture model
50        mat V0=1e-8*eye ( rv.count() +1 );
51        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
52                 V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance
53
54        Mix.init ( &A0, Smp_ex, Nc );
55        //Preserve initial mixture for repetitive estimation via flattening
56        MixEF Mix_init ( Mix );
57
58        // ============= MAIN LOOP ==================
59        bool converged=false;
60        int niter = 0;
61        char str[100];
62
63        epdf* Mpred=Mix.predictor(rv);
64        vec Mix_pdf ( Ns );
65        while ( !converged ) {
66                //Re-estimate Mix
67                //Re-Initialize Mixture model
68                Mix.flatten ( &Mix_init );
69                Mix.bayesB ( Smp_ex, w*Ns );
70                delete Mpred;
71                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
72               
73                if ( 1./sum_sqr ( w ) <0.5*Ns ) {
74                        // Generate new samples
75                        eSmp.set_samples ( Mpred );
76                        for ( int i=0;i<Ns;i++ ) {
77                                //////////// !!!!!!!!!!!!!
78                                if ( Smp ( i ) ( 1 ) <0 ) {Smp ( i ) ( 1 ) = 0.01; }
79                                set_col_part ( Smp_ex,i,Smp ( i ) );
80                        }
81                        {cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;}
82                        cout << sum ( Smp_ex,2 ) /Ns <<endl;
83                        cout << Smp_ex*Smp_ex.T()/Ns << endl; 
84                }
85//              sprintf ( str,"Mpred_mean%d",niter );
86//              dbg << Name ( str ) << Mpred->mean();
87
88
89//              sprintf ( str,"Mpdf%d",niter );
90//              for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
91//              dbg << Name ( str ) << Mix_pdf;
92
93//              sprintf ( str,"Smp%d",niter );
94//              dbg << Name ( str ) << Smp_ex;
95
96                //Importace weighting
97                for ( int i=0;i<n;i++ ) {
98                        lw_src=0.0;
99                        //======== Same RVs ===========
100                        //Split according to dependency in rvs
101                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
102                                // no need for conditioning or marginalization
103                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
104                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
105                                }
106                        }
107                        else {
108                                // compute likelihood of marginal on the conditional variable
109                                if ( mpdfs ( i )->_rvc().count() >0 ) {
110                                        // Make marginal on rvc_i
111                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
112                                        //compute vector of lw_src
113                                        for ( int k=0;k<Ns;k++ ) {
114                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
115                                                lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_cond ( Smp ( k ) ) );
116                                        }
117                                        delete tmp_marg;
118
119//                                      sprintf ( str,"marg%d",niter );
120//                                      dbg << Name ( str ) << lw_src;
121
122                                }
123                                // Compute likelihood of the missing variable
124                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
125                                        ///////////////
126                                        // There are variales unknown to mpdfs(i) : rvzs
127                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
128                                        // Compute likelihood
129                                        vec lw_dbg=lw_src;
130                                        for ( int k= 0; k<Ns; k++ ) {
131                                                lw_src ( k ) += log (
132                                                                    tmp_cond->evalcond (
133                                                                        zdls ( i )->get_val ( Smp ( k ) ),
134                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
135                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
136                                                        lw_src ( k ) = -1e16; cout << "!";
137                                                }
138                                        }
139                                        delete tmp_cond;
140                                }
141                                // Compute likelihood of the partial source
142                                for ( int k= 0; k<Ns; k++ ) {
143                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
144                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp ( k ) ) );
145                                }
146
147                        }
148//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
149                        lW.set_row ( i, lw_src ); // do not divide by mix
150                }
151                //Importance of the mixture
152                for ( int j=0;j<Ns;j++ ) {
153                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
154                }
155//              sprintf ( str,"lW%d",niter );
156//              dbg << Name ( str ) << lW;
157
158                lw = lognorm_merge ( lW ); //merge
159
160//              sprintf ( str,"w%d",niter );
161//              dbg << Name ( str ) << w;
162//              sprintf ( str,"lw_m%d",niter );
163//              dbg << Name ( str ) << lw_mix;
164
165                //Importance weighting
166                lw -=  lw_mix; // hoping that it is not numerically sensitive...
167                w = exp ( lw-max ( lw ) );
168                //renormalize
169                double sumw =sum ( w );
170                if ( std::isfinite ( sumw ) ) {
171                        w = w/sumw;
172                }
173                else {
174
175                }
176
177//              sprintf ( str,"w_is_%d",niter );
178//              dbg << Name ( str ) << w;
179
180//              eSmp.resample(); // So that it can be used in bayes
181//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
182
183//              sprintf ( str,"Smp_res%d",niter );
184//              dbg << Name ( str ) << Smp;
185
186                // ==== stopping rule ===
187                niter++;
188                converged = ( niter>20 );
189        }
190        cout << endl;
191
192}
Note: See TracBrowser for help on using the browser.