root/library/bdm/stat/merger.cpp @ 423

Revision 423, 7.1 kB (checked in by vbarta, 15 years ago)

fixed merger_base constructor to initialize debug fields (still not all fields, though...)

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "../estim/arx.h"
4
5namespace bdm
6{
7
8        merger_base::merger_base(const Array<mpdf*> &S, bool own) {
9            DBG = false;
10            dbg_file = NULL;
11            set_sources(S, own);
12        }
13
14        vec merger_base::merge_points ( mat &lW ) {
15                int nu=lW.rows();
16                vec result;
17                ivec indW;
18                bool infexist;
19                switch ( METHOD ) {
20                        case ARITHMETIC:
21                                result= log ( sum ( exp ( lW ) ) ); //ugly!
22                                break;
23                        case GEOMETRIC:
24                                result= sum ( lW ) /nu;
25                                break;
26                        case LOGNORMAL:
27                                vec sumlW=sum ( lW ) ;
28                                indW=find((sumlW<inf) & (sumlW>-inf));
29                                infexist=(indW.size()<lW.cols());
30                                vec mu;
31                                vec lam;
32                                if (infexist){
33                                        mu = sumlW(indW) /nu; //mean of logs
34                                        //
35                                        mat validlW=lW.get_cols(indW);
36                                        lam = sum ( pow ( validlW-outer_product ( ones ( validlW.rows() ),mu ),2 ) );
37                                }
38                                else {
39                                        mu = sum ( lW ) /nu; //mean of logs
40                                        lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
41                                }
42                                //
43                                double coef=0.0;
44                                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
45                                switch ( nu ) {
46                                        case 2:
47                                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
48                                                result =coef*sq2bl + mu ;
49                                                break;
50                                        // case 4: == can be done similar to case 2 - is it worth it???
51                                        default: // see accompanying document merge_lognorm_derivation.lyx
52                                                coef = sqrt(1-(nu+1)/(2*beta*nu));
53                                                result =log(besselk((nu-3)/2, sq2bl*coef))-log(besselk((nu-3)/2, sq2bl)) + mu;
54                                                break;
55                                }
56                                break;
57                }
58                if (infexist){
59                        vec tmp =-inf*ones(lW.cols());
60                        set_subvector(tmp, indW, result);
61                        return tmp;
62                }
63                else {return result;}
64        }
65
66        void merger_mix::merge ( )
67        {
68                Array<vec> &Smp = eSmp._samples(); //aux
69                vec &w = eSmp._w(); //aux
70
71                mat Smp_ex =ones ( dim +1,Npoints ); // Extended samples for the ARX model - the last row is ones
72                for ( int i=0;i<Npoints;i++ ) { set_col_part ( Smp_ex,i,Smp ( i ) );}
73
74                if ( DBG )      *dbg_file << Name ( "Smp_0" ) << Smp_ex;
75
76                // Stuff for merging
77                vec lw_src ( Npoints );         // weights of the ith source
78                vec lw_mix ( Npoints );         // weights of the approximating mixture
79                vec lw ( Npoints );                     // tmp
80                mat lW=zeros ( Nsources,Npoints ); // array of weights of all sources
81                vec vec0 ( 0 );
82
83                //initialize importance weights
84                lw_mix = 1.0; // assuming uniform grid density -- otherwise
85
86                // Initial component in the mixture model
87                mat V0=1e-8*eye ( dim +1 );
88                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
89
90                Mix.init ( &A0, Smp_ex, Ncoms );
91                //Preserve initial mixture for repetitive estimation via flattening
92                MixEF Mix_init ( Mix );
93
94                // ============= MAIN LOOP ==================
95                bool converged=false;
96                int niter = 0;
97                char dbg_str[100];
98
99                emix* Mpred=Mix.epredictor ( );
100                vec Mix_pdf ( Npoints );
101                while ( !converged )
102                {
103                        //Re-estimate Mix
104                        //Re-Initialize Mixture model
105                        Mix.flatten ( &Mix_init );
106                        Mix.bayesB ( Smp_ex, w*Npoints );
107                        delete Mpred;
108                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
109                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
110
111                        // This will be active only later in iterations!!!
112                        if (  1./sum_sqr ( w ) <effss_coef*Npoints ) 
113                        {
114                                // Generate new samples
115                                eSmp.set_samples ( Mpred );
116                                for ( int i=0;i<Npoints;i++ )
117                                {
118                                        //////////// !!!!!!!!!!!!!
119                                        //if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
120                                        set_col_part ( Smp_ex,i,Smp ( i ) );
121                                        //Importance of the mixture
122                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
123                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
124                                }
125                                if ( DBG )
126                                {
127                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
128                                        cout << Mix._e()->mean() <<endl;
129                                        cout << sum ( Smp_ex,2 ) /Npoints <<endl;
130                                        cout << Smp_ex*Smp_ex.T() /Npoints << endl;
131                                }
132                        }
133                        if ( DBG )
134                        {
135                                sprintf ( dbg_str,"Mpred_mean%d",niter );
136                                *dbg_file << Name ( dbg_str ) << Mpred->mean();
137                                sprintf ( dbg_str,"Mpred_var%d",niter );
138                                *dbg_file << Name ( dbg_str ) << Mpred->variance();
139
140
141                                sprintf ( dbg_str,"Mpdf%d",niter );
142                                for ( int i=0;i<Npoints;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
143                                *dbg_file << Name ( dbg_str ) << Mix_pdf;
144
145                                sprintf ( dbg_str,"Smp%d",niter );
146                                *dbg_file << Name ( dbg_str ) << Smp_ex;
147
148                        }
149                        //Importace weighting
150                        for ( int i=0;i<mpdfs.length();i++ )
151                        {
152                                lw_src=0.0;
153                                //======== Same RVs ===========
154                                //Split according to dependency in rvs
155                                if ( mpdfs ( i )->dimension() ==dim )
156                                {
157                                        // no need for conditioning or marginalization
158                                        lw_src = mpdfs ( i )->_epdf().evallog_m ( Smp  );
159                                }
160                                else
161                                {
162                                        // compute likelihood of marginal on the conditional variable
163                                        if ( mpdfs ( i )->dimensionc() >0 )
164                                        {
165                                                // Make marginal on rvc_i
166                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
167                                                //compute vector of lw_src
168                                                for ( int k=0;k<Npoints;k++ )
169                                                {
170                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
171                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
172                                                }
173                                                delete tmp_marg;
174
175//                                      sprintf ( str,"marg%d",niter );
176//                                      *dbg << Name ( str ) << lw_src;
177
178                                        }
179                                        // Compute likelihood of the missing variable
180                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
181                                        {
182                                                ///////////////
183                                                // There are variales unknown to mpdfs(i) : rvzs
184                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
185                                                // Compute likelihood
186                                                vec lw_dbg=lw_src;
187                                                for ( int k= 0; k<Npoints; k++ )
188                                                {
189                                                        lw_src ( k ) += log (
190                                                                            tmp_cond->evallogcond (
191                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
192                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
193                                                        if ( !std::isfinite ( lw_src ( k ) ) )
194                                                        {
195                                                                lw_src ( k ) = -1e16; cout << "!";
196                                                        }
197                                                }
198                                                delete tmp_cond;
199                                        }
200                                        // Compute likelihood of the partial source
201                                        for ( int k= 0; k<Npoints; k++ )
202                                        {
203                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
204                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
205                                        }
206
207                                }
208                //                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
209                                lW.set_row ( i, lw_src ); // do not divide by mix
210                        }
211                        lw = merger_base::merge_points ( lW ); //merge
212
213                        //Importance weighting
214                        lw -=  lw_mix; // hoping that it is not numerically sensitive...
215                        w = exp ( lw-max ( lw ) );
216
217                        //renormalize
218                        double sumw =sum ( w );
219                        if ( std::isfinite ( sumw ) )
220                        {
221                                w = w/sumw;
222                        }
223                        else
224                        {
225                                it_file itf ( "merg_err.it" );
226                                itf << Name ( "w" ) << w;
227                        }
228
229                        if ( DBG )
230                        {
231                                sprintf ( dbg_str,"lW%d",niter );
232                                *dbg_file << Name ( dbg_str ) << lW;
233                                sprintf ( dbg_str,"w%d",niter );
234                                *dbg_file << Name ( dbg_str ) << w;
235                                sprintf ( dbg_str,"lw_m%d",niter );
236                                *dbg_file << Name ( dbg_str ) << lw_mix;
237                        }
238                        // ==== stopping rule ===
239                        niter++;
240                        converged = ( niter>stop_niter );
241                }
242                delete Mpred;
243//              cout << endl;
244
245        }
246
247        // DEFAULTS FOR MERGER_BASE
248        const MERGER_METHOD merger_base::DFLT_METHOD=LOGNORMAL;
249        const double merger_base::DFLT_beta=1.2;
250        // DEFAULTS FOR MERGER_MIX
251        const int merger_mix::DFLT_Ncoms=10;
252        const double merger_mix::DFLT_effss_coef=0.5;
253
254}
Note: See TracBrowser for help on using the browser.