root/mpdm/merg_pred.cpp @ 204

Revision 204, 4.0 kB (checked in by smidl, 16 years ago)

merger is now in logarithms + new merge_test

RevLine 
[161]1#include <estim/arx.h>
[198]2#include <estim/merger.h>
[161]3#include <stat/libEF.h>
4#include <stat/loggers.h>
[176]5//#include <stat/merger.h>
[161]6using namespace itpp;
7
8//These lines are needed for use of cout and endl
9using std::cout;
10using std::endl;
11
12int main() {
13        // Setup model
[162]14        RV y ( "{y }" );
15        RV u1 ( "{u1 }" );
16        RV u2 ( "{u2 }" );
[198]17        RV ym=y; ym.t(-1);
18        RV yy = y; yy.add(ym);
19        RV uu=u1; uu.add(u2);
[162]20
[161]21        // Full system
[198]22        vec thg ( "0.7 1 1 0" ); //Simulated system - zero for constant term
[162]23        //y=a y_t-1 + u1 + u2
[161]24        double sqr=0.1;
25        int ord = 1;
[162]26
[161]27        // Estimated systems ARX(2)
[198]28        RV thri ( "{thr_i }",vec_1 ( 3+1 ) );
29        RV thrg ( "{thr_g }",vec_1 ( 4+1 ) );
[161]30        // Setup values
[162]31
[161]32        //ARX constructor
[162]33        mat V0 = 0.001*eye ( thri.count() ); V0 ( 0,0 ) *= 10; //
34        mat V0g = 0.001*eye ( thrg.count() ); V0g ( 0,0 ) *= 10; //
[198]35        double nu0 = ord+6.0;
[204]36        double frg = 0.95;
[162]37
38        ARX P1 ( thri, V0, nu0, frg );
39        ARX P2 ( thri, V0, nu0, frg );
40        ARX PG ( thrg, V0g, nu0, frg );
[161]41        //Test estimation
[204]42        int ndat = 200;
[161]43        int t;
[162]44
[161]45        // Logging
[162]46        dirfilelog L ( "exp/merg",ndat );
[161]47
[198]48        int Li_Eth1 = L.add ( thri,"P1" );
49        int Li_Eth2 = L.add ( thri,"P2" );
50        int Li_Ethg = L.add ( thrg,"PG" );
51        int Li_Data = L.add ( RV ( "{Y U1 U2 }" ), "" );
52        int Li_LL   = L.add ( RV ( "{1 2 G }" ), "LL" );
[204]53        int Li_Pred = L.add ( RV ( "{1 2 G ar ge }" ), "Pred" );
[162]54
55        L.init();
56
57        vec Yt ( ndat );
[198]58        vec yt(1);
[162]59
60        Yt.set_subvector ( 0,randn ( ord ) ); //initial values
[198]61        vec rgrg ( thrg.count() -1); // constant terms are in!
62        vec rgr1 ( thri.count() -1);
63        vec rgr2 ( thri.count() -1);
[162]64
[198]65        vec PredLLs(5);
66        vec PostLLs(3);
67        vec PredLLs_m=zeros(5);
68        ivec ind_r1 = "0 1 3";
69        ivec ind_r2 = "0 2 3";
[162]70        for ( t=0; t<ndat; t++ ) {
[161]71                // True system
[162]72                if ( t>0 ) {
73                        rgrg ( 0 ) =Yt ( t-1 );
74                        rgrg ( 1 ) = pow(sin ( ( t/40.0 ) *pi ),3);
75                        rgrg ( 2 ) = pow(cos ( ( t/40.0 ) *pi ),3);
[198]76                        rgrg (3) = 1.0; // constant term
77                       
78                        rgr1(0) = rgrg(0); rgr1(1) = rgrg(1); rgr1(2) = rgrg(3); // no u2
79                        rgr2(0) = rgrg(0); rgr2(1) = rgrg(2); rgr2(2) = rgrg(3); // no u1
[162]80
81                        Yt ( t ) = thg*rgrg + sqr * NorRNG();
82
[198]83                        // Test predictors
84                        if (t>2){
[204]85                                mlnorm<ldmat>* P1p = P1.predictor(y,concat(ym,u1));
86                                mlnorm<ldmat>* P2p = P2.predictor(y,concat(ym,u2));
87                                mlnorm<ldmat>* Pgp = PG.predictor(y,concat(ym,uu));
[198]88                               
89                                Array<mpdf*> A(2); A(0)=P1p;A(1)=P2p;
90                                merger M(A);
91                                enorm<ldmat> g0(concat(yy,uu)); g0.set_parameters("0 0 0 0 ",3*eye(4));
[204]92                                M.set_parameters(10000000.0, 100,1);
[198]93                                M.merge(&g0);
94                               
95                                yt(0) = Yt(t);
96                                double P1pl = P1p->evalcond(yt,rgr1);   
97                                double P2pl = P2p->evalcond(yt,rgr2);   
98                                double PGpl = Pgp->evalcond(yt,rgrg);   
99                                {
[204]100                                        cout << "yt: " << yt << endl;
101                                        cout << "yt_1: " << P1p->_epdf().mean() << endl;
102                                        cout << "yt_2: " << P2p->_epdf().mean() << endl;
103                                        cout << "yt_G: " << P2p->_epdf().mean() << endl;
104                                }
105                                double cP1pl;
106                                double cP2pl;
107                                {
[198]108                                        ARX* Apred = (ARX*)M._Mix()._Coms(0);
109                                        enorm<ldmat>* MP= Apred->predictor(concat(yy,uu));
110                                        enorm<ldmat>* mP1p = (enorm<ldmat>*)MP->marginal(concat(yy,u1));
111                                        enorm<ldmat>* mP2p = (enorm<ldmat>*)MP->marginal(concat(yy,u2));
112                                        mlnorm<ldmat>* cP1p = (mlnorm<ldmat>*)mP1p->condition(y);
113                                        mlnorm<ldmat>* cP2p = (mlnorm<ldmat>*)mP2p->condition(y);
114
[204]115                                        cP1pl = cP1p->evalcond(yt,rgr1);       
116                                        cP2pl = cP2p->evalcond(yt,rgr2);       
117                               
118                                        cout << "ytm1: " << cP1p->_epdf().mean() << endl;
119                                        cout << "ytm2: " << cP2p->_epdf().mean() << endl;
[198]120                                }
121
[204]122                                PredLLs *=frg;
123                                PredLLs += log(concat(vec_3(P1pl, P2pl, PGpl), vec_2(cP1pl,cP2pl)));
124                                L.logit(Li_Pred, PredLLs); //log-normal
[198]125                               
126                                delete P1p;
127                                delete P2p;
128                                delete Pgp;
129                        }
130                       
[162]131                        // 1st
[198]132                        P1.bayes ( concat(Yt(t),rgr1) );
[162]133                        // 2nd
[198]134                        P2.bayes ( concat(Yt(t),rgr2) );
[162]135
136                        //Global
137                        PG.bayes ( concat ( Yt ( t ),rgrg ) );
138                       
139                        //Merger
140                }
[198]141                L.logit ( Li_Eth1, P1._epdf().mean() );
142                L.logit ( Li_Eth2, P2._epdf().mean() );
143                L.logit ( Li_Ethg, PG._epdf().mean() );
144                L.logit ( Li_Data, vec_3 ( Yt ( t ), rgrg ( 1 ), rgrg ( 2 ) ) );
145                PostLLs *= frg;
146                PostLLs += vec_3 ( P1._ll(), P2._ll(), PG._ll() );
147                L.logit ( Li_LL, PostLLs );
[162]148                L.step (  );
[161]149        }
[162]150        L.finalize( );
151        L.itsave ( "merg.it" );
[161]152}
Note: See TracBrowser for help on using the browser.