root/bdm/estim/merger.cpp @ 379

Revision 379, 6.4 kB (checked in by smidl, 15 years ago)

merger restructured

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "arx.h"
4
5namespace bdm
6{
7        vec merger_base::merge_points ( mat &lW )
8        {
9                int nu=lW.rows();
10                switch ( METHOD )
11                {
12                        case ARITHMETIC:
13                                return log ( sum ( exp ( lW ) ) ); //ugly!
14                                break;
15                        case GEOMETRIC:
16                                return sum ( lW ) /nu;
17                                break;
18                        case LOGNORMAL:
19                                vec mu = sum ( lW ) /nu; //mean of logs
20                                // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
21                                vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
22                                double coef=0.0;
23                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
24                                switch ( nu )
25                                {
26                                        case 2:
27                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
28                                                return  coef*sq2bl + mu ;
29                                                break;
30                                        case 3://Ratio of Bessel
31                                                coef = sqrt ( ( 3*beta-2 ) /3*beta );
32                                                return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
33                                                break;
34                                        case 4:
35                                                break;
36                                        default: // Approximate conditional density
37                                                break;
38                                }
39                                break;
40                }
41                return vec ( 0 );
42        }
43
44        void merger_mix::merge ( )
45        {
46                Array<vec> &Smp = eSmp._samples(); //aux
47                vec &w = eSmp._w(); //aux
48
49                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones
50                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );}
51
52                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
53
54                // Stuff for merging
55                vec lw_src ( Npoints );         // weights of the ith source
56                vec lw_mix ( Npoints );         // weights of the approximating mixture
57                vec lw ( Npoints );                     // tmp
58                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources
59                vec vec0 ( 0 );
60
61                //initialize importance weights
62                lw_mix = 1.0; // assuming uniform grid density -- otherwise
63
64                // Initial component in the mixture model
65                mat V0=1e-8*eye ( dim +1 );
66                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
67
68                Mix.init ( &A0, Smp_ex, Ncoms );
69                //Preserve initial mixture for repetitive estimation via flattening
70                MixEF Mix_init ( Mix );
71
72                // ============= MAIN LOOP ==================
73                bool converged=false;
74                int niter = 0;
75                char dbg_str[100];
76
77                emix* Mpred=Mix.epredictor ( );
78                vec Mix_pdf ( Npoints );
79                while ( !converged )
80                {
81                        //Re-estimate Mix
82                        //Re-Initialize Mixture model
83                        Mix.flatten ( &Mix_init );
84                        Mix.bayesB ( Smp_ex, w*Npoints );
85                        delete Mpred;
86                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
87                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
88
89                        // This will be active only later in iterations!!!
90                        if (  1./sum_sqr ( w ) <effss_coef*Npoints ) 
91                        {
92                                // Generate new samples
93                                eSmp.set_samples ( Mpred );
94                                for ( int i=0;i<Npoints;i++ )
95                                {
96                                        //////////// !!!!!!!!!!!!!
97                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
98                                        set_col_part ( Smp_ex,i,Smp ( i ) );
99                                        //Importance of the mixture
100                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
101                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
102                                }
103                                if ( DBG )
104                                {
105                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
106                                        cout << Mix._e()->mean() <<endl;
107                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl;
108                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl;
109                                }
110                        }
111                        if ( DBG )
112                        {
113                                sprintf ( dbg_str,"Mpred_mean%d",niter );
114                                *dbg_file << Name ( dbg_str ) << Mpred->mean();
115                                sprintf ( dbg_str,"Mpred_var%d",niter );
116                                *dbg_file << Name ( dbg_str ) << Mpred->variance();
117
118
119                                sprintf ( dbg_str,"Mpdf%d",niter );
120                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
121                                *dbg_file << Name ( dbg_str ) << Mix_pdf;
122
123                                sprintf ( dbg_str,"Smp%d",niter );
124                                *dbg_file << Name ( dbg_str ) << Smp_ex;
125
126                        }
127                        //Importace weighting
128                        for ( int i=0;i<mpdfs.length();i++ )
129                        {
130                                lw_src=0.0;
131                                //======== Same RVs ===========
132                                //Split according to dependency in rvs
133                                if ( mpdfs ( i )->dimension() ==dim )
134                                {
135                                        // no need for conditioning or marginalization
136                                        for ( int j=0;j<Npoints; j++ )   // Smp is Array<> => for cycle
137                                        {
138                                                lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) );
139                                        }
140                                }
141                                else
142                                {
143                                        // compute likelihood of marginal on the conditional variable
144                                        if ( mpdfs ( i )->dimensionc() >0 )
145                                        {
146                                                // Make marginal on rvc_i
147                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
148                                                //compute vector of lw_src
149                                                for ( int k=0;k<Npoints;k++ )
150                                                {
151                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
152                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
153                                                }
154                                                delete tmp_marg;
155
156//                                      sprintf ( str,"marg%d",niter );
157//                                      *dbg << Name ( str ) << lw_src;
158
159                                        }
160                                        // Compute likelihood of the missing variable
161                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
162                                        {
163                                                ///////////////
164                                                // There are variales unknown to mpdfs(i) : rvzs
165                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
166                                                // Compute likelihood
167                                                vec lw_dbg=lw_src;
168                                                for ( int k= 0; k<Npoints; k++ )
169                                                {
170                                                        lw_src ( k ) += log (
171                                                                            tmp_cond->evallogcond (
172                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
173                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
174                                                        if ( !std::isfinite ( lw_src ( k ) ) )
175                                                        {
176                                                                lw_src ( k ) = -1e16; cout << "!";
177                                                        }
178                                                }
179                                                delete tmp_cond;
180                                        }
181                                        // Compute likelihood of the partial source
182                                        for ( int k= 0; k<Npoints; k++ )
183                                        {
184                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
185                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
186                                        }
187
188                                }
189                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
190                                lW.set_row ( i, lw_src ); // do not divide by mix
191                        }
192                        lw = merger_base::merge_points ( lW ); //merge
193
194                        //Importance weighting
195                        lw -=  lw_mix; // hoping that it is not numerically sensitive...
196                        w = exp ( lw-max ( lw ) );
197
198                        //renormalize
199                        double sumw =sum ( w );
200                        if ( std::isfinite ( sumw ) )
201                        {
202                                w = w/sumw;
203                        }
204                        else
205                        {
206                                it_file itf ( "merg_err.it" );
207                                itf << Name ( "w" ) << w;
208                        }
209
210                        if ( DBG )
211                        {
212                                sprintf ( dbg_str,"lW%d",niter );
213                                *dbg_file << Name ( dbg_str ) << lW;
214                                sprintf ( dbg_str,"w%d",niter );
215                                *dbg_file << Name ( dbg_str ) << w;
216                                sprintf ( dbg_str,"lw_m%d",niter );
217                                *dbg_file << Name ( dbg_str ) << lw_mix;
218                        }
219                        // ==== stopping rule ===
220                        niter++;
221                        converged = ( niter>40 );
222                }
223                delete Mpred;
224//              cout << endl;
225
226        }
227
228}
Note: See TracBrowser for help on using the browser.