!=======================================================================
! 1-dimensional Diffusion Equation - MPI Version
! Periodic boundary conditions u(0,t) = u(1,t) on a circle of perimeter 1
! 0 <= x < 1 (with x=1 implicitly mapped back to x=0)
!-----------------------------------------------------------------------
! Test:
! mpifort  pdiffusion.f90 -o d; echo 60 1000 0.1 | mpirun -n 4 ./d
! splot [:][:][:4] "f.dat" w l t "diffusion"
!=======================================================================
program diffusion_1d_periodic_mpi
 use mpi
 implicit none 
    
 integer , parameter    :: dp = 8
 real(dp), allocatable  :: u(:), du(:), u_global(:)
 real(dp)               :: t, x, dx, dt, tf, courant
 integer                :: Nx, Nt, i, j
    
 real(dp), parameter    :: ZERO=0.0_dp, ONE=1.0_dp, HALF=0.5_dp, TWO=2.0_dp
 real(dp), parameter    :: PI=atan2(ZERO,-ONE)

 ! --- MPI Variables ---
 integer                :: my_rank, num_procs, ierr
 integer                :: left_neighbor, right_neighbor
 integer                :: Nloc, i_global
 integer                :: status(MPI_STATUS_SIZE)

 call MPI_Init(ierr)
 call MPI_Comm_rank(MPI_COMM_WORLD, my_rank  , ierr)
 call MPI_Comm_size(MPI_COMM_WORLD, num_procs, ierr)

 ! --- Input (Rank 0 reads and broadcasts) ---
 if (my_rank == 0) then
  print *, '# Enter: Nx, Nt, tf:'
  read  *,           Nx, Nt, tf
  if(Nx <=  3) stop 'Nx <= 3'
  if(Nt <=  2) stop 'Nt <= 2'
 end if

 call MPI_Bcast(Nx, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(Nt, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(tf, 1, MPI_REAL8  , 0, MPI_COMM_WORLD, ierr)

 if (mod(Nx, num_procs) /= 0) then
  if (my_rank == 0) print *, "Error: Nx must be divisible by the number of processes."
  call MPI_Abort(MPI_COMM_WORLD, 1, ierr)
 end if

 Nloc = Nx / num_procs

 ! --- RING TOPOLOGY ---
 ! Connect the ends of the domain to form a periodic circle
 left_neighbor  = my_rank   - 1
 right_neighbor = my_rank   + 1
 if (my_rank   == 0             ) left_neighbor  = num_procs - 1
 if (my_rank   == num_procs - 1 ) right_neighbor = 0

 ! Allocate local arrays with ghost cells at 0 and Nloc+1
 allocate(u(0:Nloc+1), du(1:Nloc))
 du = ZERO

 ! --- Initialize Lattice Geometry ---
 ! dx is now 1/Nx, not 1/(Nx-1)
 dx      = ONE / Nx
 dt      = tf / (Nt-1)
 courant = dt / dx**2
    
 if (my_rank == 0) then
  print * ,'# 1d Diffusion Equation (Periodic): 0<=x<1, 0<=t<=tf'
  print * ,'# dx= ',dx,' dt= ',dt,' tf= ', tf
  print * ,'# Nx= ',Nx,' Nt= ',Nt
  print * ,'# Courant Number= ',courant
  if(courant > HALF) print *,'# WARNING: courant > 0.5'
  open(unit=11,file='f.dat')
  allocate(u_global(Nx))
 end if

 ! --- Initial condition at t=0 ---
 ! We use δ(x-0.5) 
 do i          = 1, Nloc
  i_global     = (my_rank * Nloc) + i
  x            = (i_global - 1) * dx  !use x, if needed
  if(i_global == Nx/2) then            
   u(i)        = 100.0_dp
  else
   u(i)        = 0.0_dp
  end if
 end do

 ! Gather initial state to Rank 0 for output
 call MPI_Gather(                  &
      u(1)      , Nloc, MPI_REAL8, &
      u_global  , Nloc, MPI_REAL8, &
      0, MPI_COMM_WORLD, ierr)

 if (my_rank == 0) then
  do i = 1, Nx
   x   = (i-1)*dx
   write(11,*) ZERO, x, u_global(i)
  end do
  ! Explicitly write the x=1.0 point using u_global(1) to close the plot
  write (11,*) ZERO, ONE, u_global(1)
  write (11,*)' '
 end if

 ! --- Calculate Time Evolution ---
 timeevolution: do j  = 2, Nt
  t = (j-1)*dt

  ! --- HALO EXCHANGE ---
  ! Identical to the Dirichlet version. The modified neighbors handle the wrap-around.
  call MPI_Sendrecv(                               &
       u(Nloc)  , 1, MPI_REAL8, right_neighbor, 0, &
       u(0)     , 1, MPI_REAL8, left_neighbor , 0, &
       MPI_COMM_WORLD, status, ierr)
                          
  call MPI_Sendrecv(                               &
       u(1)     , 1, MPI_REAL8, left_neighbor , 1, &
       u(Nloc+1), 1, MPI_REAL8, right_neighbor, 1, &
       MPI_COMM_WORLD, status, ierr)

  ! --- SECOND DERIVATIVE AND UPDATE ---
  ! All processes update ALL their local points (1 to Nloc)
  do  i  = 1, Nloc
   du(i) = courant*(u(i+1) - TWO*u(i) + u(i-1))
  end do
  
  do i   = 1, Nloc
   u(i)  = u(i) + du(i)
  end do

  ! Gather updated state to Rank 0 for output
  call MPI_Gather(                 &
       u(1)     , Nloc, MPI_REAL8, &
       u_global , Nloc, MPI_REAL8, &
       0, MPI_COMM_WORLD, ierr)

  if (my_rank == 0) then
   do i = 1, Nx
    x   = (i-1)*dx
    write(11,*) t, x, u_global(i)
   end do
   ! Explicitly write the x=1.0 point using u_global(1) to close the plot
   write (11,*) t, ONE, u_global(1)
   write (11,*)' '
  end if
       
 end do timeevolution

 call MPI_Finalize(ierr)

end program diffusion_1d_periodic_mpi
!===================================================================================================================================
!-----------------------------------------------------------------------------------------------------------------------------------
!  Copyright by Konstantinos N. Anagnostopoulos, Physics Department, National Technical University of Athens, 2025
!  konstant@mail.ntua.gr, www.physics.ntua.gr/konstant
!  
!  This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as 
!  published by the Free Software Foundation, version 3 of the License.
!  
!  This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
!  
!  You should have received a copy of the GNU General Public Liense along with this program. If not, see http://www.gnu.org/licenses
!-----------------------------------------------------------------------------------------------------------------------------------
