#include <mpi.h> #include <pthread.h> #define n 2 int thread_id[n]; MPI_Comm split_comm[n]; pthread_t thread[n]; void *worker(void *arg) { int i = *((int *) arg), j = i; MPI_Comm comm = split_comm[i]; MPI_Allreduce(MPI_IN_PLACE, &j, 1, MPI_INT, MPI_SUM, comm); printf("Thread %d: allreduce returned %d\n", i, j); } int main() { MPI_Info info; int i, provided; char s[16]; MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided); MPI_Info_create(&info); for (i = 0; i < n; i++) { MPI_Comm_dup(MPI_COMM_WORLD, &split_comm[i]); sprintf(s, "%d", i); MPI_Info_set(info, "thread_id", s); MPI_Comm_set_info(split_comm[i], info); thread_id[i] = i; pthread_create(&thread[i], NULL, worker, (void *) &thread_id[i]); } for (i = 0; i < n; i++) { pthread_join(thread[i], NULL); } MPI_Info_free(&info); MPI_Finalize(); }