!>
!> @file test_pwsmp.f90
!>
!> @brief
!>
!> @copyright
!> Copyright (©) 2021 EPFL (Ecole Polytechnique Fédérale de Lausanne)
!> SPC (Swiss Plasma Center)
!>
!> spclibs is free software: you can redistribute it and/or modify it under
!> the terms of the GNU Lesser General Public License as published by the Free
!> Software Foundation, either version 3 of the License, or (at your option)
!> any later version.
!>
!> spclibs is distributed in the hope that it will be useful, but WITHOUT ANY
!> WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
!> FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
!>
!> You should have received a copy of the GNU Lesser General Public License
!> along with this program. If not, see .
!>
!> @authors
!> (in alphabetical order)
!> @author Trach-Minh Tran
!>
PROGRAM main
!!$ USE futils
IMPLICIT NONE
INCLUDE 'mpif.h'
!
INTEGER :: npes, me, ierr, comm=MPI_COMM_WORLD
INTEGER :: l, i, lun=99
INTEGER :: nrank, nnz, s, e, nrank_loc, nnz_loc, nnz_sum
INTEGER :: istart, iend
INTEGER, ALLOCATABLE :: irow(:), cols(:)
INTEGER, ALLOCATABLE :: irow_loc(:), cols_loc(:)
DOUBLE PRECISION, ALLOCATABLE :: val(:), val_loc(:)
DOUBLE PRECISION, ALLOCATABLE :: rhs(:), rhs_loc(:)
DOUBLE PRECISION, ALLOCATABLE :: sol(:), sol_loc(:)
DOUBLE PRECISION :: mem
CHARACTER(len=128) :: fname = "mat.dat"
DOUBLE PRECISION :: mem_loc, mem_min, mem_max
DOUBLE PRECISION :: err, err_max, err_norm
DOUBLE PRECISION :: t0, tfact, tsolv
INTEGER :: it, nits=100
!
! PWSMP vars
!
DOUBLE PRECISION :: dparm(64)
INTEGER :: iparm(64)
INTEGER, ALLOCATABLE :: perm(:), invp(:)
!
INTEGER :: mrp ! just a placeholder in this program
DOUBLE PRECISION :: aux, diag ! just placeholders in this program
INTEGER :: naux=0, nrhs=1
!===========================================================================
! 1.0 Prologue
!
CALL mpi_init(ierr)
CALL mpi_comm_size(comm, npes, ierr)
CALL mpi_comm_rank(comm, me, ierr)
!===========================================================================
! 2.0 Read matrix
!
! File header
IF( command_argument_count() > 0 ) THEN
CALL get_command_argument(1, fname, l, ierr)
END IF
OPEN(unit=lun, file=fname, form="unformatted")
READ(lun) nrank, nnz
IF(me.EQ.0) WRITE(*,'(a,3i16)') 'npes, nrank, nnz', npes, nrank, nnz
!
! Matrix partition
CALL dist1d(comm, 1, nrank, istart, nrank_loc)
iend = istart+nrank_loc-1
WRITE(*,'(a,i3.3,a,2i12)') 'PE', me, ':istart, iend', istart, iend
ALLOCATE(irow_loc(nrank_loc+1))
!
! Read irow
ALLOCATE(irow(nrank+1))
READ(lun) irow
nnz_loc = irow(iend+1)-irow(istart)
CALL mpi_reduce(nnz_loc, nnz_sum, 1, MPI_INTEGER, MPI_SUM, 0, comm, ierr)
IF(me.EQ.0) THEN
PRINT*, 'nnz_sum', nnz_sum
END IF
irow_loc(:) = irow(istart:iend+1) ! Still unshifted
DEALLOCATE(irow)
!
ALLOCATE(cols_loc(nnz_loc))
ALLOCATE(val_loc(nnz_loc))
ALLOCATE(rhs_loc(nrank_loc))
ALLOCATE(sol_loc(nrank_loc))
!
s = irow_loc(1)
e = irow_loc(nrank_loc+1)-1
irow_loc(:) = irow_loc(:)-s+1 ! Shifted relative irow
WRITE(*,'(a,i3.3,a,3i12)') 'PE', me, ':s, e, nnz_loc', s, e, nnz_loc
!
! Read cols
ALLOCATE(cols(nnz))
READ(lun) cols
cols_loc(:) = cols(s:e)
DEALLOCATE(cols)
!
! Read vals
ALLOCATE(val(nnz))
READ(lun) val
val_loc(:) = val(s:e)
DEALLOCATE(val)
!
! Read RHS
ALLOCATE(rhs(nrank))
READ(lun) rhs
rhs_loc(:) = rhs(istart:iend)
DEALLOCATE(rhs)
!
!!$ mem_loc = mem()
!!$ CALL minmax_r(mem_loc, comm, 'mem used (MB) after matrix read')
!===========================================================================
! 3.0 Call PWSMP
!
! Initializing of PWSMP.
!
!!$ CALL pwsmp_initialize
ALLOCATE(invp(nrank), perm(nrank))
!
! Fill 'iparm' and 'dparm' arrays with default values.
iparm(1:3) = 0
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& rhs_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP init failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP init ok'
END IF
!
! Ordering
iparm(2) = 1
iparm(3) = 1
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& rhs_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP ordering failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP ordering ok'
END IF
!
! Symbolic factorization
iparm(2) = 2
iparm(3) = 2
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& rhs_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP symbolic failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP symbolic ok'
END IF
IF(me.EQ.0) THEN
PRINT *,'Number of nonzeros in factor L = 1000 X ',iparm(24)
PRINT *,'Number of FLOPS in factorization = ',dparm(23)
PRINT *,'Double words needed to factor on 0 = 1000 X ',iparm(23)
END IF
!
! Cholesky factorizarion
iparm(2) = 3
iparm(3) = 3
t0 = mpi_wtime()
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& rhs_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
tfact = mpi_wtime()-t0
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP Choleski failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP Choleski ok'
END IF
!
! Backsolve
t0 = mpi_wtime()
DO it=1,nits
sol_loc=rhs_loc
iparm(2) = 4
iparm(3) = 4
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& sol_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
END DO
rhs_loc=sol_loc
tsolv = (mpi_wtime()-t0)/REAL(nits,8)
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP backsolve failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP backsolve ok'
END IF
!
! Iterative refinement
iparm(2) = 5
iparm(3) = 5
CALL pwssmp (nrank_loc, irow_loc, cols_loc, val_loc, diag, perm, invp, &
& rhs_loc, nrank_loc, nrhs, &
& aux, naux, mrp, iparm, dparm)
IF(iparm(64).NE.0) THEN
PRINT*, 'WSMP refinement failed with iparm(64) =', iparm(64)
CALL mpi_abort(comm, iparm(64), ierr)
ELSE
IF(me.EQ.0) PRINT*, 'WSMP refinement ok'
END IF
!
!!$ mem_loc = mem()
!!$ CALL minmax_r(mem_loc, comm, 'mem used (MB) after PWSMP')
!===========================================================================
! 4.0 Check SOL
!
! Read SOL
ALLOCATE(sol(nrank))
READ(lun) sol
sol_loc(:) = sol(istart:iend)
DEALLOCATE(sol)
PRINT*, 'Comp. sol', SUM(rhs_loc)
PRINT*, 'Exact sol', SUM(sol_loc)
!
err=MAXVAL(ABS(sol_loc-rhs_loc))
CALL mpi_reduce(err, err_max, 1, MPI_DOUBLE_PRECISION, MPI_MAX, 0, comm, ierr)
IF(me.EQ.0) THEN
PRINT*, 'Max. error', err_max
END IF
rhs_loc = rhs_loc-sol_loc
err = DOT_PRODUCT(rhs_loc,rhs_loc)
CALL mpi_reduce(err, err_norm, 1, MPI_DOUBLE_PRECISION, MPI_SUM, 0, comm, ierr)
IF(me.EQ.0) THEN
PRINT*, 'Norm of error', SQRT(err_norm)
END IF
!
!!$ mem_loc = mem()
!!$ CALL minmax_r(mem_loc, comm, 'mem used (MB)')
!===========================================================================
! 9.0 Epilogue
!
CALL minmax_r(tfact, comm, 'Factorisation time(s)')
CALL minmax_r(tsolv, comm, ' Backsolve time(s)')
CALL mpi_finalize(ierr)
!
CONTAINS
SUBROUTINE dist1d(comm, s0, ntot, s, nloc)
!
! 1d distribute ntot elements, returns offset s and local number of
! elements nloc.
!
IMPLICIT NONE
INCLUDE 'mpif.h'
INTEGER, INTENT(in) :: s0, ntot
INTEGER, INTENT(out) :: s, nloc
INTEGER :: comm, me, npes, ierr, naver, rem
!
CALL MPI_COMM_SIZE(comm, npes, ierr)
CALL MPI_COMM_RANK(comm, me, ierr)
naver = ntot/npes
rem = MODULO(ntot,npes)
s = s0 + MIN(rem,me) + me*naver
nloc = naver
IF( me.LT.rem ) nloc = nloc+1
END SUBROUTINE dist1d
!
SUBROUTINE minmax_r(x, comm, str)
CHARACTER(len=*), INTENT(in) :: str
DOUBLE PRECISION, INTENT(in) :: x
INTEGER, INTENT(in) :: comm
INTEGER :: me, ierr
DOUBLE PRECISION :: xmin, xmax
CALL mpi_comm_rank(comm, me, ierr)
CALL mpi_reduce(x, xmin, 1, MPI_DOUBLE_PRECISION, MPI_MIN, 0, comm, ierr)
CALL mpi_reduce(x, xmax, 1, MPI_DOUBLE_PRECISION, MPI_MAX, 0, comm, ierr)
IF( me.EQ.0 ) THEN
WRITE(*,'(a,2(1pe12.4))') 'Minmax of ' // TRIM(str), xmin, xmax
END IF
END SUBROUTINE minmax_r
!
END PROGRAM main