root/applications/robust/main.cpp @ 1401

Revision 1401, 8.5 kB (checked in by sindj, 13 years ago)

Nevim, co se zmenilo, dodelani experimentu s maxlik odhady asi. JS

RevLine 
[976]1
2/*!
3\file
4\brief Robust
5\author Vasek Smidl
6
7 */
8
[1337]9#include "estim/arx.h"
[976]10#include "robustlib.h"
[1216]11#include <vector>
[1284]12#include <iostream>
[1282]13#include <fstream>
[1376]14//#include <itpp/itsignal.h>
[1361]15#include "windows.h"
16#include "ddeml.h"
17#include "stdio.h"
[1282]18
[1361]19//#include "DDEClient.h"
20//#include <conio.h>
[1358]21
[1361]22
[1208]23using namespace itpp;
[1337]24using namespace bdm;
[976]25
[1361]26//const int emlig_size = 2;
27//const int utility_constant = 5;
[1268]28
[1401]29const int max_model_order = 2;
[1389]30const double apriorno     = 0.01;
[1401]31const int max_window_size = 30;
[1272]32
[1361]33class model
34{
35public:
[1376]36        set<pair<int,int>> ar_components;
[1358]37
[1361]38        // Best thing would be to inherit the two models from a single souce, this is planned, but now structurally
39        // problematic.
[1376]40        RARX* my_rarx; //vzmenovane parametre pre triedu model
[1379]41        ARXwin* my_arx;
[1361]42
43        bool has_constant;
[1376]44        int  window_size;  //musi byt vacsia ako pocet krokov ak to nema ovplyvnit
[1361]45        int  predicted_channel;
46        mat* data_matrix;
[1383]47        vec  predictions;
[1393]48        char name[80];
[1361]49       
[1376]50        model(set<pair<int,int>> ar_components, //funkcie treidz model-konstruktor
[1361]51                  bool robust, 
52                  bool has_constant, 
53                  int window_size, 
[1376]54                  int predicted_channel,
[1361]55                  mat* data_matrix)
[1358]56        {
[1376]57                this->ar_components.insert(ar_components.begin(),ar_components.end());
[1393]58
59                strcpy(name,"M");
60
61                for(set<pair<int,int>>::iterator ar_ref = ar_components.begin();ar_ref!=ar_components.end();ar_ref++)
62                {
63                        char buffer1[2];
64                        char buffer2[2];
65                        itoa((*ar_ref).first,buffer1,10);
66                        itoa((*ar_ref).second,buffer2,10);
67
68                        strcat(name,buffer1);
69                        strcat(name,buffer2);
70                        strcat(name,"_");
71                }
72
[1376]73                this->has_constant      = has_constant;
[1393]74
75                if(has_constant)
76                {
77                        strcat(name,"C");
78                }
79
[1376]80                this->window_size       = window_size;
81                this->predicted_channel = predicted_channel;
82                this->data_matrix       = data_matrix;
[1361]83
84                if(robust)
85                {
[1393]86                        strcat(name,"R");
87
[1361]88                        if(has_constant)
89                        {
[1395]90                                my_rarx = new RARX(ar_components.size()+1,window_size,true,sqrt(2*apriorno),sqrt(2*apriorno),ar_components.size()+4);
[1361]91                                my_arx  = NULL;
92                        }
[1376]93                else
[1361]94                        {
[1395]95                                my_rarx = new RARX(ar_components.size(),window_size,false,sqrt(2*apriorno),sqrt(2*apriorno),ar_components.size()+3);
[1361]96                                my_arx  = NULL;
97                        }
98                }
99                else
100                {
101                        my_rarx = NULL;
[1379]102                        my_arx  = new ARXwin();
[1361]103                        mat V0;
104
105                        if(has_constant)
106                        {                               
[1376]107                                V0  = apriorno * eye(ar_components.size()+2); //aj tu konst
[1384]108                                //V0(0,0) = 0;
[1379]109                                my_arx->set_constant(true);                             
[1361]110                        }
111                        else
112                        {
113                               
[1376]114                                V0  = apriorno * eye(ar_components.size()+1);//menit konstantu
[1396]115                                //V0(0,1) = -0.01;
116                                //V0(1,0) = -0.01;
[1361]117                                my_arx->set_constant(false);                           
118                               
119                        }
120
[1384]121                        my_arx->set_statistics(1, V0, V0.rows()+2);                     
[1379]122                        my_arx->set_parameters(window_size);
[1361]123                        my_arx->validate();
[1396]124
125                        vec mean = my_arx->posterior().mean();
126                        cout << mean << endl;
[1361]127                }
[1358]128        }
[1361]129
[1376]130        void data_update(int time) //vlozime cas a ono vlozi do data_vector podmineky(conditions) a predikce, ktore pouzije do bayes
[1358]131        {
[1376]132                vec data_vector;
133                for(set<pair<int,int>>::iterator ar_iterator = ar_components.begin();ar_iterator!=ar_components.end();ar_iterator++)
[1401]134                { 
[1376]135                        data_vector.ins(data_vector.size(),(*data_matrix).get(ar_iterator->first,time-ar_iterator->second));
[1358]136                }
[1376]137                if(my_rarx!=NULL)
[1401]138                {       
[1376]139                        data_vector.ins(0,(*data_matrix).get(predicted_channel,time));
140                        my_rarx->bayes(data_vector);
141                }
[1358]142                else
143                {
[1401]144                        vec pred_vec;
[1376]145                        pred_vec.ins(0,(*data_matrix).get(predicted_channel,time));
146                        my_arx->bayes(pred_vec,data_vector);
[1361]147                }
148        }
149
[1376]150        pair<vec,vec> predict(int sample_size, int time, itpp::Laplace_RNG* LapRNG)  //nerozumiem, ale vraj to netreba, nepouziva to
[1367]151        {
[1376]152                vec condition_vector;
153                for(set<pair<int,int>>::iterator ar_iterator = ar_components.begin();ar_iterator!=ar_components.end();ar_iterator++)
[1367]154                {
[1376]155                        condition_vector.ins(condition_vector.size(),(*data_matrix).get(ar_iterator->first,time-ar_iterator->second+1));
156                }
[1367]157
[1376]158                if(my_rarx!=NULL)
159                {
[1396]160                        pair<vec,mat> imp_samples = my_rarx->posterior->sample(sample_size,false);
[1367]161
[1393]162                        //cout << imp_samples.first << endl;                   
[1376]163                       
164                        vec sample_prediction;                 
165                        for(int t = 0;t<sample_size;t++)
[1367]166                        {
[1376]167                                vec lap_sample = condition_vector;
[1367]168                               
[1376]169                                if(has_constant)
[1367]170                                {
[1376]171                                        lap_sample.ins(lap_sample.size(),1.0);
[1367]172                                }
[1376]173                               
[1393]174                                lap_sample.ins(lap_sample.size(),(*LapRNG)());
[1367]175
[1376]176                                sample_prediction.ins(0,lap_sample*imp_samples.second.get_col(t));                             
[1367]177                        }
178
[1376]179                        return pair<vec,vec>(imp_samples.first,sample_prediction);
[1393]180                }       
[1376]181                else
182                {
[1383]183                        mat samples = my_arx->posterior().sample_mat(sample_size);                     
[1376]184                       
185                        vec sample_prediction;
186                        for(int t = 0;t<sample_size;t++)
187                        {
188                                vec gau_sample = condition_vector;
[1367]189                               
[1376]190                                if(has_constant)
[1367]191                                {
[1376]192                                        gau_sample.ins(gau_sample.size(),1.0);
[1367]193                                }
[1376]194                               
[1383]195                                gau_sample.ins(gau_sample.size(),randn());
[1367]196
[1376]197                                sample_prediction.ins(0,gau_sample*samples.get_col(t));                         
[1367]198                        }
[1376]199
200                        return pair<vec,vec>(ones(sample_prediction.size()),sample_prediction);
[1367]201                }
202       
203        }
204
205
[1376]206        static set<set<pair<int,int>>> possible_models_recurse(int max_order,int number_of_channels)
[1361]207        {
[1376]208                set<set<pair<int,int>>> created_model_types;           
[1361]209
[1401]210                if(max_order == 1)
[1361]211                {                       
[1401]212                        for(int channel = 0;channel<number_of_channels;channel++)
[1358]213                        {
[1376]214                                set<pair<int,int>> returned_type;
[1401]215                                returned_type.insert(pair<int,int>(channel,1)); 
[1376]216                                created_model_types.insert(returned_type);
[1358]217                        }
[1361]218
219                        return created_model_types;
220                }
221                else
222                {
[1401]223                        created_model_types = possible_models_recurse(max_order-1,number_of_channels);
[1376]224                        set<set<pair<int,int>>> returned_types;
[1361]225                       
[1376]226                        for(set<set<pair<int,int>>>::iterator model_ref = created_model_types.begin();model_ref!=created_model_types.end();model_ref++)
[1361]227                        {                               
228                               
229                                for(int order = 1; order<=max_order; order++)
[1358]230                                {
[1361]231                                        for(int channel = 0;channel<number_of_channels;channel++)
232                                        {
[1376]233                                                set<pair<int,int>> returned_type;
[1401]234                                                pair<int,int> new_pair = pair<int,int>(channel,order);
235                                                if(find((*model_ref).begin(),(*model_ref).end(),new_pair)==(*model_ref).end()) 
[1361]236                                                {
[1401]237                                                        returned_type.insert((*model_ref).begin(),(*model_ref).end()); 
[1376]238                                                        returned_type.insert(new_pair);
239                                                       
240
241                                                        returned_types.insert(returned_type);                                                   
[1361]242                                                }
243                                        }
[1358]244                                }
245                        }
[1361]246
[1376]247                        created_model_types.insert(returned_types.begin(),returned_types.end());
[1361]248
249                        return created_model_types;
250                }               
[1358]251        }
[1361]252};
253
254
255
256
[1383]257int main ( int argc, char* argv[] ) 
258{
[1376]259        vector<vector<string>> strings;
[1301]260
[1401]261        char* file_string =  "C:\\results\\normalM"; // "C:\\dataADClosePercDiff"; // 
[1301]262
[1376]263        char dfstring[80];
264        strcpy(dfstring,file_string);
265        strcat(dfstring,".txt");
266       
267       
268        mat data_matrix;
269        ifstream myfile(dfstring);
270        if (myfile.is_open())
271        {               
272                string line;
273                while(getline(myfile,line))
274                {                       
275                        vec data_vector;
[1401]276                        while(line.find(',') != string::npos) 
277                        {                               
[1376]278                                int loc2 = line.find('\n');
279                                int loc  = line.find(',');
280                                data_vector.ins(data_vector.size(),atof(line.substr(0,loc).c_str()));                           
281                                line.erase(0,loc+1);                                   
282                        }
[1301]283
[1376]284                        data_matrix.ins_row(data_matrix.rows(),data_vector);
285                }               
[1361]286
[1376]287                myfile.close(); 
288        }
289        else
290        {
291                cout << "Can't open data file!" << endl;
292        }
[1365]293
[1401]294        set<pair<int,int>> model_type;
295        model_type.insert(pair<int,int>(0,1));
296        model_type.insert(pair<int,int>(0,2));
[1376]297
[1401]298        vector<model*> models;
299
[1383]300        ofstream myfilew;
[1401]301               
[1383]302
[1401]303        while(data_matrix.rows()!=0)
304        {
305                for(int i=0;i<models.size();i++)
[1396]306                {
[1401]307                        delete models[i];
[1396]308                }
309               
[1401]310                models.clear();
311                models.push_back(new model(model_type,true,false,max_window_size,0,&data_matrix));
312                models.push_back(new model(model_type,false,false,max_window_size,0,&data_matrix));
[1379]313
[1401]314                for(int time = max_model_order;time<max_window_size;time++) //time<data_matrix.cols()
315                {               
316                        vec cur_res_lognc;
317               
318                        vector<string> nazvy;
319                        for(vector<model*>::iterator model_ref = models.begin();model_ref!=models.end();model_ref++)
[1376]320                        {
[1401]321                                (*model_ref)->data_update(time);
[1396]322
[1401]323                                cout << "Updated:" << time << endl;     
324
325                                if(time == max_window_size-1)
[1396]326                                {
[1401]327                                        char fstring[80];                                       
328                                        strcpy(fstring,file_string);
329                                        strcat(fstring,"ml");                                   
330                                        strcat(fstring,(*model_ref)->name);
331                                        strcat(fstring,".txt");
[1396]332
[1401]333                                        vec coords;
334                                        if((*model_ref)->my_arx!=NULL)
[1396]335                                        {
[1401]336                                                coords = (*model_ref)->my_arx->posterior().est_theta();
337                                        }
338                                        else
339                                        {
340                                                coords = (*model_ref)->my_rarx->posterior->minimal_vertex->get_coordinates();                                           
341                                        }
[1396]342
[1401]343                                        myfilew.open(fstring,ios::app);
[1396]344
[1401]345                                        for(int i=0;i<coords.size();i++)
346                                        {
347                                                myfilew << coords.get(i) << ",";
348                                        }
349                                        myfilew << endl;
[1367]350
[1401]351                                        myfilew.close();                               
[1383]352                                }
[1376]353                        }
[1383]354                }
[1301]355
[1401]356                data_matrix.del_row(0);
357        }
[1383]358       
[1301]359
[1337]360
361       
[1301]362
[1376]363        return 0;
364}
[976]365
[1282]366
Note: See TracBrowser for help on using the browser.