LCOV - code coverage report
Current view: top level - gfx/angle/src/common - matrix_utils.h (source / functions) Hit Total Coverage
Test: output.info Lines: 0 210 0.0 %
Date: 2017-07-14 16:53:18 Functions: 0 16 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //
       2             : // Copyright 2015 The ANGLE Project Authors. All rights reserved.
       3             : // Use of this source code is governed by a BSD-style license that can be
       4             : // found in the LICENSE file.
       5             : //
       6             : // Matrix:
       7             : //   Utility class implementing various matrix operations.
       8             : //   Supports matrices with minimum 2 and maximum 4 number of rows/columns.
       9             : //
      10             : // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this implementation.
      11             : // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
      12             : 
      13             : #ifndef COMMON_MATRIX_UTILS_H_
      14             : #define COMMON_MATRIX_UTILS_H_
      15             : 
      16             : #include <vector>
      17             : 
      18             : #include "common/debug.h"
      19             : #include "common/mathutil.h"
      20             : 
      21             : namespace angle
      22             : {
      23             : 
      24             : template<typename T>
      25           0 : class Matrix
      26             : {
      27             :   public:
      28           0 :     Matrix(const std::vector<T> &elements, const unsigned int &numRows, const unsigned int &numCols)
      29             :         : mElements(elements),
      30             :           mRows(numRows),
      31           0 :           mCols(numCols)
      32             :     {
      33           0 :         ASSERT(rows() >= 1 && rows() <= 4);
      34           0 :         ASSERT(columns() >= 1 && columns() <= 4);
      35           0 :     }
      36             : 
      37           0 :     Matrix(const std::vector<T> &elements, const unsigned int &size)
      38             :         : mElements(elements),
      39             :           mRows(size),
      40           0 :           mCols(size)
      41             :     {
      42           0 :         ASSERT(rows() >= 1 && rows() <= 4);
      43           0 :         ASSERT(columns() >= 1 && columns() <= 4);
      44           0 :     }
      45             : 
      46           0 :     Matrix(const T *elements, const unsigned int &size)
      47             :         : mRows(size),
      48           0 :           mCols(size)
      49             :     {
      50           0 :         ASSERT(rows() >= 1 && rows() <= 4);
      51           0 :         ASSERT(columns() >= 1 && columns() <= 4);
      52           0 :         for (size_t i = 0; i < size * size; i++)
      53           0 :             mElements.push_back(elements[i]);
      54           0 :     }
      55             : 
      56           0 :     const T &operator()(const unsigned int &rowIndex, const unsigned int &columnIndex) const
      57             :     {
      58           0 :         return mElements[rowIndex * columns() + columnIndex];
      59             :     }
      60             : 
      61           0 :     T &operator()(const unsigned int &rowIndex, const unsigned int &columnIndex)
      62             :     {
      63           0 :         return mElements[rowIndex * columns() + columnIndex];
      64             :     }
      65             : 
      66           0 :     const T &at(const unsigned int &rowIndex, const unsigned int &columnIndex) const
      67             :     {
      68           0 :         return operator()(rowIndex, columnIndex);
      69             :     }
      70             : 
      71             :     Matrix<T> operator*(const Matrix<T> &m)
      72             :     {
      73             :         ASSERT(columns() == m.rows());
      74             : 
      75             :         unsigned int resultRows = rows();
      76             :         unsigned int resultCols = m.columns();
      77             :         Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
      78             :         for (unsigned int i = 0; i < resultRows; i++)
      79             :         {
      80             :             for (unsigned int j = 0; j < resultCols; j++)
      81             :             {
      82             :                 T tmp = 0.0f;
      83             :                 for (unsigned int k = 0; k < columns(); k++)
      84             :                     tmp += at(i, k) * m(k, j);
      85             :                 result(i, j) = tmp;
      86             :             }
      87             :         }
      88             : 
      89             :         return result;
      90             :     }
      91             : 
      92           0 :     unsigned int size() const
      93             :     {
      94           0 :         ASSERT(rows() == columns());
      95           0 :         return rows();
      96             :     }
      97             : 
      98           0 :     unsigned int rows() const { return mRows; }
      99             : 
     100           0 :     unsigned int columns() const { return mCols; }
     101             : 
     102           0 :     std::vector<T> elements() const { return mElements; }
     103             : 
     104           0 :     Matrix<T> compMult(const Matrix<T> &mat1) const
     105             :     {
     106           0 :         Matrix result(std::vector<T>(mElements.size()), size());
     107           0 :         for (unsigned int i = 0; i < columns(); i++)
     108           0 :             for (unsigned int j = 0; j < rows(); j++)
     109           0 :                 result(i, j) = at(i, j) * mat1(i, j);
     110             : 
     111           0 :         return result;
     112             :     }
     113             : 
     114           0 :     Matrix<T> outerProduct(const Matrix<T> &mat1) const
     115             :     {
     116           0 :         unsigned int cols = mat1.columns();
     117           0 :         Matrix result(std::vector<T>(rows() * cols), rows(), cols);
     118           0 :         for (unsigned int i = 0; i < rows(); i++)
     119           0 :             for (unsigned int j = 0; j < cols; j++)
     120           0 :                 result(i, j) = at(i, 0) * mat1(0, j);
     121             : 
     122           0 :         return result;
     123             :     }
     124             : 
     125           0 :     Matrix<T> transpose() const
     126             :     {
     127           0 :         Matrix result(std::vector<T>(mElements.size()), columns(), rows());
     128           0 :         for (unsigned int i = 0; i < columns(); i++)
     129           0 :             for (unsigned int j = 0; j < rows(); j++)
     130           0 :                 result(i, j) = at(j, i);
     131             : 
     132           0 :         return result;
     133             :     }
     134             : 
     135           0 :     T determinant() const
     136             :     {
     137           0 :         ASSERT(rows() == columns());
     138             : 
     139           0 :         switch (size())
     140             :         {
     141             :           case 2:
     142           0 :             return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
     143             : 
     144             :           case 3:
     145           0 :             return at(0, 0) * at(1, 1) * at(2, 2) +
     146           0 :                 at(0, 1) * at(1, 2) * at(2, 0) +
     147           0 :                 at(0, 2) * at(1, 0) * at(2, 1) -
     148           0 :                 at(0, 2) * at(1, 1) * at(2, 0) -
     149           0 :                 at(0, 1) * at(1, 0) * at(2, 2) -
     150           0 :                 at(0, 0) * at(1, 2) * at(2, 1);
     151             : 
     152             :           case 4:
     153             :             {
     154             :                 const float minorMatrices[4][3 * 3] =
     155             :                 {
     156             :                     {
     157           0 :                         at(1, 1), at(2, 1), at(3, 1),
     158           0 :                         at(1, 2), at(2, 2), at(3, 2),
     159           0 :                         at(1, 3), at(2, 3), at(3, 3),
     160             :                     },
     161             :                     {
     162           0 :                         at(1, 0), at(2, 0), at(3, 0),
     163           0 :                         at(1, 2), at(2, 2), at(3, 2),
     164           0 :                         at(1, 3), at(2, 3), at(3, 3),
     165             :                     },
     166             :                     {
     167           0 :                         at(1, 0), at(2, 0), at(3, 0),
     168           0 :                         at(1, 1), at(2, 1), at(3, 1),
     169           0 :                         at(1, 3), at(2, 3), at(3, 3),
     170             :                     },
     171             :                     {
     172           0 :                         at(1, 0), at(2, 0), at(3, 0),
     173           0 :                         at(1, 1), at(2, 1), at(3, 1),
     174           0 :                         at(1, 2), at(2, 2), at(3, 2),
     175             :                     }
     176           0 :               };
     177           0 :               return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
     178           0 :                   at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
     179           0 :                   at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
     180           0 :                   at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
     181             :             }
     182             : 
     183             :           default:
     184           0 :             UNREACHABLE();
     185             :             break;
     186             :         }
     187             : 
     188             :         return T();
     189             :     }
     190             : 
     191           0 :     Matrix<T> inverse() const
     192             :     {
     193           0 :         ASSERT(rows() == columns());
     194             : 
     195           0 :         Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
     196           0 :         switch (size())
     197             :         {
     198             :           case 2:
     199           0 :             cof(0, 0) = at(1, 1);
     200           0 :             cof(0, 1) = -at(1, 0);
     201           0 :             cof(1, 0) = -at(0, 1);
     202           0 :             cof(1, 1) = at(0, 0);
     203           0 :             break;
     204             : 
     205             :           case 3:
     206           0 :             cof(0, 0) = at(1, 1) * at(2, 2) -
     207           0 :                 at(2, 1) * at(1, 2);
     208           0 :             cof(0, 1) = -(at(1, 0) * at(2, 2) -
     209           0 :                 at(2, 0) * at(1, 2));
     210           0 :             cof(0, 2) = at(1, 0) * at(2, 1) -
     211           0 :                 at(2, 0) * at(1, 1);
     212           0 :             cof(1, 0) = -(at(0, 1) * at(2, 2) -
     213           0 :                 at(2, 1) * at(0, 2));
     214           0 :             cof(1, 1) = at(0, 0) * at(2, 2) -
     215           0 :                 at(2, 0) * at(0, 2);
     216           0 :             cof(1, 2) = -(at(0, 0) * at(2, 1) -
     217           0 :                 at(2, 0) * at(0, 1));
     218           0 :             cof(2, 0) = at(0, 1) * at(1, 2) -
     219           0 :                 at(1, 1) * at(0, 2);
     220           0 :             cof(2, 1) = -(at(0, 0) * at(1, 2) -
     221           0 :                 at(1, 0) * at(0, 2));
     222           0 :             cof(2, 2) = at(0, 0) * at(1, 1) -
     223           0 :                 at(1, 0) * at(0, 1);
     224           0 :             break;
     225             : 
     226             :           case 4:
     227           0 :             cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) +
     228           0 :                 at(2, 1) * at(3, 2) * at(1, 3) +
     229           0 :                 at(3, 1) * at(1, 2) * at(2, 3) -
     230           0 :                 at(1, 1) * at(3, 2) * at(2, 3) -
     231           0 :                 at(2, 1) * at(1, 2) * at(3, 3) -
     232           0 :                 at(3, 1) * at(2, 2) * at(1, 3);
     233           0 :             cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) +
     234           0 :                 at(2, 0) * at(3, 2) * at(1, 3) +
     235           0 :                 at(3, 0) * at(1, 2) * at(2, 3) -
     236           0 :                 at(1, 0) * at(3, 2) * at(2, 3) -
     237           0 :                 at(2, 0) * at(1, 2) * at(3, 3) -
     238           0 :                 at(3, 0) * at(2, 2) * at(1, 3));
     239           0 :             cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) +
     240           0 :                 at(2, 0) * at(3, 1) * at(1, 3) +
     241           0 :                 at(3, 0) * at(1, 1) * at(2, 3) -
     242           0 :                 at(1, 0) * at(3, 1) * at(2, 3) -
     243           0 :                 at(2, 0) * at(1, 1) * at(3, 3) -
     244           0 :                 at(3, 0) * at(2, 1) * at(1, 3);
     245           0 :             cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) +
     246           0 :                 at(2, 0) * at(3, 1) * at(1, 2) +
     247           0 :                 at(3, 0) * at(1, 1) * at(2, 2) -
     248           0 :                 at(1, 0) * at(3, 1) * at(2, 2) -
     249           0 :                 at(2, 0) * at(1, 1) * at(3, 2) -
     250           0 :                 at(3, 0) * at(2, 1) * at(1, 2));
     251           0 :             cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) +
     252           0 :                 at(2, 1) * at(3, 2) * at(0, 3) +
     253           0 :                 at(3, 1) * at(0, 2) * at(2, 3) -
     254           0 :                 at(0, 1) * at(3, 2) * at(2, 3) -
     255           0 :                 at(2, 1) * at(0, 2) * at(3, 3) -
     256           0 :                 at(3, 1) * at(2, 2) * at(0, 3));
     257           0 :             cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) +
     258           0 :                 at(2, 0) * at(3, 2) * at(0, 3) +
     259           0 :                 at(3, 0) * at(0, 2) * at(2, 3) -
     260           0 :                 at(0, 0) * at(3, 2) * at(2, 3) -
     261           0 :                 at(2, 0) * at(0, 2) * at(3, 3) -
     262           0 :                 at(3, 0) * at(2, 2) * at(0, 3);
     263           0 :             cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) +
     264           0 :                 at(2, 0) * at(3, 1) * at(0, 3) +
     265           0 :                 at(3, 0) * at(0, 1) * at(2, 3) -
     266           0 :                 at(0, 0) * at(3, 1) * at(2, 3) -
     267           0 :                 at(2, 0) * at(0, 1) * at(3, 3) -
     268           0 :                 at(3, 0) * at(2, 1) * at(0, 3));
     269           0 :             cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) +
     270           0 :                 at(2, 0) * at(3, 1) * at(0, 2) +
     271           0 :                 at(3, 0) * at(0, 1) * at(2, 2) -
     272           0 :                 at(0, 0) * at(3, 1) * at(2, 2) -
     273           0 :                 at(2, 0) * at(0, 1) * at(3, 2) -
     274           0 :                 at(3, 0) * at(2, 1) * at(0, 2);
     275           0 :             cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) +
     276           0 :                 at(1, 1) * at(3, 2) * at(0, 3) +
     277           0 :                 at(3, 1) * at(0, 2) * at(1, 3) -
     278           0 :                 at(0, 1) * at(3, 2) * at(1, 3) -
     279           0 :                 at(1, 1) * at(0, 2) * at(3, 3) -
     280           0 :                 at(3, 1) * at(1, 2) * at(0, 3);
     281           0 :             cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) +
     282           0 :                 at(1, 0) * at(3, 2) * at(0, 3) +
     283           0 :                 at(3, 0) * at(0, 2) * at(1, 3) -
     284           0 :                 at(0, 0) * at(3, 2) * at(1, 3) -
     285           0 :                 at(1, 0) * at(0, 2) * at(3, 3) -
     286           0 :                 at(3, 0) * at(1, 2) * at(0, 3));
     287           0 :             cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) +
     288           0 :                 at(1, 0) * at(3, 1) * at(0, 3) +
     289           0 :                 at(3, 0) * at(0, 1) * at(1, 3) -
     290           0 :                 at(0, 0) * at(3, 1) * at(1, 3) -
     291           0 :                 at(1, 0) * at(0, 1) * at(3, 3) -
     292           0 :                 at(3, 0) * at(1, 1) * at(0, 3);
     293           0 :             cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) +
     294           0 :                 at(1, 0) * at(3, 1) * at(0, 2) +
     295           0 :                 at(3, 0) * at(0, 1) * at(1, 2) -
     296           0 :                 at(0, 0) * at(3, 1) * at(1, 2) -
     297           0 :                 at(1, 0) * at(0, 1) * at(3, 2) -
     298           0 :                 at(3, 0) * at(1, 1) * at(0, 2));
     299           0 :             cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) +
     300           0 :                 at(1, 1) * at(2, 2) * at(0, 3) +
     301           0 :                 at(2, 1) * at(0, 2) * at(1, 3) -
     302           0 :                 at(0, 1) * at(2, 2) * at(1, 3) -
     303           0 :                 at(1, 1) * at(0, 2) * at(2, 3) -
     304           0 :                 at(2, 1) * at(1, 2) * at(0, 3));
     305           0 :             cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) +
     306           0 :                 at(1, 0) * at(2, 2) * at(0, 3) +
     307           0 :                 at(2, 0) * at(0, 2) * at(1, 3) -
     308           0 :                 at(0, 0) * at(2, 2) * at(1, 3) -
     309           0 :                 at(1, 0) * at(0, 2) * at(2, 3) -
     310           0 :                 at(2, 0) * at(1, 2) * at(0, 3);
     311           0 :             cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) +
     312           0 :                 at(1, 0) * at(2, 1) * at(0, 3) +
     313           0 :                 at(2, 0) * at(0, 1) * at(1, 3) -
     314           0 :                 at(0, 0) * at(2, 1) * at(1, 3) -
     315           0 :                 at(1, 0) * at(0, 1) * at(2, 3) -
     316           0 :                 at(2, 0) * at(1, 1) * at(0, 3));
     317           0 :             cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) +
     318           0 :                 at(1, 0) * at(2, 1) * at(0, 2) +
     319           0 :                 at(2, 0) * at(0, 1) * at(1, 2) -
     320           0 :                 at(0, 0) * at(2, 1) * at(1, 2) -
     321           0 :                 at(1, 0) * at(0, 1) * at(2, 2) -
     322           0 :                 at(2, 0) * at(1, 1) * at(0, 2);
     323           0 :             break;
     324             : 
     325             :           default:
     326           0 :             UNREACHABLE();
     327             :             break;
     328             :         }
     329             : 
     330             :         // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the determinant of A.
     331           0 :         Matrix<T> adjugateMatrix(cof.transpose());
     332           0 :         T det = determinant();
     333           0 :         Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
     334           0 :         for (unsigned int i = 0; i < rows(); i++)
     335           0 :             for (unsigned int j = 0; j < columns(); j++)
     336           0 :                 result(i, j) = det ? adjugateMatrix(i, j) / det : T();
     337             : 
     338           0 :         return result;
     339             :     }
     340             : 
     341             :     void setToIdentity()
     342             :     {
     343             :         ASSERT(rows() == columns());
     344             : 
     345             :         const auto one  = T(1);
     346             :         const auto zero = T(0);
     347             : 
     348             :         for (auto &e : mElements)
     349             :             e = zero;
     350             : 
     351             :         for (unsigned int i = 0; i < rows(); ++i)
     352             :         {
     353             :             const auto pos = i * columns() + (i % columns());
     354             :             mElements[pos] = one;
     355             :         }
     356             :     }
     357             : 
     358             :     template <unsigned int Size>
     359             :     static void setToIdentity(T(&matrix)[Size])
     360             :     {
     361             :         static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
     362             : 
     363             :         const auto cols = gl::iSquareRoot<Size>();
     364             :         const auto one  = T(1);
     365             :         const auto zero = T(0);
     366             : 
     367             :         for (auto &e : matrix)
     368             :             e = zero;
     369             : 
     370             :         for (unsigned int i = 0; i < cols; ++i)
     371             :         {
     372             :             const auto pos = i * cols + (i % cols);
     373             :             matrix[pos]    = one;
     374             :         }
     375             :     }
     376             : 
     377             :   private:
     378             :     std::vector<T> mElements;
     379             :     unsigned int mRows;
     380             :     unsigned int mCols;
     381             : };
     382             : 
     383             : } // namespace angle
     384             : 
     385             : #endif   // COMMON_MATRIX_UTILS_H_
     386             : 

Generated by: LCOV version 1.13