Path: blob/master/src/FMNM/cython/solvers.pyx
1700 views
"""1Created on Mon Jul 29 11:13:45 201923@author: cantaro864"""56import numpy as np7from scipy.linalg import norm8cimport numpy as np9cimport cython1011cdef np.float64_t distance2(np.float64_t[:] a, np.float64_t[:] b, unsigned int N):12cdef np.float64_t dist = 013cdef unsigned int i14for i in range(N):15dist += (a[i] - b[i]) * (a[i] - b[i])16return dist171819@cython.boundscheck(False)20@cython.wraparound(False)21def SOR(np.float64_t aa,22np.float64_t bb, np.float64_t cc,23np.float64_t[:] b,24np.float64_t w=1, np.float64_t eps=1e-10, unsigned int N_max = 500):2526cdef unsigned int N = b.size2728cdef np.float64_t[:] x0 = np.ones(N, dtype=np.float64) # initial guess29cdef np.float64_t[:] x_new = np.ones(N, dtype=np.float64) # new solution303132cdef unsigned int i, k33cdef np.float64_t S3435for k in range(1,N_max+1):36for i in range(N):37if (i==0):38S = cc * x_new[1]39elif (i==N-1):40S = aa * x_new[N-2]41else:42S = aa * x_new[i-1] + cc * x_new[i+1]43x_new[i] = (1-w)*x_new[i] + (w/bb) * (b[i] - S)44if distance2(x_new, x0, N) < eps*eps:45return x_new46x0[:] = x_new47if k==N_max:48print("Fail to converge in {} iterations".format(k))49return x_new505152@cython.boundscheck(False)53@cython.wraparound(False)54def PSOR(np.float64_t aa,55np.float64_t bb, np.float64_t cc,56np.float64_t[:] B, np.float64_t[:] C,57np.float64_t w=1, np.float64_t eps=1e-10, unsigned int N_max = 500):5859cdef unsigned int N = B.size6061cdef np.float64_t[:] x0 = np.ones(N, dtype=np.float64) # initial guess62cdef np.float64_t[:] x_new = np.ones(N, dtype=np.float64) # new solution6364cdef unsigned int i, k65cdef np.float64_t S6667for k in range(1,N_max+1):68for i in range(N):69if (i==0):70S = cc * x_new[1]71elif (i==N-1):72S = aa * x_new[N-2]73else:74S = aa * x_new[i-1] + cc * x_new[i+1]75x_new[i] = (1-w)*x_new[i] + (w/bb) * (B[i] - S)76x_new[i] = x_new[i] if (x_new[i] > C[i]) else C[i]7778if distance2(x_new, x0, N) < eps*eps:79print("Convergence after {} iterations".format(k))80return x_new81x0[:] = x_new82if k==N_max:83print("Fail to converge in {} iterations".format(k))84return x_new8586878889