root/library/bdm/stat/merger.cpp @ 404

Revision 404, 7.0 kB (checked in by smidl, 15 years ago)

Change in epdf: evallog returns -inf for points out of support. Merger is aware of it now.

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm
6{
7        vec merger_base::merge_points ( mat &lW ) {
8                int nu=lW.rows();
9                vec result;
10                ivec indW;
11                bool infexist;
12                switch ( METHOD ) {
13                        case ARITHMETIC:
14                                result= log ( sum ( exp ( lW ) ) ); //ugly!
15                                break;
16                        case GEOMETRIC:
17                                result= sum ( lW ) /nu;
18                                break;
19                        case LOGNORMAL:
20                                vec sumlW=sum ( lW ) ;
21                                indW=find((sumlW<inf) & (sumlW>-inf));
22                                infexist=(indW.size()<lW.cols());
23                                vec mu;
24                                vec lam;
25                                if (infexist){
26                                        mu = sumlW(indW) /nu; //mean of logs
27                                        //
28                                        mat validlW=lW.get_cols(indW);
29                                        lam = sum ( pow ( validlW-outer_product ( ones ( validlW.rows() ),mu ),2 ) );
30                                }
31                                else {
32                                        mu = sum ( lW ) /nu; //mean of logs
33                                        lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
34                                }
35                                //
36                                double coef=0.0;
37                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
38                                switch ( nu ) {
39                                        case 2:
40                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
41                                                result =coef*sq2bl + mu ;
42                                                break;
43                                        // case 4: == can be done similar to case 2 - is it worth it???
44                                        default: // see accompanying document merge_lognorm_derivation.lyx
45                                                coef = sqrt(1-(nu+1)/(2*beta*nu));
46                                                result =log(besselk((nu-3)/2, sq2bl*coef))-log(besselk((nu-3)/2, sq2bl)) + mu;
47                                                break;
48                                }
49                                break;
50                }
51                if (infexist){
52                        vec tmp =-inf*ones(lW.cols());
53                        set_subvector(tmp, indW, result);
54                        return tmp;
55                }
56                else {return result;}
57        }
58
59        void merger_mix::merge ( )
60        {
61                Array<vec> &Smp = eSmp._samples(); //aux
62                vec &w = eSmp._w(); //aux
63
64                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones
65                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );}
66
67                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
68
69                // Stuff for merging
70                vec lw_src ( Npoints );         // weights of the ith source
71                vec lw_mix ( Npoints );         // weights of the approximating mixture
72                vec lw ( Npoints );                     // tmp
73                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources
74                vec vec0 ( 0 );
75
76                //initialize importance weights
77                lw_mix = 1.0; // assuming uniform grid density -- otherwise
78
79                // Initial component in the mixture model
80                mat V0=1e-8*eye ( dim +1 );
81                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
82
83                Mix.init ( &A0, Smp_ex, Ncoms );
84                //Preserve initial mixture for repetitive estimation via flattening
85                MixEF Mix_init ( Mix );
86
87                // ============= MAIN LOOP ==================
88                bool converged=false;
89                int niter = 0;
90                char dbg_str[100];
91
92                emix* Mpred=Mix.epredictor ( );
93                vec Mix_pdf ( Npoints );
94                while ( !converged )
95                {
96                        //Re-estimate Mix
97                        //Re-Initialize Mixture model
98                        Mix.flatten ( &Mix_init );
99                        Mix.bayesB ( Smp_ex, w*Npoints );
100                        delete Mpred;
101                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
102                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
103
104                        // This will be active only later in iterations!!!
105                        if (  1./sum_sqr ( w ) <effss_coef*Npoints ) 
106                        {
107                                // Generate new samples
108                                eSmp.set_samples ( Mpred );
109                                for ( int i=0;i<Npoints;i++ )
110                                {
111                                        //////////// !!!!!!!!!!!!!
112                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
113                                        set_col_part ( Smp_ex,i,Smp ( i ) );
114                                        //Importance of the mixture
115                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
116                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
117                                }
118                                if ( DBG )
119                                {
120                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
121                                        cout << Mix._e()->mean() <<endl;
122                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl;
123                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl;
124                                }
125                        }
126                        if ( DBG )
127                        {
128                                sprintf ( dbg_str,"Mpred_mean%d",niter );
129                                *dbg_file << Name ( dbg_str ) << Mpred->mean();
130                                sprintf ( dbg_str,"Mpred_var%d",niter );
131                                *dbg_file << Name ( dbg_str ) << Mpred->variance();
132
133
134                                sprintf ( dbg_str,"Mpdf%d",niter );
135                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
136                                *dbg_file << Name ( dbg_str ) << Mix_pdf;
137
138                                sprintf ( dbg_str,"Smp%d",niter );
139                                *dbg_file << Name ( dbg_str ) << Smp_ex;
140
141                        }
142                        //Importace weighting
143                        for ( int i=0;i<mpdfs.length();i++ )
144                        {
145                                lw_src=0.0;
146                                //======== Same RVs ===========
147                                //Split according to dependency in rvs
148                                if ( mpdfs ( i )->dimension() ==dim )
149                                {
150                                        // no need for conditioning or marginalization
151                                        lw_src = mpdfs ( i )->_epdf().evallog_m ( Smp  );
152                                }
153                                else
154                                {
155                                        // compute likelihood of marginal on the conditional variable
156                                        if ( mpdfs ( i )->dimensionc() >0 )
157                                        {
158                                                // Make marginal on rvc_i
159                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
160                                                //compute vector of lw_src
161                                                for ( int k=0;k<Npoints;k++ )
162                                                {
163                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
164                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
165                                                }
166                                                delete tmp_marg;
167
168//                                      sprintf ( str,"marg%d",niter );
169//                                      *dbg << Name ( str ) << lw_src;
170
171                                        }
172                                        // Compute likelihood of the missing variable
173                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
174                                        {
175                                                ///////////////
176                                                // There are variales unknown to mpdfs(i) : rvzs
177                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
178                                                // Compute likelihood
179                                                vec lw_dbg=lw_src;
180                                                for ( int k= 0; k<Npoints; k++ )
181                                                {
182                                                        lw_src ( k ) += log (
183                                                                            tmp_cond->evallogcond (
184                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
185                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
186                                                        if ( !std::isfinite ( lw_src ( k ) ) )
187                                                        {
188                                                                lw_src ( k ) = -1e16; cout << "!";
189                                                        }
190                                                }
191                                                delete tmp_cond;
192                                        }
193                                        // Compute likelihood of the partial source
194                                        for ( int k= 0; k<Npoints; k++ )
195                                        {
196                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
197                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
198                                        }
199
200                                }
201                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
202                                lW.set_row ( i, lw_src ); // do not divide by mix
203                        }
204                        lw = merger_base::merge_points ( lW ); //merge
205
206                        //Importance weighting
207                        lw -=  lw_mix; // hoping that it is not numerically sensitive...
208                        w = exp ( lw-max ( lw ) );
209
210                        //renormalize
211                        double sumw =sum ( w );
212                        if ( std::isfinite ( sumw ) )
213                        {
214                                w = w/sumw;
215                        }
216                        else
217                        {
218                                it_file itf ( "merg_err.it" );
219                                itf << Name ( "w" ) << w;
220                        }
221
222                        if ( DBG )
223                        {
224                                sprintf ( dbg_str,"lW%d",niter );
225                                *dbg_file << Name ( dbg_str ) << lW;
226                                sprintf ( dbg_str,"w%d",niter );
227                                *dbg_file << Name ( dbg_str ) << w;
228                                sprintf ( dbg_str,"lw_m%d",niter );
229                                *dbg_file << Name ( dbg_str ) << lw_mix;
230                        }
231                        // ==== stopping rule ===
232                        niter++;
233                        converged = ( niter>stop_niter );
234                }
235                delete Mpred;
236//              cout << endl;
237
238        }
239
240        // DEFAULTS FOR MERGER_BASE
241        const MERGER_METHOD merger_base::DFLT_METHOD=LOGNORMAL;
242        const double merger_base::DFLT_beta=1.2;
243        // DEFAULTS FOR MERGER_MIX
244        const int merger_mix::DFLT_Ncoms=10;
245        const double merger_mix::DFLT_effss_coef=0.5;
246
247}
Note: See TracBrowser for help on using the browser.