Line data Source code
1 : //
2 : // Copyright 2015 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 : // Matrix:
7 : // Utility class implementing various matrix operations.
8 : // Supports matrices with minimum 2 and maximum 4 number of rows/columns.
9 : //
10 : // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this implementation.
11 : // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
12 :
13 : #ifndef COMMON_MATRIX_UTILS_H_
14 : #define COMMON_MATRIX_UTILS_H_
15 :
16 : #include <vector>
17 :
18 : #include "common/debug.h"
19 : #include "common/mathutil.h"
20 :
21 : namespace angle
22 : {
23 :
24 : template<typename T>
25 0 : class Matrix
26 : {
27 : public:
28 0 : Matrix(const std::vector<T> &elements, const unsigned int &numRows, const unsigned int &numCols)
29 : : mElements(elements),
30 : mRows(numRows),
31 0 : mCols(numCols)
32 : {
33 0 : ASSERT(rows() >= 1 && rows() <= 4);
34 0 : ASSERT(columns() >= 1 && columns() <= 4);
35 0 : }
36 :
37 0 : Matrix(const std::vector<T> &elements, const unsigned int &size)
38 : : mElements(elements),
39 : mRows(size),
40 0 : mCols(size)
41 : {
42 0 : ASSERT(rows() >= 1 && rows() <= 4);
43 0 : ASSERT(columns() >= 1 && columns() <= 4);
44 0 : }
45 :
46 0 : Matrix(const T *elements, const unsigned int &size)
47 : : mRows(size),
48 0 : mCols(size)
49 : {
50 0 : ASSERT(rows() >= 1 && rows() <= 4);
51 0 : ASSERT(columns() >= 1 && columns() <= 4);
52 0 : for (size_t i = 0; i < size * size; i++)
53 0 : mElements.push_back(elements[i]);
54 0 : }
55 :
56 0 : const T &operator()(const unsigned int &rowIndex, const unsigned int &columnIndex) const
57 : {
58 0 : return mElements[rowIndex * columns() + columnIndex];
59 : }
60 :
61 0 : T &operator()(const unsigned int &rowIndex, const unsigned int &columnIndex)
62 : {
63 0 : return mElements[rowIndex * columns() + columnIndex];
64 : }
65 :
66 0 : const T &at(const unsigned int &rowIndex, const unsigned int &columnIndex) const
67 : {
68 0 : return operator()(rowIndex, columnIndex);
69 : }
70 :
71 : Matrix<T> operator*(const Matrix<T> &m)
72 : {
73 : ASSERT(columns() == m.rows());
74 :
75 : unsigned int resultRows = rows();
76 : unsigned int resultCols = m.columns();
77 : Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
78 : for (unsigned int i = 0; i < resultRows; i++)
79 : {
80 : for (unsigned int j = 0; j < resultCols; j++)
81 : {
82 : T tmp = 0.0f;
83 : for (unsigned int k = 0; k < columns(); k++)
84 : tmp += at(i, k) * m(k, j);
85 : result(i, j) = tmp;
86 : }
87 : }
88 :
89 : return result;
90 : }
91 :
92 0 : unsigned int size() const
93 : {
94 0 : ASSERT(rows() == columns());
95 0 : return rows();
96 : }
97 :
98 0 : unsigned int rows() const { return mRows; }
99 :
100 0 : unsigned int columns() const { return mCols; }
101 :
102 0 : std::vector<T> elements() const { return mElements; }
103 :
104 0 : Matrix<T> compMult(const Matrix<T> &mat1) const
105 : {
106 0 : Matrix result(std::vector<T>(mElements.size()), size());
107 0 : for (unsigned int i = 0; i < columns(); i++)
108 0 : for (unsigned int j = 0; j < rows(); j++)
109 0 : result(i, j) = at(i, j) * mat1(i, j);
110 :
111 0 : return result;
112 : }
113 :
114 0 : Matrix<T> outerProduct(const Matrix<T> &mat1) const
115 : {
116 0 : unsigned int cols = mat1.columns();
117 0 : Matrix result(std::vector<T>(rows() * cols), rows(), cols);
118 0 : for (unsigned int i = 0; i < rows(); i++)
119 0 : for (unsigned int j = 0; j < cols; j++)
120 0 : result(i, j) = at(i, 0) * mat1(0, j);
121 :
122 0 : return result;
123 : }
124 :
125 0 : Matrix<T> transpose() const
126 : {
127 0 : Matrix result(std::vector<T>(mElements.size()), columns(), rows());
128 0 : for (unsigned int i = 0; i < columns(); i++)
129 0 : for (unsigned int j = 0; j < rows(); j++)
130 0 : result(i, j) = at(j, i);
131 :
132 0 : return result;
133 : }
134 :
135 0 : T determinant() const
136 : {
137 0 : ASSERT(rows() == columns());
138 :
139 0 : switch (size())
140 : {
141 : case 2:
142 0 : return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
143 :
144 : case 3:
145 0 : return at(0, 0) * at(1, 1) * at(2, 2) +
146 0 : at(0, 1) * at(1, 2) * at(2, 0) +
147 0 : at(0, 2) * at(1, 0) * at(2, 1) -
148 0 : at(0, 2) * at(1, 1) * at(2, 0) -
149 0 : at(0, 1) * at(1, 0) * at(2, 2) -
150 0 : at(0, 0) * at(1, 2) * at(2, 1);
151 :
152 : case 4:
153 : {
154 : const float minorMatrices[4][3 * 3] =
155 : {
156 : {
157 0 : at(1, 1), at(2, 1), at(3, 1),
158 0 : at(1, 2), at(2, 2), at(3, 2),
159 0 : at(1, 3), at(2, 3), at(3, 3),
160 : },
161 : {
162 0 : at(1, 0), at(2, 0), at(3, 0),
163 0 : at(1, 2), at(2, 2), at(3, 2),
164 0 : at(1, 3), at(2, 3), at(3, 3),
165 : },
166 : {
167 0 : at(1, 0), at(2, 0), at(3, 0),
168 0 : at(1, 1), at(2, 1), at(3, 1),
169 0 : at(1, 3), at(2, 3), at(3, 3),
170 : },
171 : {
172 0 : at(1, 0), at(2, 0), at(3, 0),
173 0 : at(1, 1), at(2, 1), at(3, 1),
174 0 : at(1, 2), at(2, 2), at(3, 2),
175 : }
176 0 : };
177 0 : return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
178 0 : at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
179 0 : at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
180 0 : at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
181 : }
182 :
183 : default:
184 0 : UNREACHABLE();
185 : break;
186 : }
187 :
188 : return T();
189 : }
190 :
191 0 : Matrix<T> inverse() const
192 : {
193 0 : ASSERT(rows() == columns());
194 :
195 0 : Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
196 0 : switch (size())
197 : {
198 : case 2:
199 0 : cof(0, 0) = at(1, 1);
200 0 : cof(0, 1) = -at(1, 0);
201 0 : cof(1, 0) = -at(0, 1);
202 0 : cof(1, 1) = at(0, 0);
203 0 : break;
204 :
205 : case 3:
206 0 : cof(0, 0) = at(1, 1) * at(2, 2) -
207 0 : at(2, 1) * at(1, 2);
208 0 : cof(0, 1) = -(at(1, 0) * at(2, 2) -
209 0 : at(2, 0) * at(1, 2));
210 0 : cof(0, 2) = at(1, 0) * at(2, 1) -
211 0 : at(2, 0) * at(1, 1);
212 0 : cof(1, 0) = -(at(0, 1) * at(2, 2) -
213 0 : at(2, 1) * at(0, 2));
214 0 : cof(1, 1) = at(0, 0) * at(2, 2) -
215 0 : at(2, 0) * at(0, 2);
216 0 : cof(1, 2) = -(at(0, 0) * at(2, 1) -
217 0 : at(2, 0) * at(0, 1));
218 0 : cof(2, 0) = at(0, 1) * at(1, 2) -
219 0 : at(1, 1) * at(0, 2);
220 0 : cof(2, 1) = -(at(0, 0) * at(1, 2) -
221 0 : at(1, 0) * at(0, 2));
222 0 : cof(2, 2) = at(0, 0) * at(1, 1) -
223 0 : at(1, 0) * at(0, 1);
224 0 : break;
225 :
226 : case 4:
227 0 : cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) +
228 0 : at(2, 1) * at(3, 2) * at(1, 3) +
229 0 : at(3, 1) * at(1, 2) * at(2, 3) -
230 0 : at(1, 1) * at(3, 2) * at(2, 3) -
231 0 : at(2, 1) * at(1, 2) * at(3, 3) -
232 0 : at(3, 1) * at(2, 2) * at(1, 3);
233 0 : cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) +
234 0 : at(2, 0) * at(3, 2) * at(1, 3) +
235 0 : at(3, 0) * at(1, 2) * at(2, 3) -
236 0 : at(1, 0) * at(3, 2) * at(2, 3) -
237 0 : at(2, 0) * at(1, 2) * at(3, 3) -
238 0 : at(3, 0) * at(2, 2) * at(1, 3));
239 0 : cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) +
240 0 : at(2, 0) * at(3, 1) * at(1, 3) +
241 0 : at(3, 0) * at(1, 1) * at(2, 3) -
242 0 : at(1, 0) * at(3, 1) * at(2, 3) -
243 0 : at(2, 0) * at(1, 1) * at(3, 3) -
244 0 : at(3, 0) * at(2, 1) * at(1, 3);
245 0 : cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) +
246 0 : at(2, 0) * at(3, 1) * at(1, 2) +
247 0 : at(3, 0) * at(1, 1) * at(2, 2) -
248 0 : at(1, 0) * at(3, 1) * at(2, 2) -
249 0 : at(2, 0) * at(1, 1) * at(3, 2) -
250 0 : at(3, 0) * at(2, 1) * at(1, 2));
251 0 : cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) +
252 0 : at(2, 1) * at(3, 2) * at(0, 3) +
253 0 : at(3, 1) * at(0, 2) * at(2, 3) -
254 0 : at(0, 1) * at(3, 2) * at(2, 3) -
255 0 : at(2, 1) * at(0, 2) * at(3, 3) -
256 0 : at(3, 1) * at(2, 2) * at(0, 3));
257 0 : cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) +
258 0 : at(2, 0) * at(3, 2) * at(0, 3) +
259 0 : at(3, 0) * at(0, 2) * at(2, 3) -
260 0 : at(0, 0) * at(3, 2) * at(2, 3) -
261 0 : at(2, 0) * at(0, 2) * at(3, 3) -
262 0 : at(3, 0) * at(2, 2) * at(0, 3);
263 0 : cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) +
264 0 : at(2, 0) * at(3, 1) * at(0, 3) +
265 0 : at(3, 0) * at(0, 1) * at(2, 3) -
266 0 : at(0, 0) * at(3, 1) * at(2, 3) -
267 0 : at(2, 0) * at(0, 1) * at(3, 3) -
268 0 : at(3, 0) * at(2, 1) * at(0, 3));
269 0 : cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) +
270 0 : at(2, 0) * at(3, 1) * at(0, 2) +
271 0 : at(3, 0) * at(0, 1) * at(2, 2) -
272 0 : at(0, 0) * at(3, 1) * at(2, 2) -
273 0 : at(2, 0) * at(0, 1) * at(3, 2) -
274 0 : at(3, 0) * at(2, 1) * at(0, 2);
275 0 : cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) +
276 0 : at(1, 1) * at(3, 2) * at(0, 3) +
277 0 : at(3, 1) * at(0, 2) * at(1, 3) -
278 0 : at(0, 1) * at(3, 2) * at(1, 3) -
279 0 : at(1, 1) * at(0, 2) * at(3, 3) -
280 0 : at(3, 1) * at(1, 2) * at(0, 3);
281 0 : cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) +
282 0 : at(1, 0) * at(3, 2) * at(0, 3) +
283 0 : at(3, 0) * at(0, 2) * at(1, 3) -
284 0 : at(0, 0) * at(3, 2) * at(1, 3) -
285 0 : at(1, 0) * at(0, 2) * at(3, 3) -
286 0 : at(3, 0) * at(1, 2) * at(0, 3));
287 0 : cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) +
288 0 : at(1, 0) * at(3, 1) * at(0, 3) +
289 0 : at(3, 0) * at(0, 1) * at(1, 3) -
290 0 : at(0, 0) * at(3, 1) * at(1, 3) -
291 0 : at(1, 0) * at(0, 1) * at(3, 3) -
292 0 : at(3, 0) * at(1, 1) * at(0, 3);
293 0 : cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) +
294 0 : at(1, 0) * at(3, 1) * at(0, 2) +
295 0 : at(3, 0) * at(0, 1) * at(1, 2) -
296 0 : at(0, 0) * at(3, 1) * at(1, 2) -
297 0 : at(1, 0) * at(0, 1) * at(3, 2) -
298 0 : at(3, 0) * at(1, 1) * at(0, 2));
299 0 : cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) +
300 0 : at(1, 1) * at(2, 2) * at(0, 3) +
301 0 : at(2, 1) * at(0, 2) * at(1, 3) -
302 0 : at(0, 1) * at(2, 2) * at(1, 3) -
303 0 : at(1, 1) * at(0, 2) * at(2, 3) -
304 0 : at(2, 1) * at(1, 2) * at(0, 3));
305 0 : cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) +
306 0 : at(1, 0) * at(2, 2) * at(0, 3) +
307 0 : at(2, 0) * at(0, 2) * at(1, 3) -
308 0 : at(0, 0) * at(2, 2) * at(1, 3) -
309 0 : at(1, 0) * at(0, 2) * at(2, 3) -
310 0 : at(2, 0) * at(1, 2) * at(0, 3);
311 0 : cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) +
312 0 : at(1, 0) * at(2, 1) * at(0, 3) +
313 0 : at(2, 0) * at(0, 1) * at(1, 3) -
314 0 : at(0, 0) * at(2, 1) * at(1, 3) -
315 0 : at(1, 0) * at(0, 1) * at(2, 3) -
316 0 : at(2, 0) * at(1, 1) * at(0, 3));
317 0 : cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) +
318 0 : at(1, 0) * at(2, 1) * at(0, 2) +
319 0 : at(2, 0) * at(0, 1) * at(1, 2) -
320 0 : at(0, 0) * at(2, 1) * at(1, 2) -
321 0 : at(1, 0) * at(0, 1) * at(2, 2) -
322 0 : at(2, 0) * at(1, 1) * at(0, 2);
323 0 : break;
324 :
325 : default:
326 0 : UNREACHABLE();
327 : break;
328 : }
329 :
330 : // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the determinant of A.
331 0 : Matrix<T> adjugateMatrix(cof.transpose());
332 0 : T det = determinant();
333 0 : Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
334 0 : for (unsigned int i = 0; i < rows(); i++)
335 0 : for (unsigned int j = 0; j < columns(); j++)
336 0 : result(i, j) = det ? adjugateMatrix(i, j) / det : T();
337 :
338 0 : return result;
339 : }
340 :
341 : void setToIdentity()
342 : {
343 : ASSERT(rows() == columns());
344 :
345 : const auto one = T(1);
346 : const auto zero = T(0);
347 :
348 : for (auto &e : mElements)
349 : e = zero;
350 :
351 : for (unsigned int i = 0; i < rows(); ++i)
352 : {
353 : const auto pos = i * columns() + (i % columns());
354 : mElements[pos] = one;
355 : }
356 : }
357 :
358 : template <unsigned int Size>
359 : static void setToIdentity(T(&matrix)[Size])
360 : {
361 : static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
362 :
363 : const auto cols = gl::iSquareRoot<Size>();
364 : const auto one = T(1);
365 : const auto zero = T(0);
366 :
367 : for (auto &e : matrix)
368 : e = zero;
369 :
370 : for (unsigned int i = 0; i < cols; ++i)
371 : {
372 : const auto pos = i * cols + (i % cols);
373 : matrix[pos] = one;
374 : }
375 : }
376 :
377 : private:
378 : std::vector<T> mElements;
379 : unsigned int mRows;
380 : unsigned int mCols;
381 : };
382 :
383 : } // namespace angle
384 :
385 : #endif // COMMON_MATRIX_UTILS_H_
386 :
|