Path: blob/master/thirdparty/jolt_physics/Jolt/Math/GaussianElimination.h
9913 views
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)1// SPDX-FileCopyrightText: 2021 Jorrit Rouwe2// SPDX-License-Identifier: MIT34#pragma once56JPH_NAMESPACE_BEGIN78/// This function performs Gauss-Jordan elimination to solve a matrix equation.9/// A must be an NxN matrix and B must be an NxM matrix forming the equation A * x = B10/// on output B will contain x and A will be destroyed.11///12/// This code can be used for example to compute the inverse of a matrix.13/// Set A to the matrix to invert, set B to identity and let GaussianElimination solve14/// the equation, on return B will be the inverse of A. And A is destroyed.15///16/// Taken and adapted from Numerical Recipes in C paragraph 2.117template <class MatrixA, class MatrixB>18bool GaussianElimination(MatrixA &ioA, MatrixB &ioB, float inTolerance = 1.0e-16f)19{20// Get problem dimensions21const uint n = ioA.GetCols();22const uint m = ioB.GetCols();2324// Check matrix requirement25JPH_ASSERT(ioA.GetRows() == n);26JPH_ASSERT(ioB.GetRows() == n);2728// Create array for bookkeeping on pivoting29int *ipiv = (int *)JPH_STACK_ALLOC(n * sizeof(int));30memset(ipiv, 0, n * sizeof(int));3132for (uint i = 0; i < n; ++i)33{34// Initialize pivot element as the diagonal35uint pivot_row = i, pivot_col = i;3637// Determine pivot element38float largest_element = 0.0f;39for (uint j = 0; j < n; ++j)40if (ipiv[j] != 1)41for (uint k = 0; k < n; ++k)42{43if (ipiv[k] == 0)44{45float element = abs(ioA(j, k));46if (element >= largest_element)47{48largest_element = element;49pivot_row = j;50pivot_col = k;51}52}53else if (ipiv[k] > 1)54{55return false;56}57}5859// Mark this column as used60++ipiv[pivot_col];6162// Exchange rows when needed so that the pivot element is at ioA(pivot_col, pivot_col) instead of at ioA(pivot_row, pivot_col)63if (pivot_row != pivot_col)64{65for (uint j = 0; j < n; ++j)66std::swap(ioA(pivot_row, j), ioA(pivot_col, j));67for (uint j = 0; j < m; ++j)68std::swap(ioB(pivot_row, j), ioB(pivot_col, j));69}7071// Get diagonal element that we are about to set to 172float diagonal_element = ioA(pivot_col, pivot_col);73if (abs(diagonal_element) < inTolerance)74return false;7576// Divide the whole row by the pivot element, making ioA(pivot_col, pivot_col) = 177for (uint j = 0; j < n; ++j)78ioA(pivot_col, j) /= diagonal_element;79for (uint j = 0; j < m; ++j)80ioB(pivot_col, j) /= diagonal_element;81ioA(pivot_col, pivot_col) = 1.0f;8283// Next reduce the rows, except for the pivot one,84// after this step the pivot_col column is zero except for the pivot element which is 185for (uint j = 0; j < n; ++j)86if (j != pivot_col)87{88float element = ioA(j, pivot_col);89for (uint k = 0; k < n; ++k)90ioA(j, k) -= ioA(pivot_col, k) * element;91for (uint k = 0; k < m; ++k)92ioB(j, k) -= ioB(pivot_col, k) * element;93ioA(j, pivot_col) = 0.0f;94}95}9697// Success98return true;99}100101JPH_NAMESPACE_END102103104