mixpp: mixtures.h Source File

mixtures.h

Go to the documentation of this file.
00001
00013 #ifndef MIXTURES_H
00014 #define MIXTURES_H
00015 
00016
00017 #include "../math/functions.h"
00018 #include "../stat/exp_family.h"
00019 #include "../stat/emix.h"
00020 #include "arx.h"
00021
00022 namespace bdm {
00023
00025 enum MixEF_METHOD { EM = 0, QB = 1};
00026
00050 class MixEF: public BMEF {
00051 protected:
00053     Array<BMEF*> Coms;
00055     multiBM weights;
00056     //aux
00057     friend class eprod_mix;
00058
00060     class eprod_mix: public eprod_base {
00061     protected:
00062         const MixEF &mix; // pointer to parent.n
00063     public:
00064         eprod_mix(const MixEF &m):mix(m) {}
00065         const epdf* factor(int i) const {
00066             return (i==(mix.Coms.length()-1)) ? &mix.weights.posterior() : &mix.Coms(i)->posterior();
00067         }
00068         const int no_factors()const {
00069             return mix.Coms.length()+1;
00070         }
00071     } est;
00073
00074     class Options: public root {
00075     public:
00077         MixEF_METHOD method;
00078
00080         int max_niter;
00081
00082         Options():method(QB),max_niter(10) {};
00083
00099         void from_setting(const Setting &set) {
00100             string meth;
00101             UI::get(meth,set,"method",UI::optional);
00102             if (meth=="EM") {
00103                 method=EM;
00104             }
00105             max_niter =10;
00106             UI::get(max_niter,set,"max_niter",UI::optional);
00107         };
00108
00109         void to_setting(Setting &set)const {
00110             string meth=(method==EM ? "EM" : "QB");
00111             UI::save(meth,set,"method");
00112             UI::save(max_niter,set,"max_niter");
00113         };
00114     };
00115
00116     Options options;
00117 public:
00119     MixEF ( const Array<BMEF*> &Coms0, const vec &alpha0 ) :
00120         BMEF ( ), Coms ( Coms0.length() ),
00121         weights (), est(*this), options() {
00122         for ( int i = 0; i < Coms0.length(); i++ ) {
00123             Coms ( i ) = ( BMEF* ) Coms0 ( i )->_copy();
00124         }
00125         weights.set_parameters(alpha0);
00126         weights.validate();
00127     }
00128
00130     MixEF () :
00131         BMEF ( ), Coms ( 0 ),
00132         weights (), est(*this), options() {
00133     }
00135     MixEF ( const MixEF &M2 ) : BMEF ( ),  Coms ( M2.Coms.length() ),
00136         weights ( M2.weights ), est(*this), options(M2.options) {
00137         for ( int i = 0; i < M2.Coms.length(); i++ ) {
00138             Coms ( i ) = (BMEF*) M2.Coms ( i )->_copy();
00139         }
00140     }
00141
00147     void init ( BMEF* Com0, const mat &Data, const int c = 5 );
00148     //Destructor
00150     void bayes ( const vec &yt, const vec &cond );
00152     double bayes_batch_weighted ( const mat &yt, const mat &cond, const vec &wData );
00153     double bayes_batch ( const mat &yt, const mat &cond) {
00154         return bayes_batch_weighted(yt,cond,ones(yt.cols()));
00155     };
00156     double logpred ( const vec &yt, const vec &cond ) const;
00158     const eprod_mix& posterior() const {
00159         return est;
00160     }
00161
00162     emix* epredictor(const vec &cond=vec()) const;
00164     void flatten ( const BMEF* M2, double weight );
00166     BMEF* _Coms ( int i ) {
00167         return Coms ( i );
00168     }
00169
00171     void set_method ( MixEF_METHOD M ) {
00172         options.method = M;
00173     }
00174
00175     void to_setting ( Setting &set ) const {
00176         BMEF::to_setting( set );
00177         UI::save ( Coms, set, "Coms" );
00178         UI::save ( &weights, set, "weights" );
00179         UI::save (options, set, "options");
00180     }
00181
00194     void from_setting (const  Setting &set ) {
00195         BMEF::from_setting( set );
00196         UI::get ( Coms, set, "Coms" );
00197         UI::get ( weights, set, "weights" );
00198         UI::get (options, set, "options",UI::optional);
00199     }
00200 };
00201 UIREGISTER ( MixEF );
00202
00204 class ARXprod: public ProdBMBase {
00205     Array<shared_ptr<ARX> > arxs;
00206 public:
00207     ARX* bm(int i) {
00208         return arxs(i).get();
00209     }
00210     int no_bms() {
00211         return arxs.length();
00212     }
00213 };
00214 UIREGISTER(ARXprod);
00215
00216 }
00217 #endif // MIXTURES_H

Generated on 2 Dec 2013 for mixpp by  doxygen 1.4.7