root/library/bdm/design/lq_ctrl.h @ 1190

Revision 1155, 11.8 kB (checked in by vahalam, 14 years ago)

first test version of LQG_universal class
using own linfnEx - extension of linfn (+ return rv)

Line 
1#include "ctrlbase.h"
2
3namespace bdm {
4
5//! extended class representing function \f$f(x) = Ax+B\f$
6class linfnEx: public linfn {
7  public:
8    //! Identification of returned value \f$f(x)\f$
9    RV rv_ret;
10        //!default constructor
11    linfnEx ( ) : linfn() { };
12    linfnEx ( const mat &A0, const vec &B0 ) : linfn(A0, B0) { };
13};
14
15  //! Universal LQG controller
16class LQG_universal : public Controller{
17public:
18        //! Controller inputs
19                //! loss function
20                Array<quadraticfn> Losses;
21                //! loss in final time
22                quadraticfn finalLoss;
23                //! model of evolutin
24                Array<linfnEx> Models;
25                RV model_rv_ret;
26
27                //! control law rv is public member in Controller class 
28                //! input data rvc is protected member in Controller class
29                 
30                //! control horizon
31                int horizon;
32
33        //! Constructor
34                LQG_universal() {
35                        horizon = 0;
36                        curtime = -1;
37                }
38               
39  protected:
40    //! control law: rv = L [rvc, 1]
41    mat L;
42   
43    //! Matrix pre_qr
44    mat pre_qr;
45   
46    //! Matrix post_qr
47    mat post_qr;
48
49        //! time+1 optimized loss to be added to current one
50        mat tolC;
51
52        int curtime;
53   
54public:
55    //! function redesigning the control strategy
56    virtual void redesign() {
57                if (curtime == -1) curtime = horizon;
58               
59                if (curtime > 0) curtime--;
60
61                if (curtime >= 0){
62                        generateLmat(curtime);
63                }
64
65                //report time 0 reached - LQG designing complete
66                //if(curtime == 0) cout << "time 0 reached" << endl;
67        }
68    //! returns designed control action
69    virtual vec ctrlaction ( const vec &cond ) const {
70        return L * concat(cond, 1.0);
71    }
72
73    void from_setting ( const Setting &set ) {
74      UI::get(Losses, set, "losses",UI::compulsory);
75      UI::get(Models, set, "models",UI::compulsory);
76    }
77    //! access function
78    const RV& _rv() {
79        return rv;
80    }
81    //! access function
82    const RV& _rvc() {
83        return rvc;
84    }
85
86        void set_rvc(RV _rvc) {rvc = _rvc;}
87
88    //! register this controller with given datasource under name "name"
89    virtual void log_register ( logger &L, const string &prefix ) { }
90    //! write requested values into the logger
91    virtual void log_write ( ) const { }
92
93        //! access debug function
94        mat getL(){ return L; }
95
96private:
97        RV trs_crv;
98
99        //! compute complete RV from all of RVs used in Losses array
100        /*RV getCompleteRV() {
101                RV cRv; //complete RV
102
103                //cRv has form [rv, "other", 1]
104                cRv = rv; //add rv
105
106                //add "other"
107                for(int i = 0; i < Losses.size(); i++)
108                        cRv.add(Losses(i).rv);
109
110                cRv.add(rvOne); //add 1
111
112                return cRv;
113        }*/
114
115        mat getMatRow (quadraticfn sourceQfn){ //returns row of matrixes crated from quadratic function
116               
117                mat tmpMatRow; //tmp variable for row of matrixes to be returned       
118                tmpMatRow.set_size(sourceQfn.Q.rows(), trs_crv.countsize()); //(rows, cols)
119                tmpMatRow.zeros();
120
121                //set data in tmpMatRow - other times then current replace using model
122                RV tmpQrv = sourceQfn.rv;
123                for(int j = 0; j < tmpQrv.length(); j++){
124                        ivec j_vec(1);
125                        j_vec(0) = j;   
126                        if( (tmpQrv.time(j) == 0) && (sum(tmpQrv(j_vec).findself(trs_crv)) > (-1)) ) {//sum is only formal, summed vector is in fact scalar
127                                //jth element of tmpQrv is also element of trs_crv with a proper time
128                                ivec copytarget = (tmpQrv(j_vec)).dataind(trs_crv); //get target column position in tmpMatRow
129                                ivec copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q
130                                if(copytarget.size() != copysource.size()) {return mat(0); /*error*/}
131                                vec copycol;
132                                for(int k = 0; k < copysource.size(); k++){
133                                        copycol = sourceQfn.Q._Ch().get_col(copysource(k));
134                                        tmpMatRow.set_col(copytarget(k), copycol);
135                                }                                       
136                        }
137                        else {
138//cout << "USING MODEL" << endl;
139                                //jth tmpQrv element is not in trs_crv -> using Model to teplace it
140                                ivec copysource;// = (tmpQrv(j_vec)).findself(tmpQrv); //get source column position in Losses(i).Q
141
142                                //int selectedModel = -1;
143
144                                //int k;
145                                ////find first usable replacement in Model                             
146                                //for(k = 0; k < Models.size(); k++){                                                   
147                                //      if( sum((tmpQrv(j_vec)).findself(Models(k).rv_ret)) > (-1) ){
148                                //      //TODO is tmpQrv(j) in kth Models RV
149                                //              //if( (Models(k)).rv_ret.findself_ids(trs_crv) ){//???????????????????
150                                //              // is kth Models rv_ret subset of trs_crv
151
152                                //              selectedModel = k;
153                                //              break;
154
155                                //              //}
156                                //      }
157                                //}
158                               
159                                //!model is model_rv_ret = sum(Array<linfn>) = sum( A1*rv + B1 + ... + An*rv + Bn)
160                               
161                                //if(selectedModel == -1) {cout << "NO MODEL" << endl;return mat("0");}//ERROR - inconsistent model data;                               
162                        //!!TODO!!if(NOT((tmpQrv(j_vec)).findself(model_rv_ret))) {cout << "NO MODEL" << endl;return mat("0");}//ERROR - inconsistent model data;                               
163                                //use kth Model to convert tmpQrv memeber to trs_crv memeber
164
165                                //get submatrix from Q which represents jth tmpQrv data
166                                copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q                               
167                                mat copysubmat;
168                                copysubmat.set_size(sourceQfn.Q.rows(), copysource.size()); //(rows, cols)
169                                copysubmat.zeros();
170                                vec copycol;
171                                int k;
172                                for(k = 0; k < copysource.size(); k++){
173                                        copycol = sourceQfn.Q._Ch().get_col(copysource(k));
174                                        copysubmat.set_col(k, copycol);
175                                }
176
177                                //check every Models element if it is a proper substitution: tmpQrv(j_vec) memeber of rv_ret
178                                for(k = 0; k < Models.size(); k++){
179                                        if( sum((tmpQrv(j_vec)).findself(Models(k).rv_ret)) > (-1) ){ //formal sum, find usable model
180                                                //check if model is correct
181                                                ivec check = (Models(k).rv).findself(trs_crv);
182                                                if(sum(check) <= -check.size()){
183                                                        bdm_assert (false , "Incorrect Model: Unusable Models element!" );                                               
184                                                        continue;
185                                                }
186
187                                                //create transformed submatrix
188                                                mat transsubmat = copysubmat * ((Models(k)).A);
189
190                                                //put them on a right place in tmpQrv
191                                                ivec copytarget = (Models(k)).rv.dataind(trs_crv); //get target column position in tmpMatRow
192                                                                                               
193                                                //copy transsubmat into tmpMatRow with adding to current one
194                                                //      tmpMatRow(new) = tmpMatRow(old) + transsubmat /all in proper indices
195                                                int l;
196                                                for(l = 0; l < copysource.size(); l++){
197                                                        copycol = tmpMatRow.get_col(copytarget(l));
198                                                        copycol += transsubmat.get_col(l);                                     
199                                                        tmpMatRow.set_col(copytarget(l), copycol);                                     
200                                                }
201
202                                                //if linear fnc constant element vec B is nonzero
203                                                vec constElB = (Models(k)).B;                           
204                                                if(prod(constElB) != 0){
205                                                        //copy transformed constant vec into last (1's) col in tmpMatRow
206                                                        int lastcol = tmpMatRow.cols() - 1;
207                                                        copycol = tmpMatRow.get_col(lastcol);
208                                                        copycol += (copysubmat * ((Models(k)).B));
209                                                        tmpMatRow.set_col(lastcol, copycol);
210                                                }
211                                        }
212
213                                }
214
215                               
216                        }
217                }
218
219                return tmpMatRow;
220        }
221
222        //! create first(0) or other (1) pre_qr matrix
223        void build_pre_qr(bool next) {
224                int i;
225                //used fake quadratic function from tolC matrix
226                quadraticfn fakeQfn;
227
228                //RV pretrs_crv = getCompleteRV(); // crv before transformation based on Losses array
229
230                //set proper size of pre_qr matrix
231                int rows = 0;
232                for(i = 0; i < Losses.size(); i++)
233                        rows += Losses(i).Q.rows();
234                if(!next) rows += finalLoss.Q.rows();
235                else{
236                        //used fake quadratic function from tolC matrix
237                        //setup fakeQfn
238                        fakeQfn.Q.setCh(tolC);
239                        RV fakeM1;
240                        fakeM1 = rvc;
241                        fakeM1.add(RV("1", 1, 0));
242                        fakeM1.t_plus(1); //RV in time t+1 => necessary use of Model to get RV in time t
243                        fakeQfn.rv = fakeM1;
244                       
245                        rows += fakeQfn.Q.rows();
246                }
247//cout << "buildpreqr trscrv: " << trs_crv << " of size " << trs_crv.countsize() << endl;
248                pre_qr.set_size(rows, trs_crv.countsize()); //(rows, cols)
249                pre_qr.zeros();
250
251                //fill pre_qr matrix for each Losses quadraticfn               
252                int rowIndex = 0;
253                mat tmpMatRow;
254                for(i = 0; i < Losses.size(); i++) {
255                        rows = Losses(i).Q.rows();
256                       
257                        //compute row matrix and insert it on proper place in pre_qr
258                        tmpMatRow = getMatRow(Losses(i));
259//cout << "tmpMatRow no " << i << endl << tmpMatRow << endl;
260                        //copy tmpMatRow in pre_qr
261
262                /*cout << "submatrix ( " << rowIndex << ", " <<
263                        (rowIndex + rows - 1) << ", 0, " << (trs_crv.countsize() - 1) << ")" << endl;
264                cout << "seting with submatrix of rows " << tmpMatRow.rows() << " and cols " << tmpMatRow.cols() <<
265                        "and data " << endl << tmpMatRow << endl;*/
266
267                        pre_qr.set_submatrix(rowIndex, (rowIndex + rows - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)
268                        rowIndex += rows;
269                }
270
271                if(!next) {
272                        tmpMatRow = getMatRow(finalLoss);
273                        pre_qr.set_submatrix(rowIndex, (rowIndex + finalLoss.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)               
274                }
275                else { //next
276                        //based on tolC but time must be shifted by one - all implemented in getMatRow method
277                               
278                        //get matrix row via getMatRow method
279                        tmpMatRow = getMatRow(fakeQfn);
280                /*cout << "submatrix ( " << rowIndex << ", " <<
281                        (rowIndex + fakeQfn.Q.rows() - 1) << ", 0, " << (trs_crv.countsize() - 1) << ")" << endl;
282                cout << "seting with submatrix of rows " << tmpMatRow.rows() << " and cols " << tmpMatRow.cols() <<
283                        "and data " << endl << tmpMatRow << endl;*/
284                        pre_qr.set_submatrix(rowIndex, (rowIndex + fakeQfn.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)         
285                //cout << "NEXT 3" << endl;
286                }
287//cout << "last tmpMatRow  " << endl << tmpMatRow << endl;
288        }       
289
290        mat get_qr_submatrix(int submatidx) {
291        /*
292                |rv||rvc||1|
293
294                AAAABBBBBBBB
295                 AAABBBBBBBB
296                  AABBBBBBBB
297                   ABBBBBBBB
298                    CCCCCCCC
299                         CCCCCCC
300                          CCCCCC
301                           CCCCC
302                            CCCC
303                                 CCC
304                                  CC
305                                   C
306        */
307        /*!
308                submatidx | get_submatrix
309                ----------|--------------
310                    0     |      A
311                        1     |          B
312                        2+        |              C
313        */
314                int sizeA = rv.countsize();
315                int colsB = post_qr.cols() -  sizeA;
316                //  rowsB = sizeA;
317                //  colsC = colsB;
318                //not required whole C - it is triangular
319                //=> NOT int rowsC = post_qr.rows() - sizeA;
320                //=> int sizeC = colsB;
321
322                mat qr_submat;
323               
324                if(submatidx == 0) qr_submat            = post_qr.get(0,                (sizeA - 1),                    0,              (sizeA - 1));  //(int r1, int r2, int c1, int c2)
325                else if(submatidx == 1) qr_submat       = post_qr.get(0,                (sizeA - 1),                    sizeA,  (post_qr.cols() - 1));
326                else qr_submat                                          = post_qr.get(sizeA,    (sizeA + colsB - 1),    sizeA,  (post_qr.cols() - 1));
327       
328                return qr_submat;
329        }
330
331        void generateLmat(int timestep){
332                //! control strategy matrix L is based on loss in time:
333                //!             time = horizon                  loss = finalLoss
334                //!             time = horizon - 1              loss = sum(Losses)(time) + finalLoss
335                //!             time = horizon - k > 1  loss = sum(Losses)(time) + tolC time+1 loss
336
337                trs_crv = rv; //transformed crv only in proper times, form [rv, rvc, 1]
338                trs_crv.add(rvc);
339                trs_crv.add(RV("1", 1, 0));
340               
341                //!first time, time = horizon - 1
342                if(timestep == (horizon-1))             
343                        build_pre_qr(0);
344               
345                //!other times         
346                else
347                        build_pre_qr(1);
348
349                mat tmpQ;               
350                qr(pre_qr, tmpQ, post_qr);
351//cout << "preQR " << pre_qr << endl << "postQR" << post_qr << endl;
352                mat qrA = get_qr_submatrix(0);
353                mat qrB = get_qr_submatrix(1);
354                mat qrC = get_qr_submatrix(2);
355//cout << "A " << qrA << "\n B " << qrB << "\n C " << qrC << endl;
356
357                L = - inv(qrA)*qrB; ///////// INVERSE OF TRIANGLE MATRIX! better?
358                tolC = qrC;
359        }
360
361};
362
363class LQG_recedinghorizon : public LQG_universal {
364protected:
365        //!total_curtime is curtime for total_horizon
366        int total_curtime;
367public:
368        //! LQG_universal::horizon means shorter receding horizon for designing control strategy
369        //! total_horizon is longer total horizon
370        int total_horizon;
371       
372        //!constructor
373        LQG_recedinghorizon() : LQG_universal() {
374                        total_horizon = 0;
375                        total_curtime = 0;
376                }
377
378        virtual void redesign() {
379                if (total_curtime < total_horizon){
380                        for(int i = 0; i < horizon - 1; i++) LQG_universal::redesign();
381                        total_curtime++;
382                }               
383        }
384
385        virtual vec ctrlaction ( const vec &cond ) const {
386        //return L * concat(cond, 1.0);
387    }
388
389    //! register this controller with given datasource under name "name"
390    virtual void log_register ( logger &L, const string &prefix ) { }
391    //! write requested values into the logger
392    virtual void log_write ( ) const { }       
393};
394
395} // namespace
Note: See TracBrowser for help on using the browser.