Advertisement
BumiBarbi

fEIG.cc

Mar 29th, 2016
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 29.32 KB | None | 0 0
  1. /*
  2.  
  3. Copyright (C) 1994-2015 John W. Eaton
  4.  
  5. This file is part of Octave.
  6.  
  7. Octave is free software; you can redistribute it and/or modify it
  8. under the terms of the GNU General Public License as published by the
  9. Free Software Foundation; either version 3 of the License, or (at your
  10. option) any later version.
  11.  
  12. Octave is distributed in the hope that it will be useful, but WITHOUT
  13. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  14. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  15. for more details.
  16.  
  17. You should have received a copy of the GNU General Public License
  18. along with Octave; see the file COPYING.  If not, see
  19. <http://www.gnu.org/licenses/>.
  20.  
  21. */
  22.  
  23. #ifdef HAVE_CONFIG_H
  24. #  include "config.h"
  25. #endif
  26.  
  27. #include "fEIG.h"
  28. #include "fColVector.h"
  29. #include "f77-fcn.h"
  30. #include "lo-error.h"
  31.  
  32. extern "C"
  33. {
  34.   F77_RET_T
  35.   F77_FUNC (sgeevx, SGEEVX) (F77_CONST_CHAR_ARG_DECL,
  36.                              F77_CONST_CHAR_ARG_DECL,
  37.                              F77_CONST_CHAR_ARG_DECL,
  38.                              F77_CONST_CHAR_ARG_DECL,
  39.                              const octave_idx_type&, float*,
  40.                              const octave_idx_type&, float*, float*, float*,
  41.                              const octave_idx_type&, float*,
  42.                              const octave_idx_type&, octave_idx_type&,
  43.                              octave_idx_type&, float*, float&, float*,
  44.                              float*, float*, const octave_idx_type&,
  45.                              octave_idx_type*, octave_idx_type&
  46.                              F77_CHAR_ARG_LEN_DECL
  47.                              F77_CHAR_ARG_LEN_DECL
  48.                              F77_CHAR_ARG_LEN_DECL
  49.                              F77_CHAR_ARG_LEN_DECL);
  50.  
  51.   F77_RET_T
  52.   F77_FUNC (sgeev, SGEEV) (F77_CONST_CHAR_ARG_DECL,
  53.                            F77_CONST_CHAR_ARG_DECL,
  54.                            const octave_idx_type&, float*,
  55.                            const octave_idx_type&, float*, float*, float*,
  56.                            const octave_idx_type&, float*,
  57.                            const octave_idx_type&, float*,
  58.                            const octave_idx_type&, octave_idx_type&
  59.                            F77_CHAR_ARG_LEN_DECL
  60.                            F77_CHAR_ARG_LEN_DECL);
  61.                            
  62.   F77_RET_T
  63.   F77_FUNC (cgeevx, CGEEVX) (F77_CONST_CHAR_ARG_DECL,
  64.                              F77_CONST_CHAR_ARG_DECL,
  65.                              F77_CONST_CHAR_ARG_DECL,
  66.                              F77_CONST_CHAR_ARG_DECL,
  67.                              const octave_idx_type&, FloatComplex*,
  68.                              const octave_idx_type&, FloatComplex*, FloatComplex*,
  69.                              const octave_idx_type&, FloatComplex*,
  70.                              const octave_idx_type&, octave_idx_type&,
  71.                              octave_idx_type&, float*, float&, float*,
  72.                              float*, FloatComplex*, const octave_idx_type&,
  73.                              float*, octave_idx_type&
  74.                              F77_CHAR_ARG_LEN_DECL
  75.                              F77_CHAR_ARG_LEN_DECL
  76.                              F77_CHAR_ARG_LEN_DECL
  77.                              F77_CHAR_ARG_LEN_DECL);
  78.  
  79.   F77_RET_T
  80.   F77_FUNC (cgeev, CGEEV) (F77_CONST_CHAR_ARG_DECL,
  81.                            F77_CONST_CHAR_ARG_DECL,
  82.                            const octave_idx_type&, FloatComplex*,
  83.                            const octave_idx_type&, FloatComplex*, FloatComplex*,
  84.                            const octave_idx_type&, FloatComplex*,
  85.                            const octave_idx_type&, FloatComplex*,
  86.                            const octave_idx_type&, float*, octave_idx_type&
  87.                            F77_CHAR_ARG_LEN_DECL
  88.                            F77_CHAR_ARG_LEN_DECL);
  89.  
  90.   F77_RET_T
  91.   F77_FUNC (ssyev, SSYEV) (F77_CONST_CHAR_ARG_DECL,
  92.                            F77_CONST_CHAR_ARG_DECL,
  93.                            const octave_idx_type&, float*,
  94.                            const octave_idx_type&, float*, float*,
  95.                            const octave_idx_type&, octave_idx_type&
  96.                            F77_CHAR_ARG_LEN_DECL
  97.                            F77_CHAR_ARG_LEN_DECL);
  98.  
  99.   F77_RET_T
  100.   F77_FUNC (cheev, CHEEV) (F77_CONST_CHAR_ARG_DECL,
  101.                            F77_CONST_CHAR_ARG_DECL,
  102.                            const octave_idx_type&, FloatComplex*,
  103.                            const octave_idx_type&, float*, FloatComplex*,
  104.                            const octave_idx_type&, float*, octave_idx_type&
  105.                            F77_CHAR_ARG_LEN_DECL
  106.                            F77_CHAR_ARG_LEN_DECL);
  107.  
  108.   F77_RET_T
  109.   F77_FUNC (spotrf, SPOTRF) (F77_CONST_CHAR_ARG_DECL,
  110.                              const octave_idx_type&, float*,
  111.                              const octave_idx_type&, octave_idx_type&
  112.                              F77_CHAR_ARG_LEN_DECL
  113.                              F77_CHAR_ARG_LEN_DECL);
  114.  
  115.   F77_RET_T
  116.   F77_FUNC (cpotrf, CPOTRF) (F77_CONST_CHAR_ARG_DECL,
  117.                              const octave_idx_type&, FloatComplex*,
  118.                              const octave_idx_type&, octave_idx_type&
  119.                              F77_CHAR_ARG_LEN_DECL
  120.                              F77_CHAR_ARG_LEN_DECL);
  121.  
  122.   F77_RET_T
  123.   F77_FUNC (sggev, SGGEV) (F77_CONST_CHAR_ARG_DECL,
  124.                            F77_CONST_CHAR_ARG_DECL,
  125.                            const octave_idx_type&, float*,
  126.                            const octave_idx_type&, float*,
  127.                            const octave_idx_type&, float*, float*, float*,
  128.                            float*, const octave_idx_type&, float*,
  129.                            const octave_idx_type&, float*,
  130.                            const octave_idx_type&, octave_idx_type&
  131.                            F77_CHAR_ARG_LEN_DECL
  132.                            F77_CHAR_ARG_LEN_DECL);
  133.  
  134.   F77_RET_T
  135.   F77_FUNC (ssygv, SSYGV) (const octave_idx_type&,
  136.                            F77_CONST_CHAR_ARG_DECL,
  137.                            F77_CONST_CHAR_ARG_DECL,
  138.                            const octave_idx_type&, float*,
  139.                            const octave_idx_type&, float*,
  140.                            const octave_idx_type&, float*, float*,
  141.                            const octave_idx_type&, octave_idx_type&
  142.                            F77_CHAR_ARG_LEN_DECL
  143.                            F77_CHAR_ARG_LEN_DECL);
  144.  
  145.   F77_RET_T
  146.   F77_FUNC (cggev, CGGEV) (F77_CONST_CHAR_ARG_DECL,
  147.                            F77_CONST_CHAR_ARG_DECL,
  148.                            const octave_idx_type&, FloatComplex*,
  149.                            const octave_idx_type&, FloatComplex*,
  150.                            const octave_idx_type&, FloatComplex*,
  151.                            FloatComplex*, FloatComplex*,
  152.                            const octave_idx_type&, FloatComplex*,
  153.                            const octave_idx_type&, FloatComplex*,
  154.                            const octave_idx_type&, float*, octave_idx_type&
  155.                            F77_CHAR_ARG_LEN_DECL
  156.                            F77_CHAR_ARG_LEN_DECL);
  157.  
  158.   F77_RET_T
  159.   F77_FUNC (chegv, CHEGV) (const octave_idx_type&,
  160.                            F77_CONST_CHAR_ARG_DECL,
  161.                            F77_CONST_CHAR_ARG_DECL,
  162.                            const octave_idx_type&, FloatComplex*,
  163.                            const octave_idx_type&, FloatComplex*,
  164.                            const octave_idx_type&, float*, FloatComplex*,
  165.                            const octave_idx_type&, float*, octave_idx_type&
  166.                            F77_CHAR_ARG_LEN_DECL
  167.                            F77_CHAR_ARG_LEN_DECL);
  168. }
  169.  
  170. octave_idx_type
  171. FloatEIG::init (const FloatMatrix& a, bool calc_ev, bool balance)
  172. {
  173.   if (a.any_element_is_inf_or_nan ())
  174.     (*current_liboctave_error_handler)
  175.       ("EIG: matrix contains Inf or NaN values");
  176.  
  177.   if (a.is_symmetric ())
  178.     return symmetric_init (a, calc_ev);
  179.  
  180.   octave_idx_type n = a.rows ();
  181.  
  182.   if (n != a.cols ())
  183.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  184.  
  185.   octave_idx_type info = 0;
  186.  
  187.   FloatMatrix atmp = a;
  188.   float *tmp_data = atmp.fortran_vec ();
  189.  
  190.   Array<float> wr (dim_vector (n, 1));
  191.   float *pwr = wr.fortran_vec ();
  192.  
  193.   Array<float> wi (dim_vector (n, 1));
  194.   float *pwi = wi.fortran_vec ();
  195.  
  196.   volatile octave_idx_type nvr = calc_ev ? n : 0;
  197.   FloatMatrix vr (nvr, nvr);
  198.   float *pvr = vr.fortran_vec ();
  199.  
  200.   octave_idx_type lwork = -1;
  201.   float dummy_work;
  202.  
  203.   float *dummy = 0;
  204.   octave_idx_type idummy = 1;
  205.  
  206.   octave_idx_type ilo;
  207.   octave_idx_type ihi;
  208.  
  209.   Array<float> scale (dim_vector (n, 1));
  210.   float *pscale = scale.fortran_vec ();
  211.  
  212.   float abnrm;
  213.  
  214.   Array<float> rconde (dim_vector (n, 1));
  215.   float *prconde = rconde.fortran_vec ();
  216.  
  217.   Array<float> rcondv (dim_vector (n, 1));
  218.   float *prcondv = rcondv.fortran_vec ();
  219.  
  220.   octave_idx_type dummy_iwork;
  221.  
  222.   F77_XFCN (sgeevx, SGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
  223.                              F77_CONST_CHAR_ARG2 ("N", 1),
  224.                              F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  225.                              F77_CONST_CHAR_ARG2 ("N", 1),
  226.                              n, tmp_data, n, pwr, pwi, dummy,
  227.                              idummy, pvr, n,
  228.                              ilo, ihi, pscale, abnrm, prconde, prcondv,
  229.                              &dummy_work, lwork, &dummy_iwork, info
  230.                              F77_CHAR_ARG_LEN (1)
  231.                              F77_CHAR_ARG_LEN (1)
  232.                              F77_CHAR_ARG_LEN (1)
  233.                              F77_CHAR_ARG_LEN (1)));                        
  234.  
  235.   if (info != 0)
  236.     (*current_liboctave_error_handler) ("sgeevx workspace query failed");
  237.  
  238.   lwork = static_cast<octave_idx_type> (dummy_work);
  239.   Array<float> work (dim_vector (lwork, 1));
  240.   float *pwork = work.fortran_vec ();
  241.  
  242.   F77_XFCN (sgeevx, SGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
  243.                              F77_CONST_CHAR_ARG2 ("N", 1),
  244.                              F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  245.                              F77_CONST_CHAR_ARG2 ("N", 1),
  246.                              n, tmp_data, n, pwr, pwi, dummy,
  247.                              idummy, pvr, n,
  248.                              ilo, ihi, pscale, abnrm, prconde, prcondv,
  249.                              pwork, lwork, &dummy_iwork, info
  250.                              F77_CHAR_ARG_LEN (1)
  251.                              F77_CHAR_ARG_LEN (1)
  252.                              F77_CHAR_ARG_LEN (1)
  253.                              F77_CHAR_ARG_LEN (1)));
  254.  
  255.   if (info < 0)
  256.     (*current_liboctave_error_handler) ("unrecoverable error in sgeevx");
  257.  
  258.   if (info > 0)
  259.     (*current_liboctave_error_handler) ("sgeevx failed to converge");
  260.  
  261.   lambda.resize (n);
  262.   v.resize (nvr, nvr);
  263.  
  264.   for (octave_idx_type j = 0; j < n; j++)
  265.     {
  266.       if (wi.elem (j) == 0.0)
  267.         {
  268.           lambda.elem (j) = FloatComplex (wr.elem (j));
  269.           for (octave_idx_type i = 0; i < nvr; i++)
  270.             v.elem (i, j) = vr.elem (i, j);
  271.         }
  272.       else
  273.         {
  274.           if (j+1 >= n)
  275.             (*current_liboctave_error_handler) ("EIG: internal error");
  276.  
  277.           lambda.elem (j) = FloatComplex (wr.elem (j), wi.elem (j));
  278.           lambda.elem (j+1) = FloatComplex (wr.elem (j+1), wi.elem (j+1));
  279.  
  280.           for (octave_idx_type i = 0; i < nvr; i++)
  281.             {
  282.               float real_part = vr.elem (i, j);
  283.               float imag_part = vr.elem (i, j+1);
  284.               v.elem (i, j) = FloatComplex (real_part, imag_part);
  285.               v.elem (i, j+1) = FloatComplex (real_part, -imag_part);
  286.             }
  287.           j++;
  288.         }
  289.     }
  290.  
  291.   return info;
  292. }
  293.  
  294. octave_idx_type
  295. FloatEIG::symmetric_init (const FloatMatrix& a, bool calc_ev)
  296. {
  297.   octave_idx_type n = a.rows ();
  298.  
  299.   if (n != a.cols ())
  300.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  301.  
  302.   octave_idx_type info = 0;
  303.  
  304.   FloatMatrix atmp = a;
  305.   float *tmp_data = atmp.fortran_vec ();
  306.  
  307.   FloatColumnVector wr (n);
  308.   float *pwr = wr.fortran_vec ();
  309.  
  310.   octave_idx_type lwork = -1;
  311.   float dummy_work;
  312.  
  313.   F77_XFCN (ssyev, SSYEV, (F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  314.                            F77_CONST_CHAR_ARG2 ("U", 1),
  315.                            n, tmp_data, n, pwr, &dummy_work, lwork, info
  316.                            F77_CHAR_ARG_LEN (1)
  317.                            F77_CHAR_ARG_LEN (1)));
  318.  
  319.   if (info != 0)
  320.     (*current_liboctave_error_handler) ("ssyev workspace query failed");
  321.  
  322.   lwork = static_cast<octave_idx_type> (dummy_work);
  323.   Array<float> work (dim_vector (lwork, 1));
  324.   float *pwork = work.fortran_vec ();
  325.  
  326.   F77_XFCN (ssyev, SSYEV, (F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  327.                            F77_CONST_CHAR_ARG2 ("U", 1),
  328.                            n, tmp_data, n, pwr, pwork, lwork, info
  329.                            F77_CHAR_ARG_LEN (1)
  330.                            F77_CHAR_ARG_LEN (1)));
  331.  
  332.   if (info < 0)
  333.     (*current_liboctave_error_handler) ("unrecoverable error in ssyev");
  334.  
  335.   if (info > 0)
  336.     (*current_liboctave_error_handler) ("ssyev failed to converge");
  337.  
  338.   lambda = FloatComplexColumnVector (wr);
  339.   v = calc_ev ? FloatComplexMatrix (atmp) : FloatComplexMatrix ();
  340.  
  341.   return info;
  342. }
  343.  
  344. octave_idx_type
  345. FloatEIG::init (const FloatComplexMatrix& a, bool calc_ev, bool balance)
  346. {
  347.   if (a.any_element_is_inf_or_nan ())
  348.     (*current_liboctave_error_handler)
  349.       ("EIG: matrix contains Inf or NaN values");
  350.  
  351.   if (a.is_hermitian ())
  352.     return hermitian_init (a, calc_ev);
  353.  
  354.   octave_idx_type n = a.rows ();
  355.  
  356.   if (n != a.cols ())
  357.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  358.  
  359.   octave_idx_type info = 0;
  360.  
  361.   FloatComplexMatrix atmp = a;
  362.   FloatComplex *tmp_data = atmp.fortran_vec ();
  363.  
  364.   FloatComplexColumnVector w (n);
  365.   FloatComplex *pw = w.fortran_vec ();
  366.  
  367.   octave_idx_type nvr = calc_ev ? n : 0;
  368.   FloatComplexMatrix vtmp (nvr, nvr);
  369.   FloatComplex *pv = vtmp.fortran_vec ();
  370.  
  371.   octave_idx_type lwork = -1;
  372.   FloatComplex dummy_work;
  373.  
  374.   octave_idx_type lrwork = 2*n;
  375.   Array<float> rwork (dim_vector (lrwork, 1));
  376.   float *prwork = rwork.fortran_vec ();
  377.  
  378.   FloatComplex *dummy = 0;
  379.   octave_idx_type idummy = 1;
  380.  
  381.   octave_idx_type ilo;
  382.   octave_idx_type ihi;
  383.  
  384.   Array<float> scale (dim_vector (n, 1));
  385.   float *pscale = scale.fortran_vec ();
  386.  
  387.   float abnrm;
  388.  
  389.   Array<float> rconde (dim_vector (n, 1));
  390.   float *prconde = rconde.fortran_vec ();
  391.  
  392.   Array<float> rcondv (dim_vector (n, 1));
  393.   float *prcondv = rcondv.fortran_vec ();
  394.  
  395.   F77_XFCN (cgeevx, CGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
  396.                              F77_CONST_CHAR_ARG2 ("N", 1),
  397.                              F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  398.                              F77_CONST_CHAR_ARG2 ("N", 1),
  399.                              n, tmp_data, n, pw, dummy, idummy,
  400.                              pv, n, ilo, ihi, pscale, abnrm, prconde, prcondv,
  401.                              &dummy_work, lwork, prwork, info
  402.                              F77_CHAR_ARG_LEN (1)
  403.                              F77_CHAR_ARG_LEN (1)
  404.                              F77_CHAR_ARG_LEN (1)
  405.                              F77_CHAR_ARG_LEN (1)));
  406.  
  407.   if (info != 0)
  408.     (*current_liboctave_error_handler) ("cgeevx workspace query failed");
  409.  
  410.   lwork = static_cast<octave_idx_type> (dummy_work.real ());
  411.   Array<FloatComplex> work (dim_vector (lwork, 1));
  412.   FloatComplex *pwork = work.fortran_vec ();
  413.  
  414.   F77_XFCN (cgeevx, CGEEVX, (F77_CONST_CHAR_ARG2 (balance ? "B" : "N", 1),
  415.                              F77_CONST_CHAR_ARG2 ("N", 1),
  416.                              F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  417.                              F77_CONST_CHAR_ARG2 ("N", 1),
  418.                              n, tmp_data, n, pw, dummy, idummy,
  419.                              pv, n, ilo, ihi, pscale,abnrm, prconde, prcondv,
  420.                              pwork, lwork, prwork, info
  421.                              F77_CHAR_ARG_LEN (1)
  422.                              F77_CHAR_ARG_LEN (1)
  423.                              F77_CHAR_ARG_LEN (1)
  424.                              F77_CHAR_ARG_LEN (1)));  
  425.  
  426.   if (info < 0)
  427.     (*current_liboctave_error_handler) ("unrecoverable error in cgeevx");
  428.  
  429.   if (info > 0)
  430.     (*current_liboctave_error_handler) ("cgeevx failed to converge");
  431.  
  432.   lambda = w;
  433.   v = vtmp;
  434.  
  435.   return info;
  436. }
  437.  
  438. octave_idx_type
  439. FloatEIG::hermitian_init (const FloatComplexMatrix& a, bool calc_ev)
  440. {
  441.   octave_idx_type n = a.rows ();
  442.  
  443.   if (n != a.cols ())
  444.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  445.  
  446.   octave_idx_type info = 0;
  447.  
  448.   FloatComplexMatrix atmp = a;
  449.   FloatComplex *tmp_data = atmp.fortran_vec ();
  450.  
  451.   FloatColumnVector wr (n);
  452.   float *pwr = wr.fortran_vec ();
  453.  
  454.   octave_idx_type lwork = -1;
  455.   FloatComplex dummy_work;
  456.  
  457.   octave_idx_type lrwork = 3*n;
  458.   Array<float> rwork (dim_vector (lrwork, 1));
  459.   float *prwork = rwork.fortran_vec ();
  460.  
  461.   F77_XFCN (cheev, CHEEV, (F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  462.                            F77_CONST_CHAR_ARG2 ("U", 1),
  463.                            n, tmp_data, n, pwr, &dummy_work, lwork,
  464.                            prwork, info
  465.                            F77_CHAR_ARG_LEN (1)
  466.                            F77_CHAR_ARG_LEN (1)));
  467.  
  468.   if (info != 0)
  469.     (*current_liboctave_error_handler) ("cheev workspace query failed");
  470.  
  471.   lwork = static_cast<octave_idx_type> (dummy_work.real ());
  472.   Array<FloatComplex> work (dim_vector (lwork, 1));
  473.   FloatComplex *pwork = work.fortran_vec ();
  474.  
  475.   F77_XFCN (cheev, CHEEV, (F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  476.                            F77_CONST_CHAR_ARG2 ("U", 1),
  477.                            n, tmp_data, n, pwr, pwork, lwork, prwork, info
  478.                            F77_CHAR_ARG_LEN (1)
  479.                            F77_CHAR_ARG_LEN (1)));
  480.  
  481.   if (info < 0)
  482.     (*current_liboctave_error_handler) ("unrecoverable error in cheev");
  483.  
  484.   if (info > 0)
  485.     (*current_liboctave_error_handler) ("cheev failed to converge");
  486.  
  487.   lambda = FloatComplexColumnVector (wr);
  488.   v = calc_ev ? FloatComplexMatrix (atmp) : FloatComplexMatrix ();
  489.  
  490.   return info;
  491. }
  492.  
  493. octave_idx_type
  494. FloatEIG::init (const FloatMatrix& a, const FloatMatrix& b, bool calc_ev)
  495. {
  496.   if (a.any_element_is_inf_or_nan () || b.any_element_is_inf_or_nan ())
  497.     (*current_liboctave_error_handler)
  498.       ("EIG: matrix contains Inf or NaN values");
  499.  
  500.   octave_idx_type n = a.rows ();
  501.   octave_idx_type nb = b.rows ();
  502.  
  503.   if (n != a.cols () || nb != b.cols ())
  504.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  505.  
  506.   if (n != nb)
  507.     (*current_liboctave_error_handler) ("EIG requires same size matrices");
  508.  
  509.   octave_idx_type info = 0;
  510.  
  511.   FloatMatrix tmp = b;
  512.   float *tmp_data = tmp.fortran_vec ();
  513.  
  514.   F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
  515.                              n, tmp_data, n,
  516.                              info
  517.                              F77_CHAR_ARG_LEN (1)
  518.                              F77_CHAR_ARG_LEN (1)));
  519.  
  520.   if (a.is_symmetric () && b.is_symmetric () && info == 0)
  521.     return symmetric_init (a, b, calc_ev);
  522.  
  523.   FloatMatrix atmp = a;
  524.   float *atmp_data = atmp.fortran_vec ();
  525.  
  526.   FloatMatrix btmp = b;
  527.   float *btmp_data = btmp.fortran_vec ();
  528.  
  529.   Array<float> ar (dim_vector (n, 1));
  530.   float *par = ar.fortran_vec ();
  531.  
  532.   Array<float> ai (dim_vector (n, 1));
  533.   float *pai = ai.fortran_vec ();
  534.  
  535.   Array<float> beta (dim_vector (n, 1));
  536.   float *pbeta = beta.fortran_vec ();
  537.  
  538.   volatile octave_idx_type nvr = calc_ev ? n : 0;
  539.   FloatMatrix vr (nvr, nvr);
  540.   float *pvr = vr.fortran_vec ();
  541.  
  542.   octave_idx_type lwork = -1;
  543.   float dummy_work;
  544.  
  545.   float *dummy = 0;
  546.   octave_idx_type idummy = 1;
  547.  
  548.   F77_XFCN (sggev, SGGEV, (F77_CONST_CHAR_ARG2 ("N", 1),
  549.                            F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  550.                            n, atmp_data, n, btmp_data, n,
  551.                            par, pai, pbeta,
  552.                            dummy, idummy, pvr, n,
  553.                            &dummy_work, lwork, info
  554.                            F77_CHAR_ARG_LEN (1)
  555.                            F77_CHAR_ARG_LEN (1)));
  556.  
  557.   if (info != 0)
  558.     (*current_liboctave_error_handler) ("sggev workspace query failed");
  559.  
  560.   lwork = static_cast<octave_idx_type> (dummy_work);
  561.   Array<float> work (dim_vector (lwork, 1));
  562.   float *pwork = work.fortran_vec ();
  563.  
  564.   F77_XFCN (sggev, SGGEV, (F77_CONST_CHAR_ARG2 ("N", 1),
  565.                            F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  566.                            n, atmp_data, n, btmp_data, n,
  567.                            par, pai, pbeta,
  568.                            dummy, idummy, pvr, n,
  569.                            pwork, lwork, info
  570.                            F77_CHAR_ARG_LEN (1)
  571.                            F77_CHAR_ARG_LEN (1)));
  572.  
  573.   if (info < 0)
  574.     (*current_liboctave_error_handler) ("unrecoverable error in sggev");
  575.  
  576.   if (info > 0)
  577.     (*current_liboctave_error_handler) ("sggev failed to converge");
  578.  
  579.   lambda.resize (n);
  580.   v.resize (nvr, nvr);
  581.  
  582.   for (octave_idx_type j = 0; j < n; j++)
  583.     {
  584.       if (ai.elem (j) == 0.0)
  585.         {
  586.           lambda.elem (j) = FloatComplex (ar.elem (j) / beta.elem (j));
  587.           for (octave_idx_type i = 0; i < nvr; i++)
  588.             v.elem (i, j) = vr.elem (i, j);
  589.         }
  590.       else
  591.         {
  592.           if (j+1 >= n)
  593.             (*current_liboctave_error_handler) ("EIG: internal error");
  594.  
  595.           lambda.elem (j) = FloatComplex (ar.elem (j) / beta.elem (j),
  596.                                           ai.elem (j) / beta.elem (j));
  597.           lambda.elem (j+1) = FloatComplex (ar.elem (j+1) / beta.elem (j+1),
  598.                                             ai.elem (j+1) / beta.elem (j+1));
  599.  
  600.           for (octave_idx_type i = 0; i < nvr; i++)
  601.             {
  602.               float real_part = vr.elem (i, j);
  603.               float imag_part = vr.elem (i, j+1);
  604.               v.elem (i, j) = FloatComplex (real_part, imag_part);
  605.               v.elem (i, j+1) = FloatComplex (real_part, -imag_part);
  606.             }
  607.           j++;
  608.         }
  609.     }
  610.  
  611.   return info;
  612. }
  613.  
  614. octave_idx_type
  615. FloatEIG::symmetric_init (const FloatMatrix& a, const FloatMatrix& b,
  616.                           bool calc_ev)
  617. {
  618.   octave_idx_type n = a.rows ();
  619.   octave_idx_type nb = b.rows ();
  620.  
  621.   if (n != a.cols () || nb != b.cols ())
  622.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  623.  
  624.   if (n != nb)
  625.     (*current_liboctave_error_handler) ("EIG requires same size matrices");
  626.  
  627.   octave_idx_type info = 0;
  628.  
  629.   FloatMatrix atmp = a;
  630.   float *atmp_data = atmp.fortran_vec ();
  631.  
  632.   FloatMatrix btmp = b;
  633.   float *btmp_data = btmp.fortran_vec ();
  634.  
  635.   FloatColumnVector wr (n);
  636.   float *pwr = wr.fortran_vec ();
  637.  
  638.   octave_idx_type lwork = -1;
  639.   float dummy_work;
  640.  
  641.   F77_XFCN (ssygv, SSYGV, (1, F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  642.                            F77_CONST_CHAR_ARG2 ("U", 1),
  643.                            n, atmp_data, n,
  644.                            btmp_data, n,
  645.                            pwr, &dummy_work, lwork, info
  646.                            F77_CHAR_ARG_LEN (1)
  647.                            F77_CHAR_ARG_LEN (1)));
  648.  
  649.   if (info != 0)
  650.     (*current_liboctave_error_handler) ("ssygv workspace query failed");
  651.  
  652.   lwork = static_cast<octave_idx_type> (dummy_work);
  653.   Array<float> work (dim_vector (lwork, 1));
  654.   float *pwork = work.fortran_vec ();
  655.  
  656.   F77_XFCN (ssygv, SSYGV, (1, F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  657.                            F77_CONST_CHAR_ARG2 ("U", 1),
  658.                            n, atmp_data, n,
  659.                            btmp_data, n,
  660.                            pwr, pwork, lwork, info
  661.                            F77_CHAR_ARG_LEN (1)
  662.                            F77_CHAR_ARG_LEN (1)));
  663.  
  664.   if (info < 0)
  665.     (*current_liboctave_error_handler) ("unrecoverable error in ssygv");
  666.  
  667.   if (info > 0)
  668.     (*current_liboctave_error_handler) ("ssygv failed to converge");
  669.  
  670.   lambda = FloatComplexColumnVector (wr);
  671.   v = calc_ev ? FloatComplexMatrix (atmp) : FloatComplexMatrix ();
  672.  
  673.   return info;
  674. }
  675.  
  676. octave_idx_type
  677. FloatEIG::init (const FloatComplexMatrix& a, const FloatComplexMatrix& b,
  678.                 bool calc_ev)
  679. {
  680.   if (a.any_element_is_inf_or_nan () || b.any_element_is_inf_or_nan ())
  681.     (*current_liboctave_error_handler)
  682.       ("EIG: matrix contains Inf or NaN values");
  683.  
  684.   octave_idx_type n = a.rows ();
  685.   octave_idx_type nb = b.rows ();
  686.  
  687.   if (n != a.cols () || nb != b.cols ())
  688.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  689.  
  690.   if (n != nb)
  691.     (*current_liboctave_error_handler) ("EIG requires same size matrices");
  692.  
  693.   octave_idx_type info = 0;
  694.  
  695.   FloatComplexMatrix tmp = b;
  696.   FloatComplex *tmp_data = tmp.fortran_vec ();
  697.  
  698.   F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
  699.                              n, tmp_data, n,
  700.                              info
  701.                              F77_CHAR_ARG_LEN (1)
  702.                              F77_CHAR_ARG_LEN (1)));
  703.  
  704.   if (a.is_hermitian () && b.is_hermitian () && info == 0)
  705.     return hermitian_init (a, b, calc_ev);
  706.  
  707.   FloatComplexMatrix atmp = a;
  708.   FloatComplex *atmp_data = atmp.fortran_vec ();
  709.  
  710.   FloatComplexMatrix btmp = b;
  711.   FloatComplex *btmp_data = btmp.fortran_vec ();
  712.  
  713.   FloatComplexColumnVector alpha (n);
  714.   FloatComplex *palpha = alpha.fortran_vec ();
  715.  
  716.   FloatComplexColumnVector beta (n);
  717.   FloatComplex *pbeta = beta.fortran_vec ();
  718.  
  719.   octave_idx_type nvr = calc_ev ? n : 0;
  720.   FloatComplexMatrix vtmp (nvr, nvr);
  721.   FloatComplex *pv = vtmp.fortran_vec ();
  722.  
  723.   octave_idx_type lwork = -1;
  724.   FloatComplex dummy_work;
  725.  
  726.   octave_idx_type lrwork = 8*n;
  727.   Array<float> rwork (dim_vector (lrwork, 1));
  728.   float *prwork = rwork.fortran_vec ();
  729.  
  730.   FloatComplex *dummy = 0;
  731.   octave_idx_type idummy = 1;
  732.  
  733.   F77_XFCN (cggev, CGGEV, (F77_CONST_CHAR_ARG2 ("N", 1),
  734.                            F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  735.                            n, atmp_data, n, btmp_data, n,
  736.                            palpha, pbeta, dummy, idummy,
  737.                            pv, n, &dummy_work, lwork, prwork, info
  738.                            F77_CHAR_ARG_LEN (1)
  739.                            F77_CHAR_ARG_LEN (1)));
  740.  
  741.   if (info != 0)
  742.     (*current_liboctave_error_handler) ("cggev workspace query failed");
  743.  
  744.   lwork = static_cast<octave_idx_type> (dummy_work.real ());
  745.   Array<FloatComplex> work (dim_vector (lwork, 1));
  746.   FloatComplex *pwork = work.fortran_vec ();
  747.  
  748.   F77_XFCN (cggev, CGGEV, (F77_CONST_CHAR_ARG2 ("N", 1),
  749.                            F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  750.                            n, atmp_data, n, btmp_data, n,
  751.                            palpha, pbeta, dummy, idummy,
  752.                            pv, n, pwork, lwork, prwork, info
  753.                            F77_CHAR_ARG_LEN (1)
  754.                            F77_CHAR_ARG_LEN (1)));
  755.  
  756.   if (info < 0)
  757.     (*current_liboctave_error_handler) ("unrecoverable error in cggev");
  758.  
  759.   if (info > 0)
  760.     (*current_liboctave_error_handler) ("cggev failed to converge");
  761.  
  762.   lambda.resize (n);
  763.  
  764.   for (octave_idx_type j = 0; j < n; j++)
  765.     lambda.elem (j) = alpha.elem (j) / beta.elem (j);
  766.  
  767.   v = vtmp;
  768.  
  769.   return info;
  770. }
  771.  
  772. octave_idx_type
  773. FloatEIG::hermitian_init (const FloatComplexMatrix& a,
  774.                           const FloatComplexMatrix& b, bool calc_ev)
  775. {
  776.   octave_idx_type n = a.rows ();
  777.   octave_idx_type nb = b.rows ();
  778.  
  779.   if (n != a.cols () || nb != b.cols ())
  780.     (*current_liboctave_error_handler) ("EIG requires square matrix");
  781.  
  782.   if (n != nb)
  783.     (*current_liboctave_error_handler) ("EIG requires same size matrices");
  784.  
  785.   octave_idx_type info = 0;
  786.  
  787.   FloatComplexMatrix atmp = a;
  788.   FloatComplex *atmp_data = atmp.fortran_vec ();
  789.  
  790.   FloatComplexMatrix btmp = b;
  791.   FloatComplex *btmp_data = btmp.fortran_vec ();
  792.  
  793.   FloatColumnVector wr (n);
  794.   float *pwr = wr.fortran_vec ();
  795.  
  796.   octave_idx_type lwork = -1;
  797.   FloatComplex dummy_work;
  798.  
  799.   octave_idx_type lrwork = 3*n;
  800.   Array<float> rwork (dim_vector (lrwork, 1));
  801.   float *prwork = rwork.fortran_vec ();
  802.  
  803.   F77_XFCN (chegv, CHEGV, (1, F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  804.                            F77_CONST_CHAR_ARG2 ("U", 1),
  805.                            n, atmp_data, n,
  806.                            btmp_data, n,
  807.                            pwr, &dummy_work, lwork,
  808.                            prwork, info
  809.                            F77_CHAR_ARG_LEN (1)
  810.                            F77_CHAR_ARG_LEN (1)));
  811.  
  812.   if (info != 0)
  813.     (*current_liboctave_error_handler) ("zhegv workspace query failed");
  814.  
  815.   lwork = static_cast<octave_idx_type> (dummy_work.real ());
  816.   Array<FloatComplex> work (dim_vector (lwork, 1));
  817.   FloatComplex *pwork = work.fortran_vec ();
  818.  
  819.   F77_XFCN (chegv, CHEGV, (1, F77_CONST_CHAR_ARG2 (calc_ev ? "V" : "N", 1),
  820.                            F77_CONST_CHAR_ARG2 ("U", 1),
  821.                            n, atmp_data, n,
  822.                            btmp_data, n,
  823.                            pwr, pwork, lwork, prwork, info
  824.                            F77_CHAR_ARG_LEN (1)
  825.                            F77_CHAR_ARG_LEN (1)));
  826.  
  827.   if (info < 0)
  828.     (*current_liboctave_error_handler) ("unrecoverable error in zhegv");
  829.  
  830.   if (info > 0)
  831.     (*current_liboctave_error_handler) ("zhegv failed to converge");
  832.  
  833.   lambda = FloatComplexColumnVector (wr);
  834.   v = calc_ev ? FloatComplexMatrix (atmp) : FloatComplexMatrix ();
  835.  
  836.   return info;
  837. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement