root/library/bdm/design/lq_ctrl.h @ 1383

Revision 1223, 13.1 kB (checked in by vahalam, 14 years ago)
Line 
1#include "ctrlbase.h"
2
3namespace bdm {
4
5const bool STRICT_RV = true; //empty RV NOT allowed
6
7//! left matrix division with upper triangular matrix using Gauss elimination
8//! fast variant of inv(A) * B where A is UT matrix, returns result in C and true if no error
9inline bool ldutg(mat &_utA, mat &_mB, mat &_mC){
10        int utsize = _utA.rows();
11       
12        if(utsize != _utA.cols()) return false; //utA not square
13        if(utsize != _mB.rows()) return false; //incorrect mB size
14       
15        int mbcol = _mB.cols(); 
16        int i, j, k;   
17       
18        double pvt;     
19        double *utA = _utA._data();
20        double *mB = _mB._data();
21                       
22        _mC.set_size(utsize, mbcol);
23       
24        double *mC = _mC._data();
25       
26        //copy data     
27        for(i = 0; i < utsize*mbcol; i++) 
28                mC[i] = mB[i]; 
29                   
30        for(i = utsize-1; i >= 0; i--){  //Gauss elimination
31                pvt = utA[i + utsize*i];    //get pivot
32                for(j = 0; j < mbcol; j++) mC[i + utsize*j] /= pvt;     //normalize row - only part on the right                                   
33                for(j = 0;j < i; j++){ //subtract normalized row from above ones
34                        pvt = utA[j + utsize*i]; //get pivot
35                        for(k = 0; k < mbcol; k++) //goes from utsize - do not need make matrix on the left = I
36                                mC[j + utsize*k] -= pvt * mC[i + utsize*k]; //create zero col above                                       
37                }
38                                                                                 
39        }                       
40       
41        return true;
42}
43
44
45//! extended class representing function \f$f(x) = Ax+B\f$
46class linfnEx: public linfn {
47  public:
48    //! Identification of returned value \f$f(x)\f$
49    RV rv_ret;
50        //!default constructor
51    linfnEx ( ) : linfn() { };
52    linfnEx ( const mat &A0, const vec &B0 ) : linfn(A0, B0) { };
53};
54
55  //! Universal LQG controller
56class LQG_universal : public Controller{
57public:
58        //! Controller inputs
59                //! loss function
60                Array<quadraticfn> Losses;
61                //! loss in final time
62                quadraticfn finalLoss;
63                //! model of evolutin
64                Array<linfnEx> Models;
65
66                //! control law rv is public member in Controller class 
67                //! input data rvc is protected member in Controller class
68                 
69                //! control horizon
70                int horizon;
71
72        //! Constructor
73                LQG_universal() {
74                        horizon = 0;
75                        curtime = -1;
76                }
77               
78  protected:
79    //! control law: rv = L [rvc, 1]
80    mat L;
81   
82    //! Matrix pre_qr
83    mat pre_qr;
84   
85    //! Matrix post_qr
86    mat post_qr;
87
88        //! time+1 optimized loss to be added to current one
89        mat tolC;
90
91        int curtime;
92   
93public:
94    //! function redesigning the control strategy
95    virtual void redesign() {
96                if (curtime == -1) curtime = horizon;
97                               
98                if (curtime > 0){
99                        curtime--;
100                        generateLmat(curtime);
101                }               
102        }
103    //! returns designed control action
104    virtual vec ctrlaction ( const vec &cond ) const {
105        return L * concat(cond, 1.0);
106    }
107
108    void from_setting ( const Setting &set ) {
109      UI::get(Losses, set, "losses",UI::compulsory);
110      UI::get(Models, set, "models",UI::compulsory);
111    }
112    //! access function
113    const RV& _rv() {
114        return rv;
115    }
116    //! access function
117    const RV& _rvc() {
118        return rvc;
119    }
120
121        void set_rvc(RV _rvc) {rvc = _rvc;}
122
123    //! register this controller with given datasource under name "name"
124    virtual void log_register ( logger &L, const string &prefix ) { }
125    //! write requested values into the logger
126    virtual void log_write ( ) const { }
127
128        //! access debug function
129        mat getL(){ return L; }
130
131        void resetTime() { curtime = -1; }
132
133        //! check if model and losses is correct and consistent
134        virtual void validate(){
135                /*
136                        RV:findself hleda cela rv jako vektory, pri nenalezeni je -1
137                        RV:dataind hleda datove slozky, tedy indexy v poli skalaru, pri nenalezeni vynecha
138                */
139                // (0) nonempty
140                bdm_assert((Models.size() > 0), "VALIDATION FAILED! Models array empty.");
141                bdm_assert((Losses.size() > 0), "VALIDATION FAILED! Losses array empty.");
142                if( (Models.size() <= 0) || (Losses.size() <= 0) ) return;
143
144                // (1) test Models array rv - acceptable rv is only part/composition of LQG_universal::rv, LQG_universal::rvc and const 1
145                RV accept_total;
146                accept_total = rv;
147                accept_total.add(rvc);
148                accept_total.add(RV("1", 1, 0));
149               
150                int i, j;
151                ivec finding1;
152
153                for(i = 0; i < Models.length(); i++){
154                        finding1 = Models(i).rv.findself(accept_total); 
155
156                        bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");
157
158                        for(j = 0; j < finding1.size(); j++){                           
159                                bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Provided input RV for some Models function is unknown, forbidden or recursive.");                                                       
160                                if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
161                        }                       
162                }                       
163               
164                // (3) test Losses array - acceptable rv is only part/composition of LQG_universal::rv, LQG_universal::rvc, Models rv_ret and const 1
165                for(i = 0; i < Models.length(); i++) accept_total.add(Models(i).rv_ret); //old accept_total from (1) + all rv_ret from Models
166               
167                for(i = 0; i < Losses.length(); i++){
168                        finding1 = Losses(i).rv.findself(accept_total);
169
170                        bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");
171
172                        for(j = 0; j < finding1.size(); j++){
173                                bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Unacceptable RV used in some Losses function.");
174                                if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
175                        }
176                }       
177
178                // same for finalLoss
179                finding1 = finalLoss.rv.findself(accept_total);
180
181                bdm_assert( !(STRICT_RV && (finding1.size() <= 0)), "VALIDATION FAILED! Empty RV used.");
182
183                for(j = 0; j < finding1.size(); j++){
184                        bdm_assert( ( finding1(j) > (-1) ), "VALIDATION FAILED! Unacceptable RV used in finalLoss function.");
185                        if(finding1(j) <= (-1) ) return; //rv element is not part of admissible rvs => error
186                }
187        }
188
189private:
190        RV trs_crv;     
191
192        mat getMatRow (quadraticfn sourceQfn){ //returns row of matrixes crated from quadratic function
193               
194                mat tmpMatRow; //tmp variable for row of matrixes to be returned       
195                tmpMatRow.set_size(sourceQfn.Q.rows(), trs_crv.countsize()); //(rows, cols)
196                tmpMatRow.zeros();
197
198                //set data in tmpMatRow - other times then current replace using model
199                RV tmpQrv = sourceQfn.rv;
200
201                ivec j_vec(1);
202                vec copycol;
203                ivec copysource;
204                for(int j = 0; j < tmpQrv.length(); j++){                       
205                        j_vec(0) = j;
206
207                        if( (sum(tmpQrv(j_vec).findself(trs_crv)) > (-1)) ) {//sum is only formal, summed vector is in fact scalar
208                                //jth element of tmpQrv is also element of trs_crv with a proper time
209
210                                ivec copytarget = (tmpQrv(j_vec)).dataind(trs_crv); //get target column position in tmpMatRow
211                                ivec copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q
212                                if(copytarget.size() != copysource.size()) {return mat(0); /*error*/}                           
213                                for(int k = 0; k < copysource.size(); k++){
214                                        copycol = sourceQfn.Q._Ch().get_col(copysource(k));
215                                        copycol += tmpMatRow.get_col(copytarget(k));
216                                        tmpMatRow.set_col(copytarget(k), copycol);
217                                }                                       
218                        }
219                        else {                         
220                                //!model is model_rv_ret = sum(Array<linfn>) = sum( A1*rv + B1 + ... + An*rv + Bn)                             
221                               
222                                //use kth Model to convert tmpQrv memeber to trs_crv memeber
223
224                                //get submatrix from Q which represents jth tmpQrv data
225                                copysource = (tmpQrv(j_vec)).dataind(tmpQrv); //get source column position in Losses(i).Q                               
226                                mat copysubmat;
227                                copysubmat.set_size(sourceQfn.Q.rows(), copysource.size()); //(rows, cols)
228                                copysubmat.zeros();
229                                vec copycol;
230                               
231                                int k;
232                                for(k = 0; k < copysource.size(); k++){
233                                        copycol = sourceQfn.Q._Ch().get_col(copysource(k));
234                                        copysubmat.set_col(k, copycol);
235                                }
236
237                                //check every Models element if it is a proper substitution: tmpQrv(j_vec) memeber of rv_ret
238                                for(k = 0; k < Models.size(); k++){
239                                        if( sum((tmpQrv(j_vec)).findself(Models(k).rv_ret)) > (-1) ){ //formal sum, find usable model
240                                                //check if model is correct
241                                                ivec check = (Models(k).rv).findself(trs_crv);
242                                                if(sum(check) <= -check.size()){
243                                                        bdm_assert (false , "Incorrect Model: Unusable Models element!" );                                               
244                                                        continue;
245                                                }
246
247                                                //create transformed submatrix
248                                                mat transsubmat = copysubmat * ((Models(k)).A);
249
250                                                //put them on a right place in tmpQrv
251                                                ivec copytarget = (Models(k)).rv.dataind(trs_crv); //get target column position in tmpMatRow
252                                                                                               
253                                                //copy transsubmat into tmpMatRow with adding to current one                                           
254                                                int l;
255                                                for(l = 0; l < copytarget.size(); l++){                                 
256                                                        copycol = tmpMatRow.get_col(copytarget(l));                                     
257                                                        copycol += transsubmat.get_col(l);                                                             
258                                                        tmpMatRow.set_col(copytarget(l), copycol);                                                             
259                                                }
260
261                                                //if linear fnc constant element vec B is nonzero
262                                                vec constElB = (Models(k)).B;                           
263                                                if(prod(constElB) != 0){
264                                                        //copy transformed constant vec into last (1's) col in tmpMatRow
265                                                        int lastcol = tmpMatRow.cols() - 1;
266                                                        copycol = tmpMatRow.get_col(lastcol);
267                                                        copycol += (copysubmat * ((Models(k)).B));
268                                                        tmpMatRow.set_col(lastcol, copycol);
269                                                }
270                                        }
271
272                                }
273
274                               
275                        }
276                }
277
278                return tmpMatRow;
279        }
280
281        //! create first(0) or other (1) pre_qr matrix
282        void build_pre_qr(bool next) {
283                int i;
284                //used fake quadratic function from tolC matrix
285                quadraticfn fakeQfn;           
286
287                //set proper size of pre_qr matrix
288                int rows = 0;
289                for(i = 0; i < Losses.size(); i++)
290                        rows += Losses(i).Q.rows();
291                if(!next) rows += finalLoss.Q.rows();
292                else{
293                        //used fake quadratic function from tolC matrix
294                        //setup fakeQfn
295                        fakeQfn.Q.setCh(tolC);
296                        RV fakeM1;
297                        fakeM1 = rvc;
298                        fakeM1.add(RV("1", 1, 0));
299                        fakeM1.t_plus(1); //RV in time t+1 => necessary use of Model to get RV in time t
300                        fakeM1.set_time((RV("1", 1, 1).findself(fakeM1))(0) , 0);
301
302                        fakeQfn.rv = fakeM1;
303                       
304                        rows += fakeQfn.Q.rows();
305                }
306
307                pre_qr.set_size(rows, trs_crv.countsize()); //(rows, cols)
308                pre_qr.zeros();
309
310                //fill pre_qr matrix for each Losses quadraticfn               
311                int rowIndex = 0;
312                mat tmpMatRow;
313                for(i = 0; i < Losses.size(); i++) {
314                        rows = Losses(i).Q.rows();
315                       
316                        //compute row matrix and insert it on proper place in pre_qr
317                        tmpMatRow = getMatRow(Losses(i));
318
319                        //copy tmpMatRow in pre_qr
320                        pre_qr.set_submatrix(rowIndex, (rowIndex + rows - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)
321                        rowIndex += rows;
322                }
323
324                if(!next) {                     
325                        tmpMatRow = getMatRow(finalLoss);                       
326                        pre_qr.set_submatrix(rowIndex, (rowIndex + finalLoss.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)               
327                }
328                else { //next
329                        //based on tolC but time must be shifted by one - all implemented in getMatRow method
330                               
331                        //get matrix row via getMatRow method                   
332                        tmpMatRow = getMatRow(fakeQfn);
333                       
334                        pre_qr.set_submatrix(rowIndex, (rowIndex + fakeQfn.Q.rows() - 1), 0, (trs_crv.countsize() - 1), tmpMatRow);  //(int r1, int r2, int c1, int c2, const Mat<  Num_T > &m)         
335                }
336        }       
337
338        mat get_qr_submatrix(int submatidx) {
339        /*
340                |rv||rvc||1|
341
342                AAAABBBBBBBB
343                 AAABBBBBBBB
344                  AABBBBBBBB
345                   ABBBBBBBB
346                    CCCCCCCC
347                         CCCCCCC
348                          CCCCCC
349                           CCCCC
350                            CCCC
351                                 CCC
352                                  CC
353                                   C
354        */
355        /*!
356                submatidx | get_submatrix
357                ----------|--------------
358                    0     |      A
359                        1     |          B
360                        2+        |              C
361        */
362                int sizeA = rv.countsize();
363                int colsB = post_qr.cols() -  sizeA;
364                //  rowsB = sizeA;
365                //  colsC = colsB;
366                //not required whole C - it is triangular
367                //=> NOT int rowsC = post_qr.rows() - sizeA;
368                //=> int sizeC = colsB;
369
370                mat qr_submat;
371               
372                if(submatidx == 0) qr_submat            = post_qr.get(0,                (sizeA - 1),                    0,              (sizeA - 1));  //(int r1, int r2, int c1, int c2)
373                else if(submatidx == 1) qr_submat       = post_qr.get(0,                (sizeA - 1),                    sizeA,  (post_qr.cols() - 1));
374                else {
375                        if(post_qr.cols() > post_qr.rows()) { //extend post_qr matrix to be at least square
376                                post_qr.set_size(post_qr.cols(), post_qr.cols(), true);                         
377                        }
378                       
379                        qr_submat                                               = post_qr.get(sizeA,    (sizeA + colsB - 1),    sizeA,  (post_qr.cols() - 1));
380                       
381                }
382       
383                return qr_submat;
384        }
385
386        void generateLmat(int timestep){
387                //! control strategy matrix L is based on loss in time:
388                //!             time = horizon                  loss = finalLoss
389                //!             time = horizon - 1              loss = sum(Losses)(time) + finalLoss
390                //!             time = horizon - k > 1  loss = sum(Losses)(time) + tolC time+1 loss
391
392                trs_crv = rv; //transformed crv only in proper times, form [rv, rvc, 1]
393                trs_crv.add(rvc);
394                trs_crv.add(RV("1", 1, 0));
395                       
396                //!first time, time = horizon - 1
397                if(timestep == (horizon-1))             
398                        build_pre_qr(0);
399               
400                //!other times         
401                else
402                        build_pre_qr(1);
403
404                mat tmpQ;               
405                qr(pre_qr, tmpQ, post_qr);
406
407                mat qrA = -get_qr_submatrix(0);         
408                mat qrB = get_qr_submatrix(1);
409                mat qrC = get_qr_submatrix(2);
410
411                //mat L = inv(qrA)*qrB;                 
412                bool invmult = ldutg(qrA, qrB, L);
413                bdm_assert(invmult, "Matrix inversion error!");
414                                               
415                // ldutg is faster matrix inv&mult (like Matlab's \ operator) function than inv
416                // it uses Gauss elimination
417                // BUT based on NOT RECOMENDED direct data access method in mat class
418                // even faster implementation could be implemented using fix double arrays             
419               
420                tolC = qrC;
421        }
422
423};
424
425} // namespace
Note: See TracBrowser for help on using the browser.