!-----------------------------------------------------------------------------------------------------------------------------------
! PROGRAM LAPLACE_MPI
! Computes the electrostatic potential around conductors using MPI.
! Employs a 1D domain decomposition and the Jacobi relaxation method.
! The code utilizes a 1D column-wise slice. Each MPI rank holds an L x  Nloc chunk of the lattice
! 
! Ghost Cells (Halos): V(L, 0:Nloc+1). 
! Indices 0 and Nloc +1 do not belong to the local processor; they are strictly memory buffers meant to receive the boundary data from adjacent processors.
! MPI_Sendrecv: This replaces manual send/receive blocking pairs, introducing the concept of simultaneous data exchange to prevent deadlocks.
! The convergence criterion: MPI_Allreduce applies the MPI_MAX operation so every process agrees on when to exit the loop.
!
! DATA STRUCTURE (Per Process):
!   real(8) V(L, 0:Nloc+1)     : Local potential with ghost columns (0 and Nloc+1)
!   real(8) Vnew(L, 1:Nloc)    : Auxiliary array for the Jacobi update
!   logical isConductor(L, 1:Nloc): .TRUE. if the site has fixed potential
!
! Test:
!  mpifort laplace.f90 -o p                    ;  gfortran laplace-vanilla.f90 -o l;
! (echo -100 100;echo 1e-12) | mpirun -n 4 ./p ; (echo 60; echo -100 100;echo 1e-12) | ./l
! splot "<paste data pdata" u 1:2:($6-$3) w l  ; splot "<paste data pdata" u 1:2:3 w l t "Vanilla", "" u 1:2:6 w l t "Parallel"
!-----------------------------------------------------------------------------------------------------------------------------------
program laplace_mpi
 use, intrinsic        :: iso_fortran_env
 use mpi
 implicit none
 integer, parameter    :: dp = real64
 ! Global lattice size. Chosen as 60 so it divides easily by 2, 3, 4, 5, or 6 MPI tasks.
 integer, parameter    :: L  = 60  
    
 ! MPI Variables
 integer               :: myrank, nprocs, ierr
 integer               :: left_neighbor, right_neighbor
 integer               :: status(MPI_STATUS_SIZE)
    
 ! Local domain variables
 integer               :: Nloc, j_global
 real(dp), allocatable :: V(:,:), Vnew(:,:)
 logical , allocatable :: isConductor(:,:)
    
 ! Physics & Iteration variables
 real(dp)              :: V1, V2, epsilon
 real(dp)              :: error_local, error_global, dV
 integer               :: i, j, icount
    
 ! For gathering results on Rank 0
 real(8), allocatable  :: V_global(:,:)

 ! 1. Initialize MPI
 call MPI_Init(ierr)
 call MPI_Comm_rank(MPI_COMM_WORLD, myrank, ierr)
 call MPI_Comm_size(MPI_COMM_WORLD, nprocs, ierr)

 ! Ensure the lattice divides evenly among processes for this introductory example
 if (mod(L, nprocs) /= 0) then
  if (myrank == 0) print *, "Error: Lattice size L must be divisible by the number of processes."
  call MPI_Abort(MPI_COMM_WORLD, 1, ierr)
 endif

 Nloc = L / nprocs

 ! Determine neighbors (MPI_PROC_NULL handles the global boundaries)
 left_neighbor  = myrank - 1
 right_neighbor = myrank + 1
 if (myrank == 0            ) left_neighbor  = MPI_PROC_NULL
 if (myrank == nprocs - 1) right_neighbor = MPI_PROC_NULL

 ! Allocate local arrays. Notice the ghost columns at index 0 and Nloc+1.
 allocate(V          (L, 0:Nloc+1))
 allocate(Vnew       (L, 1:Nloc  ))
 allocate(isConductor(L, 1:Nloc  ))

 ! 2. Data Input (Rank 0 reads and broadcasts)
 if (myrank == 0) then
  print *, 'Enter V1, V2: '
  read *, V1, V2
  print *, 'Enter epsilon: '
  read *, epsilon
  print *, 'Starting Laplace MPI:'
  print *, 'Grid Size = ', L
  print *, 'Processes = ', nprocs
 endif
    
 call MPI_Bcast(V1     , 1, MPI_REAL8, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(V2     , 1, MPI_REAL8, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(epsilon, 1, MPI_REAL8, 0, MPI_COMM_WORLD, ierr)

 ! 3. Initialize Local Lattice Geometry
 V           = 0.0_dp
 isConductor = .false.

 conductorset: do j = 1, Nloc

  j_global = (myrank * Nloc) + j
  
  ! Grounded global boundaries
  if (j_global == 1 .or. j_global == L) then
   do i = 1, L
    isConductor(i, j) = .true.
   enddo
  else
   ! Grounded top and bottom
   isConductor(1, j) = .true.
   isConductor(L, j) = .true.
   
   ! Parallel conductors (plates) at specific global i coordinates
   ! The plates are parallel to the y-axis, since i is fixed, and j varies within the limits of the extent of the plate:
   if (j_global >= 5 .and. j_global <= L-5) then
    isConductor(L/3+1  , j) = .true.
    V          (L/3+1  , j) = V1
    
    isConductor(2*L/3+1, j) = .true.
    V          (2*L/3+1, j) = V2
   endif
  endif
 enddo conductorset

 ! 4. The Main Relaxation Loop
 icount       = 0
 relaxation: do while (.true.)
  error_local = 0.0_dp
  ! --- HALO EXCHANGE ---
        
  ! 1. Shift Right:
  ! Send    my rightmost                 column (Nloc) to my right neighbor.
  ! Receive my left neighbor's rightmost column      into my left  ghost column (0).
  call MPI_Sendrecv(V(1, Nloc  ), L, MPI_REAL8, right_neighbor, 0, &
  !                               ^  send L elements from V(1,Nloc)                          <---- column major in Fortran! V(1,Nloc), V(2,Nloc), ... , V(L,Nloc)
                    V(1, 0     ), L, MPI_REAL8, left_neighbor , 0, &
  !                               ^  receive L elements, store in V(1,0) + L positions       <---- column major in Fortran! V(1,0   ), V(2,0   ), ... , V(L,0   )
       MPI_COMM_WORLD, status   , ierr)
                          
  ! 2. Shift Left:
  ! Send my leftmost physical column (1) to my left neighbor.
  ! Receive my right neighbor's leftmost column into my right ghost column (Nloc+1).
  call MPI_Sendrecv(V(1, 1     ), L, MPI_REAL8, left_neighbor , 1, &
                    V(1, Nloc+1), L, MPI_REAL8, right_neighbor, 1, &
       MPI_COMM_WORLD, status   , ierr)

  !MPI_Barrier is not necessary here because MPI_Sendrecv is a blocking communication routine.

  ! --- JACOBI SWEEP ---
  jacobi: do j = 1, Nloc
   do        i = 1, L
    if (.not. isConductor(i, j)) then
     ! Calculate arithmetic mean of nearest neighbors
     Vnew(i,j) = 0.25_dp * (V(i-1,j) + V(i+1,j) + V(i,j-1) + V(i,j+1))
                    
     dV        = abs(Vnew(i,j) - V(i,j))
     if (dV    > error_local) error_local = dV
    else
     ! Conductor points remain fixed
     Vnew(i,j) = V(i,j)
    endif
   enddo
  enddo jacobi

  ! Calculate global maximum error across all processes
  call MPI_Allreduce(error_local, error_global, 1, MPI_REAL8, MPI_MAX, MPI_COMM_WORLD, ierr)

  ! Update local potential array (exclude the halos!)
  V(1:L, 1:Nloc) = Vnew(1:L, 1:Nloc)

  icount = icount + 1
        
  ! Check convergence
  if (error_global < epsilon) exit

 enddo relaxation

 if (myrank == 0) then
  print *, 'Converged after   ', icount, ' sweeps.'
  print *, 'Final max error = ', error_global
 endif

 ! 5. Gather and Print Results
 ! Only rank 0 allocates the global array to collect data from all processes
 if (myrank == 0) allocate(V_global(L, L))

 ! Gather blocks of Nloc columns from each process into V_global on Rank 0
 call MPI_Gather(V       , L * Nloc, MPI_REAL8, &
                 V_global, L * Nloc, MPI_REAL8, &
      0, MPI_COMM_WORLD, ierr)
 
 if (myrank == 0) then
  open(unit=11, file="pdata")
  do  i = 1, L
   do j = 1, L
    write(11, *) i, j, V_global(i, j)
   enddo
   write(11, *) "" ! Empty line for gnuplot splot
  enddo
 endif

 call MPI_Finalize(ierr)

end program laplace_mpi
!-----------------------------------------------------------------------------------------------------------------------------------

