root/bdm/estim/merger.cpp @ 311

Revision 311, 6.1 kB (checked in by smidl, 15 years ago)

merger

  • Property svn:eol-style set to native
Line 
1
2#include "merger.h"
3#include "arx.h"
4
5namespace bdm
6{
7        vec merger::lognorm_merge ( mat &lW )
8        {
9                int nu=lW.rows();
10                vec mu = sum ( lW ) /nu; //mean of logs
11                // vec lam = sum ( pow ( lW,2 ) )-nu*pow ( mu,2 ); ======= numerically unsafe!
12                vec lam = sum ( pow ( lW-outer_product ( ones ( lW.rows() ),mu ),2 ) );
13                double coef=0.0;
14                vec sq2bl=sqrt ( 2*beta*lam ); //this term is everywhere
15                switch ( nu )
16                {
17                        case 2:
18                                coef= ( 1-0.5*sqrt ( ( 4.0*beta-3.0 ) /beta ) );
19                                return  coef*sq2bl + mu ;
20                                break;
21                        case 3://Ratio of Bessel
22                                coef = sqrt ( ( 3*beta-2 ) /3*beta );
23                                return log ( besselk ( 0,sq2bl*coef ) ) - log ( besselk ( 0,sq2bl ) ) +  mu;
24                                break;
25                        case 4:
26                                break;
27                        default: // Approximate conditional density
28                                break;
29                }
30                return vec ( 0 );
31        }
32
33        void merger::merge ( const epdf* g0 )
34        {
35
36                it_assert_debug ( rv.equal ( g0->_rv() ),"Incompatible g0" );
37                //Empirical density - samples
38                if ( !fix_smp )
39                {
40                        eSmp.set_statistics ( ones ( Ns ), g0 );
41                }
42
43                Array<vec> &Smp = eSmp._samples(); //aux
44                vec &w = eSmp._w(); //aux
45
46                mat Smp_ex =ones ( dim +1,Ns ); // Extended samples for the ARX model - the last row is ones
47                for ( int i=0;i<Ns;i++ ) {      set_col_part ( Smp_ex,i,Smp ( i ) );}
48
49                if ( DBG )      *dbg << Name ( "Smp_0" ) << Smp_ex;
50
51                // Stuff for merging
52                vec lw_src ( Ns );
53                vec lw_mix ( Ns );
54                vec lw ( Ns );
55                mat lW=zeros ( n,Ns );
56                vec vec0 ( 0 );
57
58                //initialize importance weights
59                if ( !fix_smp )
60                        for ( int i=0;i<Ns;i++ )
61                        {
62                                lw_mix ( i ) =g0->evallog ( Smp ( i ) );
63                        }
64
65                // Initial component in the mixture model
66                mat V0=1e-8*eye ( dim +1 );
67                ARX A0; A0.set_statistics ( dim, V0 ); //initial guess of Mix:
68
69                Mix.init ( &A0, Smp_ex, Nc );
70                //Preserve initial mixture for repetitive estimation via flattening
71                MixEF Mix_init ( Mix );
72
73                // ============= MAIN LOOP ==================
74                bool converged=false;
75                int niter = 0;
76                char str[100];
77
78                emix* Mpred=Mix.epredictor ( );
79                vec Mix_pdf ( Ns );
80                while ( !converged )
81                {
82                        //Re-estimate Mix
83                        //Re-Initialize Mixture model
84                        Mix.flatten ( &Mix_init );
85                        Mix.bayesB ( Smp_ex, w*Ns );
86                        delete Mpred;
87                        Mpred = Mix.epredictor ( ); // Allocation => must be deleted at the end!!
88                        Mpred->set_rv ( rv ); //the predictor predicts rv of this merger
89
90                        // This will be active only later in iterations!!!
91                        if ( ( !fix_smp ) & ( 1./sum_sqr ( w ) <effss_coef*Ns ) )
92                        {
93                                // Generate new samples
94                                eSmp.set_samples ( Mpred );
95                                for ( int i=0;i<Ns;i++ )
96                                {
97                                        //////////// !!!!!!!!!!!!!
98                                        if ( Smp ( i ) ( 2 ) <0 ) {Smp ( i ) ( 2 ) = 0.01; }
99                                        set_col_part ( Smp_ex,i,Smp ( i ) );
100                                        //Importance of the mixture
101                                        //lw_mix ( i ) =Mix.logpred (Smp_ex.get_col(i) );
102                                        lw_mix ( i ) = Mpred->evallog ( Smp ( i ) );
103                                }
104                                if ( 0 )
105                                {
106                                        cout<<"Resampling =" << 1./sum_sqr ( w ) << endl;
107                                        cout << Mix._e()->mean() <<endl;
108                                        cout << sum ( Smp_ex,2 ) /Ns <<endl;
109                                        cout << Smp_ex*Smp_ex.T() /Ns << endl;
110                                }
111                        }
112                        if ( DBG )
113                        {
114                                sprintf ( str,"Mpred_mean%d",niter );
115                                *dbg << Name ( str ) << Mpred->mean();
116                                sprintf ( str,"Mpred_var%d",niter );
117                                *dbg << Name ( str ) << Mpred->variance();
118
119
120                                sprintf ( str,"Mpdf%d",niter );
121                                for ( int i=0;i<Ns;i++ ) {Mix_pdf ( i ) = Mix.logpred ( Smp_ex.get_col ( i ) );}
122                                *dbg << Name ( str ) << Mix_pdf;
123
124                                sprintf ( str,"Smp%d",niter );
125                                *dbg << Name ( str ) << Smp_ex;
126
127                        }
128                        //Importace weighting
129                        for ( int i=0;i<n;i++ )
130                        {
131                                lw_src=0.0;
132                                //======== Same RVs ===========
133                                //Split according to dependency in rvs
134                                if ( mpdfs ( i )->dimension() ==dim )
135                                {
136                                        // no need for conditioning or marginalization
137                                        for ( int j=0;j<Ns; j++ )   // Smp is Array<> => for cycle
138                                        {
139                                                lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) );
140                                        }
141                                }
142                                else
143                                {
144                                        // compute likelihood of marginal on the conditional variable
145                                        if ( mpdfs ( i )->dimensionc() >0 )
146                                        {
147                                                // Make marginal on rvc_i
148                                                epdf* tmp_marg = Mpred->marginal ( mpdfs ( i )->_rvc() );
149                                                //compute vector of lw_src
150                                                for ( int k=0;k<Ns;k++ )
151                                                {
152                                                        // Here val of tmp_marg = cond of mpdfs(i) ==> calling dls->get_cond
153                                                        lw_src ( k ) += tmp_marg->evallog ( dls ( i )->get_cond ( Smp ( k ) ) );
154                                                }
155                                                delete tmp_marg;
156
157//                                      sprintf ( str,"marg%d",niter );
158//                                      *dbg << Name ( str ) << lw_src;
159
160                                        }
161                                        // Compute likelihood of the missing variable
162                                        if ( dim > ( mpdfs ( i )->dimension() + mpdfs ( i )->dimensionc() ) )
163                                        {
164                                                ///////////////
165                                                // There are variales unknown to mpdfs(i) : rvzs
166                                                mpdf* tmp_cond = Mpred->condition ( rvzs ( i ) );
167                                                // Compute likelihood
168                                                vec lw_dbg=lw_src;
169                                                for ( int k= 0; k<Ns; k++ )
170                                                {
171                                                        lw_src ( k ) += log (
172                                                                            tmp_cond->evallogcond (
173                                                                                zdls ( i )->pushdown ( Smp ( k ) ),
174                                                                                zdls ( i )->get_cond ( Smp ( k ) ) ) );
175                                                        if ( !std::isfinite ( lw_src ( k ) ) )
176                                                        {
177                                                                lw_src ( k ) = -1e16; cout << "!";
178                                                        }
179                                                }
180                                                delete tmp_cond;
181                                        }
182                                        // Compute likelihood of the partial source
183                                        for ( int k= 0; k<Ns; k++ )
184                                        {
185                                                mpdfs ( i )->condition ( dls ( i )->get_cond ( Smp ( k ) ) );
186                                                lw_src ( k ) += mpdfs ( i )->_epdf().evallog ( dls ( i )->pushdown ( Smp ( k ) ) );
187                                        }
188
189                                }
190//                      it_assert_debug(std::isfinite(sum(lw_src)),"bad");
191                                lW.set_row ( i, lw_src ); // do not divide by mix
192                        }
193                        lw = lognorm_merge ( lW ); //merge
194
195                        //Importance weighting
196                        if ( !fix_smp )
197                                lw -=  lw_mix; // hoping that it is not numerically sensitive...
198                        w = exp ( lw-max ( lw ) );
199
200                        //renormalize
201                        double sumw =sum ( w );
202                        if ( std::isfinite ( sumw ) )
203                        {
204                                w = w/sumw;
205                        }
206                        else
207                        {
208                                it_file itf ( "merg_err.it" );
209                                itf << Name ( "w" ) << w;
210                        }
211
212                        if ( DBG )
213                        {
214                                sprintf ( str,"lW%d",niter );
215                                *dbg << Name ( str ) << lW;
216                                sprintf ( str,"w%d",niter );
217                                *dbg << Name ( str ) << w;
218                                sprintf ( str,"lw_m%d",niter );
219                                *dbg << Name ( str ) << lw_mix;
220                        }
221                        // ==== stopping rule ===
222                        niter++;
223                        converged = ( niter>40 );
224                }
225                delete Mpred;
226//              cout << endl;
227
228        }
229
230}
Note: See TracBrowser for help on using the browser.