Changeset 395

Show
Ignore:
Timestamp:
06/22/09 13:17:49 (15 years ago)
Author:
smidl
Message:

merging works for merger_mx

Files:
7 modified

Legend:

Unmodified
Added
Removed
  • applications/bdmtoolbox/mex/merger_mx.cpp

    r393 r395  
    8383        Merger->set_sources(Sources,true); // takes care of deletion of sources 
    8484        Merger->merge(); 
    85          
     85 
     86        mxArray* tmp ; 
    8687        // Save results 
    8788        if (n_output>0){ 
    88                 mxArray* tmp = mxCreateStructMatrix(1,1,0,NULL); 
     89                tmp = mxCreateStructMatrix(1,1,0,NULL); 
    8990                //support 
    9091                Array<vec> &samples=Merger->_Smp()._samples(); 
     
    100101                vec2mxArray(w,fldw); 
    101102                mxReplaceFieldNM(tmp, "weights", fldw); 
    102                  
     103 
     104                // sources 
     105                                char srcstr[20]; 
     106                for (int i=0;i<Sources.length();i++){ 
     107                        sprintf(srcstr,"source%d",i+1); 
     108                        vec sll=exp(Sources(i)->evallogcond_m(Merger->_Smp()._samples(),vec(0))); 
     109 
     110                        mxArray* fldw=mxCreateDoubleMatrix(1, sll.length(), mxREAL); 
     111                        vec2mxArray(sll/sum(sll),fldw); 
     112                        mxReplaceFieldNM(tmp, srcstr, fldw); 
     113                }                
     114 
    103115                output[0] = tmp; 
    104116        } 
  • library/bdm/base/bdmbase.h

    r392 r395  
    268268  //! Compute log-probability of multiple values argument \c val 
    269269  virtual vec evallog_m(const mat &Val) const { 
    270     vec x(Val.cols()); 
    271     for (int i = 0; i < Val.cols(); i++) {x(i) = evallog(Val.get_col(i)) ;} 
    272     return x; 
     270          vec x(Val.cols()); 
     271          for (int i = 0; i < Val.cols(); i++) {x(i) = evallog(Val.get_col(i)) ;} 
     272          return x; 
     273  } 
     274  //! Compute log-probability of multiple values argument \c val 
     275  virtual vec evallog_m(const Array<vec> &Avec) const { 
     276          vec x(Avec.size()); 
     277          for (int i = 0; i < Avec.size(); i++) {x(i) = evallog(Avec(i)) ;} 
     278          return x; 
    273279  } 
    274280  //! Return conditional density on the given RV, the remaining rvs will be in conditioning 
     
    311317  //! Size of the random variable 
    312318  int dimension() const {return dim;} 
    313   //!@} 
     319  //! Load from structure with elements: 
     320  //!  \code 
     321  //! { rv = {class="RV", names=(...),}; // RV describing meaning of random variable 
     322  //!   // elements of offsprings 
     323  //! } 
     324  //! \endcode 
     325  //!@} 
     326  void from_setting(const Setting &set){ 
     327          if (set.exists("rv")){ 
     328                  RV* r = UI::build<RV>(set,"rv"); 
     329                  set_rv(*r);  
     330                  delete r; 
     331          } 
     332  } 
    314333 
    315334}; 
     
    368387  //! Matrix version of evallogcond 
    369388  virtual vec evallogcond_m(const mat &Dt, const vec &cond) {this->condition(cond); return ep->evallog_m(Dt);}; 
     389  //! Array<vec> version of evallogcond 
     390  virtual vec evallogcond_m(const Array<vec> &Dt, const vec &cond) {this->condition(cond); return ep->evallog_m(Dt);}; 
    370391 
    371392  //! \name Access to attributes 
     
    378399  epdf& _epdf() {return *ep;} 
    379400  epdf* _e() {return ep;} 
     401  //! Load from structure with elements: 
     402  //!  \code 
     403  //! { rv = {class="RV", names=(...),}; // RV describing meaning of random variable 
     404  //!   rvc= {class="RV", names=(...),}; // RV describing meaning of random variable in condition 
     405  //!   // elements of offsprings 
     406  //! } 
     407  //! \endcode 
     408  //!@} 
     409  void from_setting(const Setting &set){ 
     410          if (set.exists("rv")){ 
     411                  RV* r = UI::build<RV>(set,"rv"); 
     412                  set_rv(*r);  
     413                  delete r; 
     414          } 
     415          if (set.exists("rvc")){ 
     416                  RV* r = UI::build<RV>(set,"rvc"); 
     417                  set_rvc(*r);  
     418                  delete r; 
     419          } 
     420  } 
    380421  //!@} 
    381422 
     
    570611public: 
    571612        //!Default constructor 
     613        mepdf(){}; 
    572614        mepdf ( epdf* em, bool owning_ep0=false ) :mpdf ( ) {ep= em ;owning_ep=owning_ep0;dimc=0;}; 
    573615        mepdf (const epdf* em ) :mpdf ( ) {ep=const_cast<epdf*>( em );}; 
    574616        void condition ( const vec &cond ) {} 
    575617        ~mepdf(){if (owning_ep) delete ep;} 
    576 }; 
     618  //! Load from structure with elements: 
     619  //!  \code 
     620  //! { class = "mepdf",          
     621  //!   epdfs = {class="epdfs",...} 
     622  //! } 
     623  //! \endcode 
     624  //!@} 
     625        void from_setting(const Setting &set){ 
     626                epdf* e = UI::build<epdf>(set,"epdf"); 
     627                ep=     e;  
     628                owning_ep=true; 
     629        } 
     630}; 
     631UIREGISTER(mepdf); 
    577632 
    578633//!\brief Chain rule of pdfs - abstract part common for mprod and merger.  
  • library/bdm/stat/emix.h

    r394 r395  
    235235        mprod (){}; 
    236236        mprod (Array<mpdf*> mFacs ){set_elements( mFacs );}; 
    237         void set_elements(Array<mpdf*> mFacs ) { 
     237        void set_elements(Array<mpdf*> mFacs , bool own=false) { 
    238238                 
    239                 set_elements(mFacs); 
    240                  
     239                compositepdf::set_elements(mFacs,own); 
     240                dls.set_size(mFacs.length()); 
     241                epdfs.set_size(mFacs.length()); 
     242                                 
    241243                ep=&dummy; 
    242244                RV rv=getrv ( true ); 
    243                 set_rv ( rv );dummy.set_parameters ( rv._dsize() ); 
     245                set_rv ( rv ); 
     246                dummy.set_parameters ( rv._dsize() ); 
    244247                setrvc ( ep->_rv(),rvc ); 
    245248                // rv and rvc established = > we can link them with mpdfs 
     
    270273                return res; 
    271274        } 
     275        vec evallogcond_m(const mat &Dt, const vec &cond) { 
     276                vec tmp(Dt.cols()); 
     277                for(int i=0;i<Dt.cols(); i++){ 
     278                        tmp(i) = evallogcond(Dt.get_col(i),cond); 
     279                } 
     280                return tmp; 
     281        }; 
     282        vec evallogcond_m(const Array<vec> &Dt, const vec &cond) { 
     283                vec tmp(Dt.length()); 
     284                for(int i=0;i<Dt.length(); i++){ 
     285                        tmp(i) = evallogcond(Dt(i),cond); 
     286                } 
     287                return tmp;              
     288        }; 
     289 
     290 
    272291        //TODO smarter... 
    273292        vec samplecond ( const vec &cond ) { 
     
    294313 
    295314        ~mprod() {}; 
    296 }; 
     315        //! Load from structure with elements: 
     316        //!  \code 
     317        //! { class='mprod'; 
     318        //!   mpdfs = (..., ...);     // list of mpdfs in the order of chain rule 
     319        //! } 
     320        //! \endcode 
     321        //!@} 
     322        void from_setting(const Setting &set){ 
     323                Array<mpdf*> Atmp; //temporary Array 
     324                UI::get(Atmp,set, "mpdfs"); 
     325                set_elements(Atmp,true); 
     326        } 
     327         
     328}; 
     329UIREGISTER(mprod); 
    297330 
    298331//! Product of independent epdfs. For dependent pdfs, use mprod. 
  • library/bdm/stat/exp_family.cpp

    r384 r395  
    2424        if ( dimx==1 ) { //same as the following, just quicker. 
    2525                double r = val ( vend ); //last entry! 
     26                if (r<0) return -1e+100; 
    2627                vec Psi ( nPsi+dimx ); 
    2728                Psi ( 0 ) = -1.0; 
     
    3435                mat Th= reshape ( val ( 0,nPsi*dimx-1 ),nPsi,dimx ); 
    3536                fsqmat R ( reshape ( val ( nPsi*dimx,vend ),dimx,dimx ) ); 
     37                double ldetR=R.logdet(); 
     38                if (ldetR) return -1e+100; 
    3639                mat Tmp=concat_vertical ( -eye ( dimx ),Th ); 
    3740                fsqmat iR ( dimx ); 
    3841                R.inv ( iR ); 
    3942 
    40                 return -0.5* ( nu*R.logdet() + trace ( iR.to_mat() *Tmp.T() *V.to_mat() *Tmp ) ); 
     43                return -0.5* ( nu*ldetR + trace ( iR.to_mat() *Tmp.T() *V.to_mat() *Tmp ) ); 
    4144        } 
    4245} 
     
    187190        int i; 
    188191 
     192        if (any(val<=0.)) return -1e100; 
     193        if (any(beta<=0.)) return -1e100; 
    189194        for ( i=0; i<dim; i++ ) { 
    190195                res += ( alpha ( i ) - 1 ) *std::log ( val ( i ) ) - beta ( i ) *val ( i ); 
  • library/bdm/stat/exp_family.h

    r388 r395  
    4848                        virtual double evallog_nn ( const vec &val ) const{it_error ( "Not implemented" );return 0.0;}; 
    4949                        //!Evaluate normalized log-probability 
    50                         virtual double evallog ( const vec &val ) const {double tmp;tmp= evallog_nn ( val )-lognc();it_assert_debug ( std::isfinite ( tmp ),"Infinite value" ); return tmp;} 
     50                        virtual double evallog ( const vec &val ) const { 
     51                                double tmp; 
     52                                tmp= evallog_nn ( val )-lognc(); 
     53                                it_assert_debug ( std::isfinite ( tmp ),"Infinite value" );  
     54                                return tmp;} 
    5155                        //!Evaluate normalized log-probability for many samples 
    5256                        virtual vec evallog ( const mat &Val ) const 
     
    125129                        void set_parameters ( const vec &mu,const sq_T &R ); 
    126130                        void from_setting(const Setting &root); 
     131                        void validate() { 
     132                                it_assert(mu.length()==R.rows(),"parameters mismatch"); 
     133                                dim = mu.length(); 
     134                        } 
    127135                        //!@} 
    128136 
     
    225233                        double& _nu()  {return nu;} 
    226234                        const double& _nu() const {return nu;} 
    227                         //!@} 
    228         }; 
     235                        void from_setting(const Setting &set){ 
     236                                set.lookupValue("nu",nu); 
     237                                set.lookupValue("dimx",dimx); 
     238                                mat V; 
     239                                UI::get(V,set,"V"); 
     240                                set_parameters(dimx, V, nu); 
     241                                RV* rv=UI::build<RV>(set,"rv"); 
     242                                set_rv(*rv); 
     243                                delete rv; 
     244                        } 
     245                        //!@} 
     246        }; 
     247        UIREGISTER(egiw); 
    229248 
    230249        /*! \brief Dirichlet posterior density 
     
    370389                        vec mean() const {return elem_div ( alpha,beta );} 
    371390                        vec variance() const {return elem_div ( alpha,elem_mult ( beta,beta ) ); } 
    372         }; 
    373  
     391                         
     392                        //! Load from structure with elements: 
     393                        //!  \code 
     394                        //! { alpha = [...];         // vector of alpha 
     395                        //!   beta = [...];          // vector of beta 
     396                        //!   rv = {class="RV",...}  // description 
     397                        //! } 
     398                        //! \endcode 
     399                        //!@} 
     400                        void from_setting(const Setting &set){ 
     401                                epdf::from_setting(set); // reads rv 
     402                                UI::get(alpha,set,"alpha"); 
     403                                UI::get(beta,set,"beta"); 
     404                                validate(); 
     405                        } 
     406                        void validate(){ 
     407                                it_assert(alpha.length() ==beta.length(), "parameters do not match"); 
     408                                dim =alpha.length(); 
     409                        } 
     410        }; 
     411UIREGISTER(egamma); 
    374412        /*! 
    375413         \brief Inverse-Gamma posterior density 
     
    462500                        vec mean() const {return ( high-low ) /2.0;} 
    463501                        vec variance() const {return ( pow ( high,2 ) +pow ( low,2 ) +elem_mult ( high,low ) ) /3.0;} 
     502                        //! Load from structure with elements: 
     503                        //!  \code 
     504                        //! { high = [...];          // vector of upper bounds 
     505                        //!   low = [...];           // vector of lower bounds 
     506                        //!   rv = {class="RV",...}  // description of RV 
     507                        //! } 
     508                        //! \endcode 
     509                        //!@} 
     510                        void from_setting(const Setting &set){ 
     511                                epdf::from_setting(set); // reads rv and rvc 
     512                                UI::get(high,set,"high"); 
     513                                UI::get(low,set,"low"); 
     514                        } 
    464515        }; 
    465516 
     
    650701                        void set_parameters ( double k, const vec &beta0 ); 
    651702                        void condition ( const vec &val ) {_beta=k/val;}; 
    652         }; 
    653  
     703                        //! Load from structure with elements: 
     704                        //!  \code 
     705                        //! { alpha = [...];         // vector of alpha 
     706                        //!   k = 1.1;               // multiplicative constant k 
     707                        //!   rv = {class="RV",...}  // description of RV 
     708                        //!   rvc = {class="RV",...} // description of RV in condition 
     709                        //! } 
     710                        //! \endcode 
     711                        //!@} 
     712                        void from_setting(const Setting &set){ 
     713                                mpdf::from_setting(set); // reads rv and rvc 
     714                                vec betatmp; // ugly but necessary 
     715                                UI::get(betatmp,set,"beta"); 
     716                                set.lookupValue("k",k); 
     717                                set_parameters(k,betatmp); 
     718                        } 
     719        }; 
     720        UIREGISTER(mgamma); 
     721         
    654722        /*! 
    655723        \brief  Inverse-Gamma random walk 
     
    10881156                mu = mu0; 
    10891157                R = R0; 
    1090                 dim = mu0.length(); 
    1091         }; 
    1092  
    1093         template<class sq_T> 
    1094         void enorm<sq_T>::from_setting(const Setting &root){ 
    1095                 vec mu; 
    1096                 UI::get(mu,root,"mu"); 
    1097                 mat R; 
    1098                 UI::get(R,root,"R"); 
    1099                 set_parameters(mu,R); 
     1158                validate(); 
     1159        }; 
     1160 
     1161        template<class sq_T> 
     1162        void enorm<sq_T>::from_setting(const Setting &set){ 
     1163                epdf::from_setting(set); //reads rv 
    11001164                 
    1101                 RV* r = UI::build<RV>(root,"rv"); 
    1102                 set_rv(*r);  
    1103                 delete r; 
     1165                UI::get(mu,set,"mu"); 
     1166                mat Rtmp;// necessary for conversion 
     1167                UI::get(Rtmp,set,"R"); 
     1168                R=Rtmp; // conversion 
     1169                validate(); 
    11041170        } 
    11051171 
  • library/bdm/stat/merger.cpp

    r384 r395  
    134134                                { 
    135135                                        // no need for conditioning or marginalization 
    136                                         for ( int j=0;j<Npoints; j++ )   // Smp is Array<> => for cycle 
    137                                         { 
    138                                                 lw_src ( j ) =mpdfs ( i )->_epdf().evallog ( Smp ( j ) ); 
    139                                         } 
     136                                        lw_src = mpdfs ( i )->_epdf().evallog_m ( Smp  ); 
    140137                                } 
    141138                                else 
     
    219216                        // ==== stopping rule === 
    220217                        niter++; 
    221                         converged = ( niter>40 ); 
     218                        converged = ( niter>stop_niter ); 
    222219                } 
    223220                delete Mpred; 
  • library/bdm/stat/merger.h

    r392 r395  
    132132                        } 
    133133                        // fill samples 
    134                         int act_dim=0; //active dimension 
    135134                        for (int i = 0; i < Npoints; i++) { 
    136135                                // copy  
     
    139138                                for (int j = 0;j < dim;j++) { 
    140139                                        if (ind (j) == gridsize (j) - 1) { //j-th index is full 
    141 //                                              ind (j) = 0; //shift back 
     140                                                ind (j) = 0; //shift back 
    142141                                                smpi(j) = XYZ(j)(0); 
    143142                                                 
    144 //                                              ind (j + 1) ++; //increase the next dimension; 
    145                                                 smpi(j+1) += steps(j+1); 
     143                                                if (i<Npoints-1) { 
     144                                                        ind (j + 1) ++; //increase the next dimension; 
     145                                                        smpi(j+1) += steps(j+1); 
     146                                                        break; 
     147                                                } 
    146148                                                 
    147                                                 if (ind (j + 1) < gridsize (j + 1) - 1) break; 
    148149                                        } else { 
    149 //                                              ind (j) ++;  
     150                                                ind (j) ++;  
    150151                                                smpi(j) +=steps(j); 
    151152                                                break; 
     
    185186 
    186187                //!Merge given sources in given points 
    187                 void merge () { 
     188                virtual void merge () { 
    188189                        validate(); 
    189190 
     
    205206                                } 
    206207 
    207                                 vec wtmp = exp (merge_points (lW)); 
     208                                vec w_nn=merge_points (lW); 
     209                                vec wtmp = exp (w_nn-max(w_nn)); 
    208210                                //renormalize 
    209211                                eSmp._w() = wtmp / sum (wtmp); 
     
    277279                                } 
    278280                        } 
     281                        if (set.exists("dbg_file")){  
     282                                string dbg_file; 
     283                                UI::get<string> (dbg_file, set, "dbg_file"); 
     284                                set_debug_file(dbg_file); 
     285                        } 
     286                        //validate() - not used 
    279287                } 
    280288 
     
    295303                //!Number of components in a mixture 
    296304                int Ncoms; 
    297                 //! coefficient of resampling 
     305                //! coefficient of resampling [0,1] 
    298306                double effss_coef; 
     307                //! stop after niter iterations 
     308                int stop_niter; 
    299309 
    300310        public: 
     
    338348                emix* proposal() {emix* tmp = Mix.epredictor(); tmp->set_rv (rv); return tmp;} 
    339349                //! from_settings 
    340                 void from_settings(const Setting& set){ 
     350                void from_setting(const Setting& set){ 
    341351                        merger_base::from_setting(set); 
    342352                        set.lookupValue("ncoms",Ncoms); 
     353                        set.lookupValue("effss_coef",effss_coef); 
     354                        set.lookupValue("stop_niter",stop_niter); 
    343355                } 
    344356