Page MenuHomec4science

conjugate_gradient.py
No OneTemporary

File Metadata

Created
Wed, Aug 28, 22:23

conjugate_gradient.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 13:49:52 2018
@author: masc
"""
import numpy as np
from optimizer import S
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def conjugate_gradient(A, b, x):
# initialisation
r = b - np.einsum("ij,j", A, x) # r -> residual
d = r # d -> direction
i = 1
print("### initialisation ###", "\ninitial residual r*r\t",
np.einsum("i,i", r, r))
xmat = [x]
while True:
alpha = np.einsum("i,i", r, r) / np.einsum("i,i",
np.einsum("i, ij", d, A), d)
x = x + alpha * d
r_new = r - np.dot(np.dot(alpha, A), d)
beta = np.einsum("i,i", r_new, r_new) / np.einsum("i,i", r, r)
d_new = r_new + np.dot(beta, d)
xmat.append([x[0], x[1]])
r = r_new
d = d_new
# np.isclose(r_k, np.zeros(r_k.shape)).all()
if np.einsum("i,i", r, r) < 1e-9:
print("### Iteration finishes: r*r < 1e9 ###")
print("final r \t\t", r)
print("final r*r \t\t", np.einsum("i,i", r, r))
print("nit \t\t\t", i)
print("final x \t\t", x)
break
print("residual \t\t", np.einsum("i,i", r, r))
print("x \t\t\t", x)
print("")
i += 1
return xmat
def plot(x):
iterationPoints = []
for i in x:
iterationPoints.append([i[0], i[1], S(i)])
fig = plt.figure()
ax = Axes3D(fig)
X_1, X_2 = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(3, -3, 100))
f = np.array(
[S((x_1, x_2)) for x_1, x_2 in zip(np.ravel(X_1), np.ravel(X_2))])
F = f.reshape(X_1.shape)
ax.plot_surface(X_1, X_2, F, cmap='viridis')
iterationPoints = np.array(iterationPoints)
print("\nx, y and S(x,y) for each iteration\n", iterationPoints)
ax.plot(iterationPoints[:, 0], iterationPoints[:, 1],
iterationPoints[:, 2], 'ro-')
ax.set_title("Conjugate gradient method")
ax.set_xlabel("x")
ax.set_ylabel("y")
if __name__ == '__main__':
A = [[4, 1], [1, 3]] # function's Jacobian (A = `\nabla^2 S(\vec{x})`)
b = [1, 2] # b = A \vec{x} - `\nabla S(\vec{x}`
x = [1, 3]
xmat = conjugate_gradient(A=A, b=b, x=x)
# print("\nx for all iterations: ##", xmat)
plot(xmat)
plt.show()

Event Timeline