root/bdm/estim/merger.cpp @ 204

Revision 204, 6.0 kB (checked in by smidl, 16 years ago)

merger is now in logarithms + new merge_test

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
11        switch ( nu ) {
12                case 2:
13                        coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
14                        return  coef*sq2bl + mu ;
15                        break;
16                case 3://Ratio of Bessel
17                        coef = sqrt ( ( 3*beta-2 ) /3*beta );
18                        return log( besselk ( 0,sq2bl*coef )) - log( besselk ( 0,sq2bl ) ) +  mu;
19                        break;
20                case 4:
21                        break;
22                default: // Approximate conditional density
23                        break;
24        }
25        return vec ( 0 );
26}
27
28void merger::merge ( const epdf* g0 ) {
29//      it_file dbg ( "merger_debug.it" );
30
31        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
32        //Empirical density - samples
33        eEmp eSmp ( rv,Ns );
34        eSmp.set_parameters ( ones ( Ns ), g0 );
35        Array<vec> &Smp = eSmp._samples(); //aux
36        vec &w = eSmp._w(); //aux
37
38        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
39        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
40
41//      dbg << Name ( "Smp_0" ) << Smp_ex;
42
43        // Stuff for merging
44        vec lw_src ( Ns );
45        vec lw_mix ( Ns );
46        vec lw ( Ns );
47        mat lW=zeros ( n,Ns );
48        vec vec0 ( 0 );
49
50        // Initial component in the mixture model
51        mat V0=1e-8*eye ( rv.count() +1 );
52        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
53                 V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance
54
55        Mix.init ( &A0, Smp_ex, Nc );
56        //Preserve initial mixture for repetitive estimation via flattening
57        MixEF Mix_init ( Mix );
58
59        // ============= MAIN LOOP ==================
60        bool converged=false;
61        int niter = 0;
62        char str[100];
63
64        epdf* Mpred;
65        vec Mix_pdf ( Ns );
66        while ( !converged ) {
67                //Re-estimate Mix
68                //Re-Initialize Mixture model
69                Mix.flatten ( &Mix_init );
70                Mix.bayesB ( Smp_ex, w*Ns );
71                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
72
73//              sprintf ( str,"Mpred_mean%d",niter );
74//              dbg << Name ( str ) << Mpred->mean();
75
76                if ( 1 ) {
77                        // Generate new samples
78                        eSmp.set_samples ( Mpred );
79                        for ( int i=0;i<Ns;i++ ) {
80                                //////////// !!!!!!!!!!!!!
81                                if ( Smp ( i ) ( 1 ) <0 ) {Smp ( i ) ( 1 ) = GamRNG(); }
82                                set_col_part ( Smp_ex,i,Smp ( i ) );
83                        }
84                        {cout<<"Eff. sample size=" << 1./sum_sqr ( w ) << endl;}
85                        cout << sum ( Smp_ex,2 ) /Ns <<endl;
86                }
87                else
88                        {cout<<"Eff. sample size=" << 1./sum_sqr ( w ) << endl;}
89
90//              sprintf ( str,"Mpdf%d",niter );
91//              for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
92//              dbg << Name ( str ) << Mix_pdf;
93
94//              sprintf ( str,"Smp%d",niter );
95//              dbg << Name ( str ) << Smp_ex;
96
97                //Importace weighting
98                for ( int i=0;i<n;i++ ) {
99                        lw_src=0.0;
100                        //======== Same RVs ===========
101                        //Split according to dependency in rvs
102                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
103                                // no need for conditioning or marginalization
104                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
105                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
106                                }
107                        }
108                        else {
109                                // compute likelihood of marginal on the conditional variable
110                                if ( mpdfs ( i )->_rvc().count() >0 ) {
111                                        // Make marginal on rvc_i
112                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
113                                        //compute vector of lw_src
114                                        for ( int k=0;k<Ns;k++ ) {
115                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
116                                                lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_cond ( Smp ( k ) ) );
117                                        }
118                                        delete tmp_marg;
119
120//                                      sprintf ( str,"marg%d",niter );
121//                                      dbg << Name ( str ) << lw_src;
122
123                                }
124                                // Compute likelihood of the missing variable
125                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
126                                        ///////////////
127                                        // There are variales unknown to mpdfs(i) : rvzs
128                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
129                                        // Compute likelihood
130                                        vec lw_dbg=lw_src;
131                                        for ( int k= 0; k<Ns; k++ ) {
132                                                lw_src ( k ) += log (
133                                                                    tmp_cond->evalcond (
134                                                                        zdls ( i )->get_val ( Smp ( k ) ),
135                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
136                                                if ( !std::isfinite ( lw_src ( k ) ) ) {
137                                                        cout << endl;
138                                                }
139                                        }
140                                        delete tmp_cond;
141                                }
142                                // Compute likelihood of the partial source
143                                for ( int k= 0; k<Ns; k++ ) {
144                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
145                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp ( k ) ) );
146                                }
147
148                        }
149//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
150                        lW.set_row ( i, lw_src ); // do not divide by mix
151                }
152                //Importance of the mixture
153                for ( int j=0;j<Ns;j++ ) {
154                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
155                }
156//              sprintf ( str,"lW%d",niter );
157//              dbg << Name ( str ) << lW;
158
159                lw = lognorm_merge ( lW ); //merge
160
161//              sprintf ( str,"w%d",niter );
162//              dbg << Name ( str ) << w;
163//              sprintf ( str,"lw_m%d",niter );
164//              dbg << Name ( str ) << lw_mix;
165
166                //Importance weighting
167                lw -=  lw_mix; // hoping that it is not numerically sensitive...
168                w = exp(lw-max(lw));
169                //renormalize
170                double sumw =sum ( w );
171                if ( std::isfinite ( sumw ) ) {
172                        w = w/sumw;
173                }
174                else {
175                       
176                }
177                {
178                        double eff = 1./sum_sqr ( w );
179                        if ( eff<2 ) {
180                                int mi= max_index ( w );
181                                cout << w(mi) <<endl;
182                                cout << lW.get_col ( mi ) <<endl;
183                                mat mm(2,1);mm.set_col(0,lW.get_col(mi));
184                                cout << lognorm_merge(mm) <<endl;
185                                cout << lw_mix ( mi ) <<endl;
186                                cout << lw ( mi ) <<endl;
187                                cout << Smp ( mi ) <<endl;
188                                cout << Mix._Coms(0)->_e()->mean() <<endl;
189                        }
190                        cout << "Eff: " << eff <<endl;
191                }
192//              sprintf ( str,"w_is_%d",niter );
193//              dbg << Name ( str ) << w;
194
195//              eSmp.resample(); // So that it can be used in bayes
196//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
197
198//              sprintf ( str,"Smp_res%d",niter );
199//              dbg << Name ( str ) << Smp;
200
201                // ==== stopping rule ===
202                niter++;
203                converged = ( niter>20 );
204        }
205        cout << endl;
206
207}
Note: See TracBrowser for help on using the browser.