root/library/bdm/estim/arx.cpp @ 625

Revision 625, 5.9 kB (checked in by smidl, 15 years ago)

ARX re-designed

  • Property svn:eol-style set to native
RevLine 
[97]1#include "arx.h"
[270]2namespace bdm {
[13]3
[170]4void ARX::bayes ( const vec &dt, const double w ) {
[97]5        double lnc;
[13]6
[625]7               
[477]8        if ( frg < 1.0 ) {
[170]9                est.pow ( frg );
10                if ( evalll ) {
[162]11                        last_lognc = est.lognc();
12                }
13        }
[625]14        if (have_constant) {
15                _dt.set_subvector(0,dt);
16                V.opupdt ( _dt, w );
17        } else {
18                V.opupdt ( dt, w );
19        }
[477]20        nu += w;
[97]21
[204]22        // log(sqrt(2*pi)) = 0.91893853320467
[97]23        if ( evalll ) {
24                lnc = est.lognc();
[204]25                ll = lnc - last_lognc - 0.91893853320467;
[97]26                last_lognc = lnc;
27        }
28}
29
[170]30double ARX::logpred ( const vec &dt ) const {
31        egiw pred ( est );
[477]32        ldmat &V = pred._V();
33        double &nu = pred._nu();
[170]34
35        double lll;
36
[477]37        if ( frg < 1.0 ) {
[170]38                pred.pow ( frg );
39                lll = pred.lognc();
[477]40        } else//should be save: last_lognc is changed only by bayes;
41                if ( evalll ) {
42                        lll = last_lognc;
43                } else {
44                        lll = pred.lognc();
45                }
[170]46
[477]47        V.opupdt ( dt, 1.0 );
48        nu += 1.0;
[201]49        // log(sqrt(2*pi)) = 0.91893853320467
[477]50        return pred.lognc() - lll - 0.91893853320467;
[170]51}
52
[283]53ARX* ARX::_copy_ ( ) const {
[477]54        ARX* Tmp = new ARX ( *this );
[170]55        return Tmp;
56}
57
58void ARX::set_statistics ( const BMEF* B0 ) {
[477]59        const ARX* A0 = dynamic_cast<const ARX*> ( B0 );
[170]60
[565]61        bdm_assert_debug ( V.rows() == A0->V.rows(), "ARX::set_statistics Statistics  differ" );
[270]62        set_statistics ( A0->dimx, A0->V, A0->nu );
[170]63}
[180]64
[270]65enorm<ldmat>* ARX::epredictor ( const vec &rgr ) const {
[477]66        int dim = dimx;//est.dimension();
67        mat mu ( dim, V.rows() - dim );
68        mat R ( dim, dim );
[270]69
[198]70        enorm<ldmat>* tmp;
[477]71        tmp = new enorm<ldmat> ( );
[270]72        //TODO: too hackish
[477]73        if ( drv._dsize() > 0 ) {
[270]74        }
[198]75
[477]76        est.mean_mat ( mu, R ); //mu =
[198]77        //correction for student-t  -- TODO check if correct!!
78        //R*=nu/(nu-2);
[477]79        mat p_mu = mu.T() * rgr;        //the result is one column
80        tmp->set_parameters ( p_mu.get_col ( 0 ), ldmat ( R ) );
[198]81        return tmp;
82}
83
[270]84mlnorm<ldmat>* ARX::predictor ( ) const {
[477]85        int dim = est.dimension();
[625]86       
[477]87        mat mu ( dim, V.rows() - dim );
88        mat R ( dim, dim );
[198]89        mlnorm<ldmat>* tmp;
[477]90        tmp = new mlnorm<ldmat> ( );
[198]91
[477]92        est.mean_mat ( mu, R ); //mu =
[198]93        mu = mu.T();
94        //correction for student-t  -- TODO check if correct!!
95
[625]96        if ( have_constant) { // constant term
[198]97                //Assume the constant term is the last one:
[477]98                tmp->set_parameters ( mu.get_cols ( 0, mu.cols() - 2 ), mu.get_col ( mu.cols() - 1 ), ldmat ( R ) );
[625]99        } else {
100                tmp->set_parameters ( mu, zeros ( dim ), ldmat ( R ) );
[198]101        }
102        return tmp;
103}
104
[270]105mlstudent* ARX::predictor_student ( ) const {
106        int dim = est.dimension();
[198]107
[477]108        mat mu ( dim, V.rows() - dim );
109        mat R ( dim, dim );
[198]110        mlstudent* tmp;
[477]111        tmp = new mlstudent ( );
[198]112
[477]113        est.mean_mat ( mu, R ); //
[198]114        mu = mu.T();
[270]115
116        int xdim = dimx;
[477]117        int end = V._L().rows() - 1;
118        ldmat Lam ( V._L() ( xdim, end, xdim, end ), V._D() ( xdim, end ) );  //exp val of R
[198]119
120
[625]121        if ( have_constant) { // no constant term
[198]122                //Assume the constant term is the last one:
[477]123                if ( mu.cols() > 1 ) {
124                        tmp->set_parameters ( mu.get_cols ( 0, mu.cols() - 2 ), mu.get_col ( mu.cols() - 1 ), ldmat ( R ), Lam );
125                } else {
126                        tmp->set_parameters ( mat ( dim, 0 ), mu.get_col ( mu.cols() - 1 ), ldmat ( R ), Lam );
[270]127                }
[625]128        } else {
129                // no constant term
130                tmp->set_parameters ( mu, zeros ( xdim ), ldmat ( R ), Lam );
[198]131        }
[180]132        return tmp;
133}
134
[585]135
136
[97]137/*! \brief Return the best structure
138@param Eg a copy of GiW density that is being examined
139@param Eg0 a copy of prior GiW density before estimation
140@param Egll likelihood of the current Eg
141@param indeces current indeces
142\return best likelihood in the structure below the given one
143*/
144double egiw_bestbelow ( egiw Eg, egiw Eg0, double Egll, ivec &indeces ) { //parameter Eg is a copy!
145        ldmat Vo = Eg._V(); //copy
146        ldmat Vo0 = Eg._V(); //copy
147        ldmat& Vp = Eg._V(); // pointer into Eg
148        ldmat& Vp0 = Eg._V(); // pointer into Eg
[477]149        int end = Vp.rows() - 1;
[97]150        int i;
151        mat Li;
152        mat Li0;
[477]153        double maxll = Egll;
154        double tmpll = Egll;
155        double belll = Egll;
[97]156
157        ivec tmpindeces;
[477]158        ivec maxindeces = indeces;
[97]159
[115]160
[477]161        cout << "bb:(" << indeces << ") ll=" << Egll << endl;
[115]162
[97]163        //try to remove only one rv
[477]164        for ( i = 0; i < end; i++ ) {
[97]165                //copy original
166                Li = Vo._L();
167                Li0 = Vo0._L();
168                //remove stuff
[477]169                Li.del_col ( i + 1 );
170                Li0.del_col ( i + 1 );
171                Vp.ldform ( Li, Vo._D() );
172                Vp0.ldform ( Li0, Vo0._D() );
173                tmpll = Eg.lognc() - Eg0.lognc(); // likelihood is difference of norm. coefs.
[115]174
[477]175                cout << "i=(" << i << ") ll=" << tmpll << endl;
[170]176
[97]177                //
178                if ( tmpll > Egll ) { //increase of the likelihood
179                        tmpindeces = indeces;
180                        tmpindeces.del ( i );
181                        //search for a better match in this substructure
[477]182                        belll = egiw_bestbelow ( Eg, Eg0, tmpll, tmpindeces );
183                        if ( belll > maxll ) { //better match found
[97]184                                maxll = belll;
185                                maxindeces = tmpindeces;
186                        }
187                }
188        }
189        indeces = maxindeces;
190        return maxll;
191}
192
193ivec ARX::structure_est ( egiw est0 ) {
[477]194        ivec ind = linspace ( 1, est.dimension() - 1 );
195        egiw_bestbelow ( est, est0, est.lognc() - est0.lognc(), ind );
[97]196        return ind;
197}
[254]198
[577]199
200
201ivec ARX::structure_est_LT ( egiw est0 ) {
202        //some stuff with beliefs etc.
[585]203//      ivec ind = bdm::straux1(V,nu, est0._V(), est0._nu());
204        return ivec();//ind;
[577]205}
206
[477]207void ARX::from_setting ( const Setting &set ) {
[625]208        shared_ptr<RV> yrv = UI::build<RV> ( set, "rv", UI::compulsory );
[527]209        shared_ptr<RV> rrv = UI::build<RV> ( set, "rgr", UI::compulsory );
[357]210        int ylen = yrv->_dsize();
[625]211        // rgrlen - including constant!!!
[357]212        int rgrlen = rrv->_dsize();
[625]213       
214        set_rv ( *yrv, *rrv );
215       
[585]216        string opt;
217        if ( UI::get(opt, set,  "options", UI::optional) ) {
218                BM::set_options(opt);
219        }
[625]220        if (!UI::get(have_constant, set, "constant", UI::optional)){
221                have_constant=true;
222        }
223        if (have_constant) {rgrlen++;_dt=ones(rgrlen+ylen);}
[585]224
[357]225        //init
226        mat V0;
[412]227        vec dV0;
[625]228        if (!UI::get(V0, set, "V0",UI::optional)){
229                if ( !UI::get ( dV0, set, "dV0" ) )
230                        dV0 = concat ( 1e-3 * ones ( ylen ), 1e-5 * ones ( rgrlen ) );
231                V0 = diag ( dV0 );
232        }
[357]233        double nu0;
[477]234        if ( !UI::get ( nu0, set, "nu0" ) )
235                nu0 = rgrlen + ylen + 2;
[357]236
237        double frg;
[477]238        if ( !UI::get ( frg, set, "frg" ) )
[357]239                frg = 1.0;
240
[477]241        set_parameters ( frg );
242        set_statistics ( ylen, V0, nu0 );
[625]243       
[357]244        //name results (for logging)
[625]245        shared_ptr<RV> rv_par=UI::build<RV>(set, "rv_param",UI::optional );
246        if (!rv_par){
247                est.set_rv ( RV ( "{theta r }", vec_2 ( ylen*rgrlen, ylen*ylen ) ) );
248        } else {
249                est.set_rv ( *rv_par );
250        }
251        validate();
[270]252}
[357]253
254}
Note: See TracBrowser for help on using the browser.