root/bdm/estim/merger.cpp @ 198

Revision 198, 5.1 kB (checked in by smidl, 16 years ago)

opravy + zavedeni studenta + zakomentovani debug v mergeru

Line 
1#include <itpp/itbase.h>
2#include "merger.h"
3#include "arx.h"
4
5vec merger::lognorm_merge ( mat &lW ) {
6        int nu=lW.rows();
7        vec mu = sum ( lW ) /nu; //mean of logs
8        vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 );
9        double coef=0.0;
10        vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
11        switch ( nu ) {
12                case 2:
13                        coef= ( 1-0.5*sqrt ( ( 4*beta-3 ) /beta ) );
14                        return exp ( coef*sq2bl + mu );
15                        break;
16                case 3://Ratio of Bessel
17                        coef = sqrt ( ( 3*beta-2 ) /3*beta );
18                        return elem_mult (
19                                   elem_div ( besselk ( 0,sq2bl*coef ), besselk ( 0,sq2bl ) ),
20                                   exp ( mu ) );
21                        break;
22                case 4:
23                        break;
24                default: // Approximate conditional density
25                        break;
26        }
27        return vec ( 0 );
28}
29
30void merger::merge ( const epdf* g0 ) {
31//      it_file dbg ( "merger_debug.it" );
32
33        it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
34        //Empirical density - samples
35        eEmp eSmp ( rv,Ns );
36        eSmp.set_parameters ( ones ( Ns ), g0 );
37        Array<vec> &Smp = eSmp._samples(); //aux
38        vec &w = eSmp._w(); //aux
39
40        mat Smp_ex =ones ( rv.count() +1,Ns ); // Extended samples for the ARX model - the last row is ones
41        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
42
43//      dbg << Name ( "Smp_0" ) << Smp_ex;
44
45        // Stuff for merging
46        vec lw_src ( Ns );
47        vec lw_mix ( Ns );
48        mat lW=zeros ( n,Ns );
49        vec vec0 ( 0 );
50
51        // Initial component in the mixture model
52        mat V0=1e-8*eye ( rv.count() +1 );
53        ARX A0 ( RV ( "{th_r  }", vec_1 ( rv.count() * ( rv.count() +1 ) ) ),\
54                 V0, rv.count() *rv.count() +5.0 ); //initial guess of Mix: zero mean, large variance
55
56        Mix.init ( &A0, Smp_ex, Nc );
57        //Preserve initial mixture for repetitive estimation via flattening
58        MixEF Mix_init ( Mix );
59
60        // ============= MAIN LOOP ==================
61        bool converged=false;
62        int niter = 0;
63        char str[100];
64
65        epdf* Mpred;
66        vec Mix_pdf ( Ns );
67        while ( !converged ) {
68                //Re-estimate Mix
69                //Re-Initialize Mixture model
70                Mix.flatten ( &Mix_init );
71                Mix.bayesB ( Smp_ex, w*Ns );
72                Mpred = Mix.predictor ( rv ); // Allocation => must be deleted at the end!!
73
74//              sprintf ( str,"Mpred_mean%d",niter );
75//              dbg << Name ( str ) << Mpred->mean();
76
77                if ( niter<10 ) {
78                        // Generate new samples
79                        eSmp.set_samples ( Mpred );
80                        for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
81
82                }
83//              sprintf ( str,"Mpdf%d",niter );
84//              for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
85//              dbg << Name ( str ) << Mix_pdf;
86
87//              sprintf ( str,"Smp%d",niter );
88//              dbg << Name ( str ) << Smp_ex;
89
90                //Importace weighting
91                for ( int i=0;i<n;i++ ) {
92                        lw_src=0.0;
93                        //======== Same RVs ===========
94                        //Split according to dependency in rvs
95                        if ( mpdfs ( i )->_rv().count() ==rv.count() ) {
96                                // no need for conditioning or marginalization
97                                for ( int j=0;j<Ns; j++ ) { // Smp is Array<> => for cycle
98                                        lw_src ( j ) =mpdfs ( i )->_epdf().evalpdflog ( Smp ( j ) );
99                                }
100                        }
101                        else {
102                                // compute likelihood of marginal on the conditional variable
103                                if ( mpdfs ( i )->_rvc().count() >0 ) {
104                                        // Make marginal on rvc_i
105                                        epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
106                                        //compute vector of lw_src
107                                        for ( int k=0;k<Ns;k++ ) {
108                                                // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
109                                                lw_src ( k ) += tmp_marg->evalpdflog ( dls ( i )->get_cond ( Smp ( k ) ) );
110                                        }
111                                        delete tmp_marg;
112
113//                                      sprintf ( str,"marg%d",niter );
114//                                      dbg << Name ( str ) << lw_src;
115
116                                }
117                                // Compute likelihood of the missing variable
118                                if ( rv.count() > ( mpdfs ( i )->_rv().count() + mpdfs ( i )->_rvc().count() ) ) {
119                                        ///////////////
120                                        // There are variales unknown to mpdfs(i) : rvzs
121                                        mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
122                                        // Compute likelihood
123                                        vec lw_dbg=lw_src;
124                                        for ( int k= 0; k<Ns; k++ ) {
125                                                lw_src ( k ) += log (
126                                                                    tmp_cond->evalcond (
127                                                                        zdls ( i )->get_val ( Smp ( k ) ),
128                                                                        zdls ( i )->get_cond ( Smp ( k ) ) ) );
129                                        }
130                                        delete tmp_cond;
131                                }
132                                // Compute likelihood of the partial source
133                                for ( int k= 0; k<Ns; k++ ) {
134                                        mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
135                                        lw_src ( k ) += mpdfs ( i )->_epdf().evalpdflog ( dls ( i )->get_val ( Smp ( k ) ) );
136                                }
137
138                        }
139//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
140                        lW.set_row ( i, lw_src ); // do not divide by mix
141                }
142                //Importance of the mixture
143                for ( int j=0;j<Ns;j++ ) {
144                        lw_mix ( j ) =Mix.logpred ( Smp_ex.get_col ( j ) );
145                }
146//              sprintf ( str,"lW%d",niter );
147//              dbg << Name ( str ) << lW;
148
149                w = lognorm_merge ( lW ); //merge
150
151//              sprintf ( str,"w%d",niter );
152//              dbg << Name ( str ) << w;
153//              sprintf ( str,"lw_m%d",niter );
154//              dbg << Name ( str ) << lw_mix;
155
156                //Importance weighting
157                w /=exp ( lw_mix ); // hoping that it is not numerically sensitive...
158                //renormalize
159                w /=sum ( w );
160
161//              sprintf ( str,"w_is_%d",niter );
162//              dbg << Name ( str ) << w;
163
164//              eSmp.resample(); // So that it can be used in bayes
165//              for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
166
167//              sprintf ( str,"Smp_res%d",niter );
168//              dbg << Name ( str ) << Smp;
169
170                // ==== stopping rule ===
171                niter++;
172                converged = ( niter>5 );
173        }
174
175}
Note: See TracBrowser for help on using the browser.