diff --git a/exercice_1/conjugate_gradient.py b/exercice_1/conjugate_gradient.py index e69de29..5f2e7a4 100644 --- a/exercice_1/conjugate_gradient.py +++ b/exercice_1/conjugate_gradient.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Oct 9 13:49:52 2018 + +@author: masc +""" + +import numpy as np + +def trans(*args): + args = np.array(*args) + if args.ndim == 1: + transverse = args.T + elif args.ndim != 1: + transverse = np.einsum('ij->ji', args) + return transverse +def dot(a, b): + arg1 = np.array(a) + arg2 = np.array(b) + if arg1.ndim == 1: + if arg2.ndim == 1: + dot = np.einsum('j, j', arg1, arg2) + elif arg1.ndim == 2: + if arg2.ndim == 1: + dot = np.einsum('ij, j', arg1, arg2) + else: + dot = np.dot(arg1, arg2) + return dot + +def conjugate_gradient(A, b, x): +# initialisation + r = b - np.dot(A, x) + p = r + i = 1 + print("### initialisation ###","\ninitial residual\t", np.dot(r.T,r)) + + while True: + alpha = np.dot(r.T, r) / np.dot(np.dot(p.T, A), p) + x = x + np.dot(alpha, p) + r_new = r - np.dot(np.dot(alpha,A), p) + beta = np.dot(r_new.T, r_new) / np.dot(r.T, r) + p_new = r_new + np.dot(beta, p) + + r = r_new + p = p_new + if np.dot(r.T,r) < 1e-9:#np.isclose(r_k, np.zeros(r_k.shape)).all(): + print("### Iteration finishes: r*r < 1e9 ###") + print("final r \t\t", r) + print("final r*r \t\t",np.dot(r, r)) + print("nit \t\t\t", i) + print("final x \t\t", x) + break + + print("residual \t\t", np.dot(r.T,r)) + print("x \t\t\t", x) + print("") + + i += 1 + +if __name__ == '__main__': + A = [[5,1],[1,5]]#[np.linspace(-5, 5, 100)]#[[3, -1], [-1, 3]] + b = [3,1] + x = [0,0] + conjugate_gradient(A=A,b=b,x=x) + + + + + + + + + + + + +#def conj_grad(A, b, x): +# tol = 1e-9 +# r = b - dot(A,x) +# p = r +# rN = dot(r, r) +# while True: +# Ap = dot(A, p) +# alpha = rN / dot(trans(p), Ap) +# x = x - alpha * p +# r = r - alpha * Ap +# rnewN = dot(r,r) +# print("\n") +# print(rnewN) +# +# +# if rnew < tol: +# print(i) +# break +# beta = rnewN / rN +# rN = rnewN +# p = beta * p - r +# +# +#A = [[np.linspace(-5, 5, 100)], [np.linspace(-5, 5, 100)]]#[[3, -1], [-1, 3]] +#b = [1,0] +# +#x = [1,1] +# +#d = [] +#r = [] +#alpha = [] +# +#r_k = np.dot(A, x) - b +#d_k = -r_k +# +#i = 1 +#while True: +# print("iterations:",i) +# +# alpha_k = np.dot(r_k, r_k) / np.dot(np.dot(d_k, Q), d_k) +# x = x + np.dot(alpha_k, d_k) +# r_k1 = r_k + np.dot(np.dot(alpha_k,Q), d_k) +# beta_k1 = np.dot(r_k1, r_k1) / np.dot(r_k, r_k) +# d_k1 = -r_k1 + np.dot(beta_k1, d_k) +# +# r_k = r_k1 +# d_k = d_k1 +# +# if np.isclose(r_k, np.zeros(r_k.shape)).all(): +# print("Breaking!") +# print("final r_k:", r_k) +# print("final x:", x) +# break +# +# print("r_k:",r_k) +# print("x:", x) +# i += 1 +# +# +# +# +## +# +# +## rTr = np.einsum('ij->ji', r) * r +## for i in b: +## pT = np.einsum('ij->ji', p) +## Ap = np.einsum('k,k->', A, p) +## alpha = rTr / (np.einsum('k,k->', pT, Ap)) +## x = x + np.einsum('k,k->', alpha, p) +## r = r - np.einsum('k,k->', alpha, Ap) +## rsnew = np.einsum('ij->ji', r) * r +## +## return x +##def conjugate_gradient(A, b, x=None, **kwargs): +## tol = 1e-9 +## n = len(b) +## if not x: +## x = np.ones(n) +## r = b - np.einsum('k,k->', A, x)#np.dot(A, x) +## p = r +## r_k_norm = np.dot(r, r) +## for i in range(2*n): +## rT = r +## pT = p#np.einsum('ij->ji', p) +## Ap = np.einsum('k,k->', A, p) +## rTr = np.einsum('k,k->', rT, r) +## print(pT.shape); print(Ap.shape) +#### pTAp = np.einsum('k,k->', pT, Ap) +### alpha = rTr / pTAp +### x += alpha * p +### r += alpha * Ap +### r_kplus1_norm = np.einsum('k,k->', r, r) +### beta = r_kplus1_norm / rTr +### r_k_norm = r_kplus1_norm +### if r_kplus1_norm < tol: +### print('Iterations:', i) +### print(r_kplus1_norm) +### break +### p = beta * p - r +### return x +## +### x_ph = tf.placeholder('float32', [None, None]) +### r = tf.matmul(A, x_ph) - b +## +#if __name__ == '__main__': +# n = 100 +# P = np.random.normal(size=[n,n]) #(-5,5,n) +# A = np.dot(P.T, P) +# b = np.ones(n) +# x = ones(n) +# +# t1 = time.time() +# print('start') +# x = conj_grad(A=A, b=b, x=x) \ No newline at end of file