!=======================================================================
! 1-dimensional Diffusion Equation - MPI Version
! Dirichlet boundary conditions u(0,t)=u(1,t)=0
! 0 <= x <= 1 and 0 <= t <= tf
!
!-----------------------------------------------------------------------
! Test:
! gfortran heatConduction-vanilla.f90 -o v; echo 20 500 0.5 | ./v
! mpifort  heatConduction.f90         -o h; echo 20 500 0.5 | mpirun -n 4 ./h
! splot "d.dat" w l t "vanilla","p.dat" w l t "parallel"
! splot "<paste [pd].dat" u 1:2:($6-$3) w l t "parallel-vanilla"
!=======================================================================
program diffusion_1d_mpi
 use mpi
 implicit none 
    
 integer , parameter   :: dp = 8
 real(dp), allocatable :: u(:), du(:), u_global(:)
 real(dp)              :: t, x, dx, dt, tf, courant
 integer               :: Nx, Nt, i, j
    
 real(dp), parameter   :: ZERO=0.0_dp, ONE=1.0_dp, HALF=0.5_dp, TWO=2.0_dp
 real(dp), parameter   :: PI=atan2(ZERO,-ONE)

 ! --- MPI Variables ---
 integer :: my_rank, num_procs, ierr
 integer :: left_neighbor, right_neighbor
 integer :: Nloc, i_global, istart, iend
 integer :: status(MPI_STATUS_SIZE)

 call MPI_Init(ierr)
 call MPI_Comm_rank(MPI_COMM_WORLD, my_rank, ierr)
 call MPI_Comm_size(MPI_COMM_WORLD, num_procs, ierr)

 ! --- Input (Rank 0 reads and broadcasts) ---
 if (my_rank == 0) then
  print *, '# Enter: Nx, Nt, tf:'
  read  *,           Nx, Nt, tf
  if(Nx <=  3) stop 'Nx <= 3'
  if(Nt <=  2) stop 'Nt <= 2'
 end if

 call MPI_Bcast(Nx, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(Nt, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
 call MPI_Bcast(tf, 1, MPI_REAL8  , 0, MPI_COMM_WORLD, ierr)

 ! Ensure the lattice divides evenly among processes for this introduction
 if (mod(Nx, num_procs) /= 0) then
  if (my_rank == 0) print *, "Error: Nx must be divisible by the number of processes."
  call MPI_Abort(MPI_COMM_WORLD, 1, ierr)
 end if

 Nloc = Nx / num_procs

 ! Determine neighbors for 1D decomposition
 left_neighbor  = my_rank - 1
 right_neighbor = my_rank + 1

 if (my_rank   == 0            ) left_neighbor  = MPI_PROC_NULL
 if (my_rank   == num_procs - 1) right_neighbor = MPI_PROC_NULL

 ! Allocate local arrays with ghost cells at 0 and Nloc+1
 allocate(u(0:Nloc+1), du(1:Nloc))
 du = ZERO

 ! --- Initialize Lattice Geometry ---
 dx      = ONE/(Nx-1)
 dt      = tf /(Nt-1)
 courant = dt /dx**2
    
 if (my_rank == 0) then
  print * ,'# 1d Diffusion Equation: 0<=x<=1, 0<=t<=tf'
  print * ,'# dx= ',dx,' dt= ',dt,' tf= ', tf
  print * ,'# Nx= ',Nx,' Nt= ',Nt
  print * ,'# Courant Number= ',courant
  if(courant > HALF) print *,'# WARNING: courant > 0.5'
  open(unit=11,file='p.dat')
  allocate(u_global(Nx))
 end if

 ! --- Initial condition at t=0 ---
 do      i = 1, Nloc
  i_global = (my_rank  * Nloc) + i
  x        = (i_global -    1) * dx
  u(i)     = sin(PI*x)
 end do
    
 ! Enforce global Dirichlet boundaries
 if (my_rank == 0            ) u(1)    = ZERO
 if (my_rank == num_procs - 1) u(Nloc) = ZERO

 ! Gather initial state to Rank 0 for output
 call MPI_Gather(                &
      u(1)    , Nloc, MPI_REAL8, &
      u_global, Nloc, MPI_REAL8, &
      0, MPI_COMM_WORLD, ierr)

 if (my_rank == 0) then
  do i = 1, Nx
   x   = (i-1) * dx
   write(11,*) ZERO, x, u_global(i)
  end do
  write (11,*)' '
 end if

 ! Determine local computation bounds to protect Dirichlet boundaries
 istart = 1
 iend   = Nloc
 if (my_rank == 0)             istart = 2       ! Protect global x=0
 if (my_rank == num_procs - 1) iend   = Nloc-1  ! Protect global x=1

 ! --- Calculate Time Evolution ---
 timesteps: do j  = 2, Nt
  t = (j-1)*dt

  ! --- HALO EXCHANGE ---
  ! Shift Right: Fill left ghost cells
  call MPI_Sendrecv(                               &
       u(Nloc)  , 1, MPI_REAL8, right_neighbor, 0, &
       u(0)     , 1, MPI_REAL8, left_neighbor , 0, &
       MPI_COMM_WORLD, status, ierr)
                          
  ! Shift Left: Fill right ghost cells
  call MPI_Sendrecv( &
       u(1)     , 1, MPI_REAL8, left_neighbor , 1, &
       u(Nloc+1), 1, MPI_REAL8, right_neighbor, 1, &
       MPI_COMM_WORLD, status, ierr)

  ! --- SECOND DERIVATIVE AND UPDATE ---
  do  i  = istart, iend
   du(i) = courant*(u(i+1) - TWO*u(i) + u(i-1))
  end do

  do i   = istart, iend
   u(i)  = u(i) + du(i)
  end do

  ! Gather updated state to Rank 0 for output
  call MPI_Gather(                 &
       u(1)     , Nloc, MPI_REAL8, &
       u_global , Nloc, MPI_REAL8, &
       0, MPI_COMM_WORLD, ierr)

  if (my_rank == 0) then
   do i = 1, Nx
    x   = (i-1)*dx
    write(11,*) t, x, u_global(i)
   end do
   write (11,*)' '
  end if
       
 end do timesteps

 call MPI_Finalize(ierr)

end program diffusion_1d_mpi
!===================================================================================================================================
!-----------------------------------------------------------------------------------------------------------------------------------
!  Copyright by Konstantinos N. Anagnostopoulos, Physics Department, National Technical University of Athens, 2025
!  konstant@mail.ntua.gr, www.physics.ntua.gr/konstant
!  
!  This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as 
!  published by the Free Software Foundation, version 3 of the License.
!  
!  This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
!  
!  You should have received a copy of the GNU General Public Liense along with this program. If not, see http://www.gnu.org/licenses
!-----------------------------------------------------------------------------------------------------------------------------------
