Path: blob/master/src/FMNM/Solvers.py
1700 views
#!/usr/bin/env python31# -*- coding: utf-8 -*-2"""3Created on Sat Jul 27 11:13:45 201945@author: cantaro866"""78import numpy as np9from scipy import sparse10from scipy.linalg import norm, solve_triangular11from scipy.linalg.lapack import get_lapack_funcs12from scipy.linalg.misc import LinAlgError131415def Thomas(A, b):16"""17Solver for the linear equation Ax=b using the Thomas algorithm.18It is a wrapper of the LAPACK function dgtsv.19"""2021D = A.diagonal(0)22L = A.diagonal(-1)23U = A.diagonal(1)2425if len(A.shape) != 2 or A.shape[0] != A.shape[1]:26raise ValueError("expected square matrix")27if A.shape[0] != b.shape[0]:28raise ValueError("incompatible dimensions")2930(dgtsv,) = get_lapack_funcs(("gtsv",))31du2, d, du, x, info = dgtsv(L, D, U, b)3233if info == 0:34return x35if info > 0:36raise LinAlgError("singular matrix: resolution failed at diagonal %d" % (info - 1))373839def SOR(A, b, w=1, eps=1e-10, N_max=100):40"""41Solver for the linear equation Ax=b using the SOR algorithm.42A = L + D + U43Arguments:44L = Strict Lower triangular matrix45D = Diagonal46U = Strict Upper triangular matrix47w = Relaxation coefficient48eps = tollerance49N_max = Max number of iterations50"""5152x0 = b.copy() # initial guess5354if sparse.issparse(A):55D = sparse.diags(A.diagonal()) # diagonal56U = sparse.triu(A, k=1) # Strict U57L = sparse.tril(A, k=-1) # Strict L58DD = (w * L + D).toarray()59else:60D = np.eye(A.shape[0]) * np.diag(A) # diagonal61U = np.triu(A, k=1) # Strict U62L = np.tril(A, k=-1) # Strict L63DD = w * L + D6465for i in range(1, N_max + 1):66x_new = solve_triangular(DD, (w * b - w * U @ x0 - (w - 1) * D @ x0), lower=True)67if norm(x_new - x0) < eps:68return x_new69x0 = x_new70if i == N_max:71raise ValueError("Fail to converge in {} iterations".format(i))727374def SOR2(A, b, w=1, eps=1e-10, N_max=100):75"""76Solver for the linear equation Ax=b using the SOR algorithm.77It uses the coefficients and not the matrix multiplication.78"""79N = len(b)80x0 = np.ones_like(b, dtype=np.float64) # initial guess81x_new = np.ones_like(x0) # new solution8283for k in range(1, N_max + 1):84for i in range(N):85S = 086for j in range(N):87if j != i:88S += A[i, j] * x_new[j]89x_new[i] = (1 - w) * x_new[i] + (w / A[i, i]) * (b[i] - S)9091if norm(x_new - x0) < eps:92return x_new93x0 = x_new.copy()94if k == N_max:95print("Fail to converge in {} iterations".format(k))969798