root/mpdm/merg_pred.cpp @ 198

Revision 198, 3.8 kB (checked in by smidl, 16 years ago)

opravy + zavedeni studenta + zakomentovani debug v mergeru

Line 
1#include <estim/arx.h>
2#include <estim/merger.h>
3#include <stat/libEF.h>
4#include <stat/loggers.h>
5//#include <stat/merger.h>
6using namespace itpp;
7
8//These lines are needed for use of cout and endl
9using std::cout;
10using std::endl;
11
12int main() {
13        // Setup model
14        RV y ( "{y }" );
15        RV u1 ( "{u1 }" );
16        RV u2 ( "{u2 }" );
17        RV ym=y; ym.t(-1);
18        RV yy = y; yy.add(ym);
19        RV uu=u1; uu.add(u2);
20
21        // Full system
22        vec thg ( "0.7 1 1 0" ); //Simulated system - zero for constant term
23        //y=a y_t-1 + u1 + u2
24        double sqr=0.1;
25        int ord = 1;
26
27        // Estimated systems ARX(2)
28        RV thri ( "{thr_i }",vec_1 ( 3+1 ) );
29        RV thrg ( "{thr_g }",vec_1 ( 4+1 ) );
30        // Setup values
31
32        //ARX constructor
33        mat V0 = 0.001*eye ( thri.count() ); V0 ( 0,0 ) *= 10; //
34        mat V0g = 0.001*eye ( thrg.count() ); V0g ( 0,0 ) *= 10; //
35        double nu0 = ord+6.0;
36        double frg = 1.0;//0.99;
37
38        ARX P1 ( thri, V0, nu0, frg );
39        ARX P2 ( thri, V0, nu0, frg );
40        ARX PG ( thrg, V0g, nu0, frg );
41        //Test estimation
42        int ndat = 500;
43        int t;
44
45        // Logging
46        dirfilelog L ( "exp/merg",ndat );
47
48        int Li_Eth1 = L.add ( thri,"P1" );
49        int Li_Eth2 = L.add ( thri,"P2" );
50        int Li_Ethg = L.add ( thrg,"PG" );
51        int Li_Data = L.add ( RV ( "{Y U1 U2 }" ), "" );
52        int Li_LL   = L.add ( RV ( "{1 2 G }" ), "LL" );
53        int Li_Pred = L.add ( RV ( "{1 2 G ar ge ln}" ), "Pred" );
54
55        L.init();
56
57        vec Yt ( ndat );
58        vec yt(1);
59
60        Yt.set_subvector ( 0,randn ( ord ) ); //initial values
61        vec rgrg ( thrg.count() -1); // constant terms are in!
62        vec rgr1 ( thri.count() -1);
63        vec rgr2 ( thri.count() -1);
64
65        vec PredLLs(5);
66        vec PostLLs(3);
67        vec PredLLs_m=zeros(5);
68        vec PostLLs_m=zeros(3);
69        ivec ind_r1 = "0 1 3";
70        ivec ind_r2 = "0 2 3";
71        for ( t=0; t<ndat; t++ ) {
72                // True system
73                if ( t>0 ) {
74                        rgrg ( 0 ) =Yt ( t-1 );
75                        rgrg ( 1 ) = pow(sin ( ( t/40.0 ) *pi ),3);
76                        rgrg ( 2 ) = pow(cos ( ( t/40.0 ) *pi ),3);
77                        rgrg (3) = 1.0; // constant term
78                       
79                        rgr1(0) = rgrg(0); rgr1(1) = rgrg(1); rgr1(2) = rgrg(3); // no u2
80                        rgr2(0) = rgrg(0); rgr2(1) = rgrg(2); rgr2(2) = rgrg(3); // no u1
81
82                        Yt ( t ) = thg*rgrg + sqr * NorRNG();
83
84                        // Test predictors
85                        if (t>2){
86                                mlnorm<ldmat>* P1p = P1.predictor_student(y,concat(ym,u1));
87                                mlnorm<ldmat>* P2p = P2.predictor_student(y,concat(ym,u2));
88                                mlnorm<ldmat>* Pgp = PG.predictor_student(y,concat(ym,uu));
89                               
90                                Array<mpdf*> A(2); A(0)=P1p;A(1)=P2p;
91                                merger M(A);
92                                enorm<ldmat> g0(concat(yy,uu)); g0.set_parameters("0 0 0 0 ",3*eye(4));
93                                M.set_parameters(1.2, 100,1);
94                                M.merge(&g0);
95                               
96                                yt(0) = Yt(t);
97                                double P1pl = P1p->evalcond(yt,rgr1);   
98                                double P2pl = P2p->evalcond(yt,rgr2);   
99                                double PGpl = Pgp->evalcond(yt,rgrg);   
100                                PredLLs(0) = log(P1pl); PredLLs(1) = log(P2pl); PredLLs(2) = log(PGpl);
101                                {
102                                        ARX* Apred = (ARX*)M._Mix()._Coms(0);
103                                        enorm<ldmat>* MP= Apred->predictor(concat(yy,uu));
104                                        enorm<ldmat>* mP1p = (enorm<ldmat>*)MP->marginal(concat(yy,u1));
105                                        enorm<ldmat>* mP2p = (enorm<ldmat>*)MP->marginal(concat(yy,u2));
106                                        mlnorm<ldmat>* cP1p = (mlnorm<ldmat>*)mP1p->condition(y);
107                                        mlnorm<ldmat>* cP2p = (mlnorm<ldmat>*)mP2p->condition(y);
108
109                                        double cP1pl = cP1p->evalcond(yt,rgr1); 
110                                        double cP2pl = cP2p->evalcond(yt,rgr2); 
111                                        PredLLs(3) = log(cP1pl);
112                                        PredLLs(4) = log(cP2pl);
113                                }
114
115                                L.logit(Li_Pred, PredLLs_m*frg+ PredLLs); //log-normal
116                                PredLLs_m *=frg;
117                                PredLLs_m += PredLLs;
118                               
119                                delete P1p;
120                                delete P2p;
121                                delete Pgp;
122                        }
123                       
124                        // 1st
125                        P1.bayes ( concat(Yt(t),rgr1) );
126                        // 2nd
127                        P2.bayes ( concat(Yt(t),rgr2) );
128
129                        //Global
130                        PG.bayes ( concat ( Yt ( t ),rgrg ) );
131                       
132                        //Merger
133                }
134                L.logit ( Li_Eth1, P1._epdf().mean() );
135                L.logit ( Li_Eth2, P2._epdf().mean() );
136                L.logit ( Li_Ethg, PG._epdf().mean() );
137                L.logit ( Li_Data, vec_3 ( Yt ( t ), rgrg ( 1 ), rgrg ( 2 ) ) );
138                PostLLs *= frg;
139                PostLLs += vec_3 ( P1._ll(), P2._ll(), PG._ll() );
140                L.logit ( Li_LL, PostLLs );
141                L.step (  );
142        }
143        L.finalize( );
144        L.itsave ( "merg.it" );
145}
Note: See TracBrowser for help on using the browser.