Changeset 379 for bdm/estim/merger.h

Show
Ignore:
Timestamp:
06/17/09 23:53:11 (15 years ago)
Author:
smidl
Message:

merger restructured

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • bdm/estim/merger.h

    r311 r379  
    1919namespace bdm 
    2020{ 
    21         using std::string; 
    22  
    23         /*! 
    24         @brief Function for general combination of pdfs 
    25  
    26         Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial. 
    27         */ 
    28  
    29         class merger : public compositepdf, public epdf 
    30         { 
    31                 protected: 
    32                         //!Internal mixture of EF models 
    33                         MixEF Mix; 
    34                         //! Data link for each mpdf in mpdfs 
    35                         Array<datalink_m2e*> dls; 
    36                         //! Array of rvs that are not modelled by mpdfs at all (aux) 
    37                         Array<RV> rvzs; 
    38                         //! Data Links of rv0 mpdfs - these will be conditioned the (rv,rvc) of mpdfs 
    39                         Array<datalink_m2e*> zdls; 
    40  
    41                         //!Number of samples used in approximation 
    42                         int Ns; 
    43                         //!Number of components in a mixture 
    44                         int Nc; 
    45                         //!Prior on the log-normal merging model 
    46                         double beta; 
    47                         //! Projection to empirical density 
    48                         eEmp eSmp; 
    49                         //! coefficient of resampling 
    50                         double effss_coef; 
    51  
    52                         //! debug or not debug 
    53                         bool DBG; 
    54                         //! debugging file 
    55                         it_file* dbg; 
    56                         //! Flag if the samples are fixed or not 
    57                         bool fix_smp; 
    58                 public: 
    59 //!Default constructor 
    60                         merger ( const Array<mpdf*> &S ) : 
    61                                         compositepdf ( S ), epdf ( ), 
    62                                         Mix ( Array<BMEF*> ( 0 ),vec ( 0 ) ), dls ( n ), rvzs ( n ), zdls ( n ), eSmp() 
    63                         { 
    64                                 RV ztmp; 
    65                                 rv = getrv ( false ); 
    66                                 RV rvc; setrvc ( rv,rvc ); // Extend rv by rvc! 
    67                                 rv.add ( rvc ); 
    68                                 // get dimension 
    69                                 dim = rv._dsize(); 
    70  
    71                                 for ( int i=0;i<n;i++ ) 
    72                                 { 
    73                                         //Establich connection between mpdfs and merger 
    74                                         dls ( i ) = new datalink_m2e;dls ( i )->set_connection ( mpdfs ( i )->_rv(), mpdfs ( i )->_rvc(), rv ); 
    75                                         // find out what is missing in each mpdf 
    76                                         ztmp= mpdfs ( i )->_rv(); 
    77                                         ztmp.add ( mpdfs ( i )->_rvc() ); 
    78                                         rvzs ( i ) =rv.subt ( ztmp ); 
    79                                         zdls ( i ) = new datalink_m2e; zdls ( i )->set_connection ( rvzs ( i ), ztmp, rv ) ; 
    80                                 }; 
    81                                 //Set Default values of parameters 
    82                                 beta=2.0; 
    83                                 Ns=100; 
    84                                 Nc=10; 
    85                                 Mix.set_method ( EM ); 
    86                                 DBG = false; 
    87                                 fix_smp = false; 
    88                         } 
    89                         //! set debug file 
    90                         void debug_file ( const string fname ) { if ( DBG ) delete dbg; dbg = new it_file ( fname ); if ( dbg ) DBG=true;} 
    91 //! Set internal parameters used in approximation 
    92                         void set_parameters ( double beta0, int Ns0, int Nc0, double effss_coef0=0.5 ) {beta=beta0; 
    93                                 Ns=Ns0; 
    94                                 Nc=Nc0; 
    95                                 effss_coef=effss_coef0; 
    96                                 eSmp.set_parameters ( Ns0,false ); 
    97                         } 
    98                         void set_grid ( Array<vec> &XYZ ) 
    99                         { 
    100                                 int dim=XYZ.length(); ivec szs ( dim ); 
    101                                 for(int i=0; i<dim;i++){szs=XYZ(i).length();} 
    102                                 Ns=prod(szs); 
    103                                 eSmp.set_parameters(Ns,false); 
    104                                 Array<vec> &samples=eSmp._samples(); 
    105                                 eSmp._w()=ones(Ns)/Ns; 
    106                                                  
    107                                 //set samples 
    108                                 ivec is=zeros_i(dim);//indeces of dimensions in for cycle; 
    109                                 vec smpi(dim); 
    110                                 for(int i=0; i<Ns; i++){ 
    111                                         for(int j=0; j<dim; j++){smpi(j)=XYZ(j)(is(j)); /* jty vector*/ } 
    112                                         samples(i)=smpi; 
    113                                         // shift indeces 
    114                                         for (int j=0;j<dim;j++){ 
    115                                                 if (is(j)==szs(j)-1) { //j-th index is full 
    116                                                         is(j)=0; //shift back 
    117                                                         is(j+1)++; //increase th next dimension; 
    118                                                         if (is(j+1)<szs(j+1)-1) break; 
    119                                                 } else { 
    120                                                         is(j)++; break; 
    121                                                 } 
     21using std::string; 
     22 
     23//!Merging methods 
     24enum MERGER_METHOD {ARITHMETIC = 1, GEOMETRIC = 2, LOGNORMAL = 3}; 
     25 
     26/*! 
     27@brief Base class for general combination of pdfs on discrete support 
     28 
     29Mixtures of Gaussian densities are used internally. Switching to other densities should be trivial. 
     30 
     31The merged pdfs are expected to be of the form: 
     32 \f[ f(x_i|y_i),  i=1..n \f] 
     33where the resulting merger is a density on \f$ \cup [x_i,y_i] \f$ . 
     34Note that all variables will be joined. 
     35 
     36As a result of this feature, each source must be extended to common support 
     37\f[ f(z_i|y_i,x_i) f(x_i|y_i) f(y_i)  i=1..n \f] 
     38where \f$ z_i \f$ accumulate variables that were not in the original source. 
     39These extensions are calculated on-the-fly. 
     40 
     41However, these operations can not be performed in general. Hence, this class merges only sources on common support, \f$ y_i={}, z_i={}, \forall i \f$. 
     42For merging of more general cases, use offsprings merger_mix and merger_grid. 
     43*/ 
     44 
     45class merger_base : public compositepdf, public epdf 
     46{ 
     47        protected: 
     48                //! Data link for each mpdf in mpdfs 
     49                Array<datalink_m2e*> dls; 
     50                //! Array of rvs that are not modelled by mpdfs at all, \f$ z_i \f$ 
     51                Array<RV> rvzs; 
     52                //! Data Links for extension \f$ f(z_i|x_i,y_i) \f$ 
     53                Array<datalink_m2e*> zdls; 
     54                //! number of support points 
     55                int Npoints; 
     56                //! number of sources 
     57                int Nsources; 
     58 
     59                //! switch of the methoh used for merging 
     60                MERGER_METHOD METHOD; 
     61                //!Prior on the log-normal merging model 
     62                double beta; 
     63 
     64                //! Projection to empirical density (could also be piece-wise linear) 
     65                eEmp eSmp; 
     66 
     67                //! debug or not debug 
     68                bool DBG; 
     69 
     70                //! debugging file 
     71                it_file* dbg_file; 
     72        public: 
     73                //! \name Constructors 
     74                //! @{ 
     75 
     76                //!Empty constructor 
     77                merger_base () : compositepdf() {DBG=false;dbg_file=NULL;}; 
     78                //!Constructor from sources 
     79                merger_base (const Array<mpdf*> &S) {set_sources (S);}; 
     80                //! Function setting the main internal structures  
     81                void set_sources (const Array<mpdf*> &Sources) { 
     82                        compositepdf::set_elements (Sources); 
     83                        //set sizes 
     84                        dls.set_size (Sources.length()); 
     85                        rvzs.set_size (Sources.length()); 
     86                        zdls.set_size (Sources.length()); 
     87 
     88                        rv = getrv (/* checkoverlap = */ false); 
     89                        RV rvc; setrvc (rv, rvc);  // Extend rv by rvc! 
     90                        // join rv and rvc - see descriprion 
     91                        rv.add (rvc); 
     92                        // get dimension 
     93                        dim = rv._dsize(); 
     94 
     95                        // create links between sources and common rv 
     96                        RV xytmp; 
     97                        for (int i = 0;i < mpdfs.length();i++) { 
     98                                //Establich connection between mpdfs and merger 
     99                                dls (i) = new datalink_m2e; 
     100                                dls (i)->set_connection (mpdfs (i)->_rv(), mpdfs (i)->_rvc(), rv); 
     101 
     102                                // find out what is missing in each mpdf 
     103                                xytmp = mpdfs (i)->_rv(); 
     104                                xytmp.add (mpdfs (i)->_rvc()); 
     105                                // z_i = common_rv-xy 
     106                                rvzs (i) = rv.subt (xytmp); 
     107                                //establish connection between extension (z_i|x,y)s and common rv 
     108                                zdls (i) = new datalink_m2e; zdls (i)->set_connection (rvzs (i), xytmp, rv) ; 
     109                        }; 
     110                } 
     111                //! set debug file 
     112                void set_debug_file (const string fname) {  
     113                        if (DBG) delete dbg_file;  
     114                        dbg_file = new it_file (fname);  
     115                        if (dbg_file) DBG = true; 
     116                } 
     117                //! Set internal parameters used in approximation 
     118                void set_method (MERGER_METHOD MTH, double beta0=0.0) { 
     119                        METHOD = MTH;  
     120                        beta = beta0; 
     121                } 
     122                //! Set support points from a pdf by drawing N samples 
     123                void set_support(const epdf &overall, int N){ 
     124                        it_assert_debug ( rv.equal ( overall._rv() ),"Incompatible parameter overall!" ); 
     125                        eSmp.set_statistics(&overall,N); 
     126                        Npoints=N; 
     127                } 
     128                 
     129                //! Destructor 
     130                virtual ~merger_base() { 
     131                        for (int i = 0; i < Nsources; i++) { 
     132                                delete dls (i); 
     133                                delete zdls (i); 
     134                        } 
     135                        if (DBG) delete dbg_file; 
     136                }; 
     137                //!@} 
     138                 
     139                //! \name Mathematical operations 
     140                //!@{ 
     141                 
     142                //!Merge given sources in given points 
     143                void merge () { 
     144                        if (eSmp._w().length() ==0) {it_error("Empty support points use set_support" );} 
     145                        //check if sources overlap: 
     146                        bool OK = true; 
     147                        for (int i = 0;i < mpdfs.length(); i++) { 
     148                                OK &= (rvzs (i)._dsize() == 0); // z_i is empty 
     149                                OK &= (mpdfs (i)->_rvc()._dsize() == 0); // y_i is empty 
     150                        } 
     151 
     152                        if (OK) { 
     153                                mat lW = zeros (mpdfs.length(), eSmp._w().length()); 
     154 
     155                                vec emptyvec (0); 
     156                                for (int i = 0; i < mpdfs.length(); i++) { 
     157                                        for (int j = 0; j < eSmp._w().length(); j++) { 
     158                                                lW (i, j) = mpdfs (i)->evallogcond (eSmp._samples() (j), emptyvec); 
    122159                                        } 
    123160                                } 
    124                                  
    125                                 fix_smp = true; 
    126                         } 
    127 //!Initialize the proposal density. This function must be called before merge()! 
    128                         void init()   ////////////// NOT FINISHED 
    129                         { 
    130                                 Array<vec> Smps ( n ); 
    131                                 //Gibbs sampling 
    132                                 for ( int i=0;i<n;i++ ) {Smps ( i ) =zeros ( 0 );} 
    133                         } 
    134 //!Create a mixture density using known proposal 
    135                         void merge ( const epdf* g0 ); 
    136 //!Create a mixture density, make sure to call init() before the first call 
    137                         void merge () {merge ( & ( Mix.posterior() ) );}; 
    138  
    139 //! Merge log-likelihood values 
    140                         vec lognorm_merge ( mat &lW ); 
    141 //! sample from merged density 
     161 
     162                                vec wtmp = exp (merge_points (lW)); 
     163                                //renormalize 
     164                                eSmp._w() = wtmp / sum (wtmp); 
     165                        } else { 
     166                                it_error("Sources are not compatible - use merger_mix"); 
     167                        } 
     168                        ; 
     169                }; 
     170 
     171 
     172                //! Merge log-likelihood values in points using method specified by parameter METHOD 
     173                vec merge_points (mat &lW); 
     174                 
     175                 
     176                //! sample from merged density 
    142177//! weight w is a 
    143                         vec sample ( ) const { return Mix.posterior().sample();} 
    144                         double evallog ( const vec &dt ) const 
    145                         { 
    146                                 vec dtf=ones ( dt.length() +1 ); 
    147                                 dtf.set_subvector ( 0,dt ); 
    148                                 return Mix.logpred ( dtf ); 
    149                         } 
    150                         vec mean() const 
    151                         { 
    152                                 const Vec<double> &w = eSmp._w(); 
    153                                 const Array<vec> &S = eSmp._samples(); 
    154                                 vec tmp=zeros ( dim ); 
    155                                 for ( int i=0; i<Ns; i++ ) 
    156                                 { 
    157                                         tmp+=w ( i ) *S ( i ); 
    158                                 } 
    159                                 return tmp; 
    160                         } 
    161                         mat covariance() const 
    162                         { 
    163                                 const vec &w = eSmp._w(); 
    164                                 const Array<vec> &S = eSmp._samples(); 
    165  
    166                                 vec mea = mean(); 
    167  
    168                                 cout << sum ( w ) << "," << w*w <<endl; 
    169  
    170                                 mat Tmp=zeros ( dim, dim ); 
    171                                 for ( int i=0; i<Ns; i++ ) 
    172                                 { 
    173                                         Tmp+=w ( i ) *outer_product ( S ( i ), S ( i ) ); 
    174                                 } 
    175                                 return Tmp-outer_product ( mea,mea ); 
    176                         } 
    177                         vec variance() const 
    178                         { 
    179                                 const vec &w = eSmp._w(); 
    180                                 const Array<vec> &S = eSmp._samples(); 
    181  
    182                                 vec tmp=zeros ( dim ); 
    183                                 for ( int i=0; i<Ns; i++ ) 
    184                                 { 
    185                                         tmp+=w ( i ) *pow ( S ( i ),2 ); 
    186                                 } 
    187                                 return tmp-pow ( mean(),2 ); 
    188                         } 
    189 //! for future use 
    190                         virtual ~merger() 
    191                         { 
    192                                 for ( int i=0; i<n; i++ ) 
    193                                 { 
    194                                         delete dls ( i ); 
    195                                         delete zdls ( i ); 
    196                                 } 
    197                                 if ( DBG ) delete dbg; 
    198                         }; 
    199  
     178                vec mean() const { 
     179                        const Vec<double> &w = eSmp._w(); 
     180                        const Array<vec> &S = eSmp._samples(); 
     181                        vec tmp = zeros (dim); 
     182                        for (int i = 0; i < Npoints; i++) { 
     183                                tmp += w (i) * S (i); 
     184                        } 
     185                        return tmp; 
     186                } 
     187                mat covariance() const { 
     188                        const vec &w = eSmp._w(); 
     189                        const Array<vec> &S = eSmp._samples(); 
     190 
     191                        vec mea = mean(); 
     192 
     193                        cout << sum (w) << "," << w*w << endl; 
     194 
     195                        mat Tmp = zeros (dim, dim); 
     196                        for (int i = 0; i < Npoints; i++) { 
     197                                Tmp += w (i) * outer_product (S (i), S (i)); 
     198                        } 
     199                        return Tmp -outer_product (mea, mea); 
     200                } 
     201                vec variance() const { 
     202                        const vec &w = eSmp._w(); 
     203                        const Array<vec> &S = eSmp._samples(); 
     204 
     205                        vec tmp = zeros (dim); 
     206                        for (int i = 0; i < Nsources; i++) { 
     207                                tmp += w (i) * pow (S (i), 2); 
     208                        } 
     209                        return tmp -pow (mean(), 2); 
     210                } 
     211                //!@} 
     212 
     213                //! \name Access to attributes 
     214                //! @{ 
     215 
     216                //! Access function 
     217                eEmp& _Smp() {return eSmp;} 
     218                 
     219                //!@} 
     220}; 
     221 
     222class merger_mix : public merger_base 
     223{ 
     224        protected: 
     225                //!Internal mixture of EF models 
     226                MixEF Mix; 
     227                //!Number of components in a mixture 
     228                int Ncoms; 
     229                //! coefficient of resampling 
     230                double effss_coef; 
     231 
     232        public: 
     233                //!\name Constructors 
     234                //!@{ 
     235                merger_mix () {}; 
     236                merger_mix (const Array<mpdf*> &S) {set_sources(S);}; 
     237                //! Set sources and prepare all internal structures 
     238                void set_sources (const Array<mpdf*> &S) { 
     239                        merger_base::set_sources(S); 
     240                        Nsources = S.length(); 
     241                } 
     242                //! Set internal parameters used in approximation 
     243                void set_parameters (int Ncoms0=10, double effss_coef0 = 0.5) { 
     244                        Ncoms = Ncoms0; 
     245                        effss_coef = effss_coef0; 
     246                } 
     247                //!@} 
     248                 
     249                //! \name Mathematical operations 
     250                //!@{ 
     251                 
     252                //!Merge values using mixture approximation 
     253                void merge (); 
     254 
     255                //! sample from the approximating mixture 
     256                vec sample () const { return Mix.posterior().sample();} 
     257                //! loglikelihood computed on mixture models 
     258                double evallog (const vec &dt) const { 
     259                        vec dtf = ones (dt.length() + 1); 
     260                        dtf.set_subvector (0, dt); 
     261                        return Mix.logpred (dtf); 
     262                } 
     263                //!@} 
     264 
     265                //!\name Access functions 
     266                //!@{ 
    200267//! Access function 
    201                         MixEF& _Mix() {return Mix;} 
    202 //! Access function 
    203                         emix* proposal() {emix* tmp=Mix.epredictor(); tmp->set_rv(rv); return tmp;} 
    204 //! Access function 
    205                         eEmp& _Smp() {return eSmp;} 
    206         }; 
     268                MixEF& _Mix() {return Mix;} 
     269                //! Access function 
     270                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;} 
     271        //! @} 
     272                 
     273}; 
    207274 
    208275}