root/library/bdm/stat/merger.cpp @ 395

Revision 395, 6.3 kB (checked in by smidl, 15 years ago)

merging works for merger_mx

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm
6{
7        vec merger_base::merge_points ( mat &lW )
8        {
9                int nu=lW.rows();
10                switch ( METHOD )
11                {
12                        case ARITHMETIC:
13                                return log ( sum ( exp ( lW ) ) ); //ugly!
14                                break;
15                        case GEOMETRIC:
16                                return sum ( lW ) /nu;
17                                break;
18                        case LOGNORMAL:
19                                vec mu = sum ( lW ) /nu; //mean of logs
20                                // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
21                                vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
22                                double coef=0.0;
23                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
24                                switch ( nu )
25                                {
26                                        case 2:
27                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
28                                                return  coef*sq2bl + mu ;
29                                                break;
30                                        case 3://Ratio of Bessel
31                                                coef = sqrt ( ( 3*beta-2 ) /3*beta );
32                                                return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
33                                                break;
34                                        case 4:
35                                                break;
36                                        default: // Approximate conditional density
37                                                break;
38                                }
39                                break;
40                }
41                return vec ( 0 );
42        }
43
44        void merger_mix::merge ( )
45        {
46                Array<vec> &Smp = eSmp._samples(); //aux
47                vec &w = eSmp._w(); //aux
48
49                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones
50                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );}
51
52                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
53
54                // Stuff for merging
55                vec lw_src ( Npoints );         // weights of the ith source
56                vec lw_mix ( Npoints );         // weights of the approximating mixture
57                vec lw ( Npoints );                     // tmp
58                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources
59                vec vec0 ( 0 );
60
61                //initialize importance weights
62                lw_mix = 1.0; // assuming uniform grid density -- otherwise
63
64                // Initial component in the mixture model
65                mat V0=1e-8*eye ( dim +1 );
66                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
67
68                Mix.init ( &A0, Smp_ex, Ncoms );
69                //Preserve initial mixture for repetitive estimation via flattening
70                MixEF Mix_init ( Mix );
71
72                // ============= MAIN LOOP ==================
73                bool converged=false;
74                int niter = 0;
75                char dbg_str[100];
76
77                emix* Mpred=Mix.epredictor ( );
78                vec Mix_pdf ( Npoints );
79                while ( !converged )
80                {
81                        //Re-estimate Mix
82                        //Re-Initialize Mixture model
83                        Mix.flatten ( &Mix_init );
84                        Mix.bayesB ( Smp_ex, w*Npoints );
85                        delete Mpred;
86                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
87                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
88
89                        // This will be active only later in iterations!!!
90                        if (  1./sum_sqr ( w ) <effss_coef*Npoints ) 
91                        {
92                                // Generate new samples
93                                eSmp.set_samples ( Mpred );
94                                for ( int i=0;i<Npoints;i++ )
95                                {
96                                        //////////// !!!!!!!!!!!!!
97                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
98                                        set_col_part ( Smp_ex,i,Smp ( i ) );
99                                        //Importance of the mixture
100                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
101                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
102                                }
103                                if ( DBG )
104                                {
105                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
106                                        cout << Mix._e()->mean() <<endl;
107                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl;
108                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl;
109                                }
110                        }
111                        if ( DBG )
112                        {
113                                sprintf ( dbg_str,"Mpred_mean%d",niter );
114                                *dbg_file << Name ( dbg_str ) << Mpred->mean();
115                                sprintf ( dbg_str,"Mpred_var%d",niter );
116                                *dbg_file << Name ( dbg_str ) << Mpred->variance();
117
118
119                                sprintf ( dbg_str,"Mpdf%d",niter );
120                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
121                                *dbg_file << Name ( dbg_str ) << Mix_pdf;
122
123                                sprintf ( dbg_str,"Smp%d",niter );
124                                *dbg_file << Name ( dbg_str ) << Smp_ex;
125
126                        }
127                        //Importace weighting
128                        for ( int i=0;i<mpdfs.length();i++ )
129                        {
130                                lw_src=0.0;
131                                //======== Same RVs ===========
132                                //Split according to dependency in rvs
133                                if ( mpdfs ( i )->dimension() ==dim )
134                                {
135                                        // no need for conditioning or marginalization
136                                        lw_src = mpdfs ( i )->_epdf().evallog_m ( Smp  );
137                                }
138                                else
139                                {
140                                        // compute likelihood of marginal on the conditional variable
141                                        if ( mpdfs ( i )->dimensionc() >0 )
142                                        {
143                                                // Make marginal on rvc_i
144                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
145                                                //compute vector of lw_src
146                                                for ( int k=0;k<Npoints;k++ )
147                                                {
148                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
149                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
150                                                }
151                                                delete tmp_marg;
152
153//                                      sprintf ( str,"marg%d",niter );
154//                                      *dbg << Name ( str ) << lw_src;
155
156                                        }
157                                        // Compute likelihood of the missing variable
158                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
159                                        {
160                                                ///////////////
161                                                // There are variales unknown to mpdfs(i) : rvzs
162                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
163                                                // Compute likelihood
164                                                vec lw_dbg=lw_src;
165                                                for ( int k= 0; k<Npoints; k++ )
166                                                {
167                                                        lw_src ( k ) += log (
168                                                                            tmp_cond->evallogcond (
169                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
170                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
171                                                        if ( !std::isfinite ( lw_src ( k ) ) )
172                                                        {
173                                                                lw_src ( k ) = -1e16; cout << "!";
174                                                        }
175                                                }
176                                                delete tmp_cond;
177                                        }
178                                        // Compute likelihood of the partial source
179                                        for ( int k= 0; k<Npoints; k++ )
180                                        {
181                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
182                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
183                                        }
184
185                                }
186                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
187                                lW.set_row ( i, lw_src ); // do not divide by mix
188                        }
189                        lw = merger_base::merge_points ( lW ); //merge
190
191                        //Importance weighting
192                        lw -=  lw_mix; // hoping that it is not numerically sensitive...
193                        w = exp ( lw-max ( lw ) );
194
195                        //renormalize
196                        double sumw =sum ( w );
197                        if ( std::isfinite ( sumw ) )
198                        {
199                                w = w/sumw;
200                        }
201                        else
202                        {
203                                it_file itf ( "merg_err.it" );
204                                itf << Name ( "w" ) << w;
205                        }
206
207                        if ( DBG )
208                        {
209                                sprintf ( dbg_str,"lW%d",niter );
210                                *dbg_file << Name ( dbg_str ) << lW;
211                                sprintf ( dbg_str,"w%d",niter );
212                                *dbg_file << Name ( dbg_str ) << w;
213                                sprintf ( dbg_str,"lw_m%d",niter );
214                                *dbg_file << Name ( dbg_str ) << lw_mix;
215                        }
216                        // ==== stopping rule ===
217                        niter++;
218                        converged = ( niter>stop_niter );
219                }
220                delete Mpred;
221//              cout << endl;
222
223        }
224
225}
Note: See TracBrowser for help on using the browser.