Ginkgo Generated from develop branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
 
Loading...
Searching...
No Matches
mpi.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6#define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7
8
9#include <memory>
10#include <type_traits>
11#include <utility>
12
13#include <ginkgo/config.hpp>
14#include <ginkgo/core/base/exception.hpp>
15#include <ginkgo/core/base/exception_helpers.hpp>
16#include <ginkgo/core/base/executor.hpp>
17#include <ginkgo/core/base/half.hpp>
18#include <ginkgo/core/base/types.hpp>
19#include <ginkgo/core/base/utils_helper.hpp>
20
21
22#if GINKGO_BUILD_MPI
23
24
25#include <mpi.h>
26
27
28namespace gko {
29namespace experimental {
36namespace mpi {
37
38
42inline constexpr bool is_gpu_aware()
43{
44#if GINKGO_HAVE_GPU_AWARE_MPI
45 return true;
46#else
47 return false;
48#endif
49}
50
51
59int map_rank_to_device_id(MPI_Comm comm, int num_devices);
60
61
62#define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
63 template <> \
64 struct type_impl<input_type> { \
65 static MPI_Datatype get_type() { return mpi_type; } \
66 }
67
76template <typename T>
77struct type_impl {};
78
79
80GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
81GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
82GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
83GKO_REGISTER_MPI_TYPE(int, MPI_INT);
84GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
85GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
86GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
87GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
88GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
89GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
90GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
91GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92#if GINKGO_ENABLE_HALF
93// OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
94// Only OpenMPI support complex half
95// TODO: use native type when mpi is configured with half feature
96GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
97GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
98#endif // GKO_ENABLE_HALF
99GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
100GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
101
102
110public:
117 contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
118 {
119 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
120 GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
121 }
122
126 contiguous_type() : type_(MPI_DATATYPE_NULL) {}
127
132
137
143 contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
144 {
145 *this = std::move(other);
146 }
147
156 {
157 if (this != &other) {
158 this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
159 }
160 return *this;
161 }
162
167 {
168 if (type_ != MPI_DATATYPE_NULL) {
169 MPI_Type_free(&type_);
170 }
171 }
172
178 MPI_Datatype get() const { return type_; }
179
180private:
181 MPI_Datatype type_;
182};
183
184
189enum class thread_type {
190 serialized = MPI_THREAD_SERIALIZED,
191 funneled = MPI_THREAD_FUNNELED,
192 single = MPI_THREAD_SINGLE,
193 multiple = MPI_THREAD_MULTIPLE
194};
195
196
207public:
208 static bool is_finalized()
209 {
210 int flag = 0;
211 GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
212 return flag;
213 }
214
215 static bool is_initialized()
216 {
217 int flag = 0;
218 GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
219 return flag;
220 }
221
227 int get_provided_thread_support() const { return provided_thread_support_; }
228
237 environment(int& argc, char**& argv,
238 const thread_type thread_t = thread_type::serialized)
239 {
240 this->required_thread_support_ = static_cast<int>(thread_t);
241 GKO_ASSERT_NO_MPI_ERRORS(
242 MPI_Init_thread(&argc, &argv, this->required_thread_support_,
243 &(this->provided_thread_support_)));
244 }
245
249 ~environment() { MPI_Finalize(); }
250
251 environment(const environment&) = delete;
252 environment(environment&&) = delete;
253 environment& operator=(const environment&) = delete;
254 environment& operator=(environment&&) = delete;
255
256private:
257 int required_thread_support_;
258 int provided_thread_support_;
259};
260
261
262namespace {
263
264
269class comm_deleter {
270public:
271 using pointer = MPI_Comm*;
272 void operator()(pointer comm) const
273 {
274 GKO_ASSERT(*comm != MPI_COMM_NULL);
275 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
276 delete comm;
277 }
278};
279
280
281} // namespace
282
283
287struct status {
291 status() : status_(MPI_Status{}) {}
292
298 MPI_Status* get() { return &this->status_; }
299
310 template <typename T>
311 int get_count(const T* data) const
312 {
313 int count;
314 MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
315 return count;
316 }
317
318private:
319 MPI_Status status_;
320};
321
322
327class request {
328public:
333 request() : req_(MPI_REQUEST_NULL) {}
334
335 request(const request&) = delete;
336
337 request& operator=(const request&) = delete;
338
339 request(request&& o) noexcept { *this = std::move(o); }
340
341 request& operator=(request&& o) noexcept
342 {
343 if (this != &o) {
344 this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
345 }
346 return *this;
347 }
348
349 ~request()
350 {
351 if (req_ != MPI_REQUEST_NULL) {
352 if (MPI_Request_free(&req_) != MPI_SUCCESS) {
353 std::terminate(); // since we can't throw in destructors, we
354 // have to terminate the program
355 }
356 }
357 }
358
364 MPI_Request* get() { return &this->req_; }
365
373 {
375 GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
376 return status;
377 }
378
379
380private:
381 MPI_Request req_;
382};
383
384
392inline std::vector<status> wait_all(std::vector<request>& req)
393{
394 std::vector<status> stat;
395 for (std::size_t i = 0; i < req.size(); ++i) {
396 stat.emplace_back(req[i].wait());
397 }
398 return stat;
399}
400
401
417public:
428 communicator(const MPI_Comm& comm, bool force_host_buffer = false)
429 : comm_(), force_host_buffer_(force_host_buffer)
430 {
431 this->comm_.reset(new MPI_Comm(comm));
432 }
433
442 communicator(const MPI_Comm& comm, int color, int key)
443 {
444 MPI_Comm comm_out;
445 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
446 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
447 }
448
457 communicator(const communicator& comm, int color, int key)
458 {
459 MPI_Comm comm_out;
460 GKO_ASSERT_NO_MPI_ERRORS(
461 MPI_Comm_split(comm.get(), color, key, &comm_out));
462 this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
463 }
464
474 static communicator create_owning(const MPI_Comm& comm,
475 bool force_host_buffer = false)
476 {
477 communicator comm_out(MPI_COMM_NULL, force_host_buffer);
478 comm_out.comm_.reset(new MPI_Comm(comm), comm_deleter{});
479 return comm_out;
480 }
481
487 communicator(const communicator& other) = default;
488
495 communicator(communicator&& other) { *this = std::move(other); }
496
500 communicator& operator=(const communicator& other) = default;
501
506 {
507 if (this != &other) {
508 comm_ = std::exchange(other.comm_,
509 std::make_shared<MPI_Comm>(MPI_COMM_NULL));
510 force_host_buffer_ = other.force_host_buffer_;
511 }
512 return *this;
513 }
514
520 const MPI_Comm& get() const { return *(this->comm_.get()); }
521
522 bool force_host_buffer() const { return force_host_buffer_; }
523
529 int size() const { return get_num_ranks(); }
530
536 int rank() const { return get_my_rank(); };
537
543 int node_local_rank() const { return get_node_local_rank(); };
544
550 bool operator==(const communicator& rhs) const { return is_identical(rhs); }
551
557 bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
558
568 bool is_identical(const communicator& rhs) const
569 {
570 if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
571 return get() == rhs.get();
572 }
573 int flag;
574 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
575 return flag == MPI_IDENT;
576 }
577
590 bool is_congruent(const communicator& rhs) const
591 {
592 if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
593 return get() == rhs.get();
594 }
595 int flag;
596 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
597 return flag == MPI_CONGRUENT;
598 }
599
604 void synchronize() const
605 {
606 GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
607 }
608
622 template <typename SendType>
623 void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
624 const int send_count, const int destination_rank,
625 const int send_tag) const
626 {
627 auto guard = exec->get_scoped_device_id_guard();
628 GKO_ASSERT_NO_MPI_ERRORS(
629 MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
630 destination_rank, send_tag, this->get()));
631 }
632
649 template <typename SendType>
650 request i_send(std::shared_ptr<const Executor> exec,
651 const SendType* send_buffer, const int send_count,
652 const int destination_rank, const int send_tag) const
653 {
654 auto guard = exec->get_scoped_device_id_guard();
655 request req;
656 GKO_ASSERT_NO_MPI_ERRORS(
657 MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
658 destination_rank, send_tag, this->get(), req.get()));
659 return req;
660 }
661
677 template <typename RecvType>
678 status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
679 const int recv_count, const int source_rank,
680 const int recv_tag) const
681 {
682 auto guard = exec->get_scoped_device_id_guard();
683 status st;
684 GKO_ASSERT_NO_MPI_ERRORS(
685 MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
686 source_rank, recv_tag, this->get(), st.get()));
687 return st;
688 }
689
705 template <typename RecvType>
706 request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
707 const int recv_count, const int source_rank,
708 const int recv_tag) const
709 {
710 auto guard = exec->get_scoped_device_id_guard();
711 request req;
712 GKO_ASSERT_NO_MPI_ERRORS(
713 MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
714 source_rank, recv_tag, this->get(), req.get()));
715 return req;
716 }
717
730 template <typename BroadcastType>
731 void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
732 int count, int root_rank) const
733 {
734 auto guard = exec->get_scoped_device_id_guard();
735 GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
737 root_rank, this->get()));
738 }
739
755 template <typename BroadcastType>
756 request i_broadcast(std::shared_ptr<const Executor> exec,
757 BroadcastType* buffer, int count, int root_rank) const
758 {
759 auto guard = exec->get_scoped_device_id_guard();
760 request req;
761 GKO_ASSERT_NO_MPI_ERRORS(
762 MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
763 root_rank, this->get(), req.get()));
764 return req;
765 }
766
781 template <typename ReduceType>
782 void reduce(std::shared_ptr<const Executor> exec,
783 const ReduceType* send_buffer, ReduceType* recv_buffer,
784 int count, MPI_Op operation, int root_rank) const
785 {
786 auto guard = exec->get_scoped_device_id_guard();
787 GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
789 operation, root_rank, this->get()));
790 }
791
808 template <typename ReduceType>
809 request i_reduce(std::shared_ptr<const Executor> exec,
810 const ReduceType* send_buffer, ReduceType* recv_buffer,
811 int count, MPI_Op operation, int root_rank) const
812 {
813 auto guard = exec->get_scoped_device_id_guard();
814 request req;
815 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
816 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
817 operation, root_rank, this->get(), req.get()));
818 return req;
819 }
820
834 template <typename ReduceType>
835 void all_reduce(std::shared_ptr<const Executor> exec,
836 ReduceType* recv_buffer, int count, MPI_Op operation) const
837 {
838 auto guard = exec->get_scoped_device_id_guard();
839 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
840 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
841 operation, this->get()));
842 }
843
859 template <typename ReduceType>
860 request i_all_reduce(std::shared_ptr<const Executor> exec,
861 ReduceType* recv_buffer, int count,
862 MPI_Op operation) const
863 {
864 auto guard = exec->get_scoped_device_id_guard();
865 request req;
866 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
867 MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
868 operation, this->get(), req.get()));
869 return req;
870 }
871
886 template <typename ReduceType>
887 void all_reduce(std::shared_ptr<const Executor> exec,
888 const ReduceType* send_buffer, ReduceType* recv_buffer,
889 int count, MPI_Op operation) const
890 {
891 auto guard = exec->get_scoped_device_id_guard();
892 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
893 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
894 operation, this->get()));
895 }
896
913 template <typename ReduceType>
914 request i_all_reduce(std::shared_ptr<const Executor> exec,
915 const ReduceType* send_buffer, ReduceType* recv_buffer,
916 int count, MPI_Op operation) const
917 {
918 auto guard = exec->get_scoped_device_id_guard();
919 request req;
920 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
921 send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
922 operation, this->get(), req.get()));
923 return req;
924 }
925
942 template <typename SendType, typename RecvType>
943 void gather(std::shared_ptr<const Executor> exec,
944 const SendType* send_buffer, const int send_count,
945 RecvType* recv_buffer, const int recv_count,
946 int root_rank) const
947 {
948 auto guard = exec->get_scoped_device_id_guard();
949 GKO_ASSERT_NO_MPI_ERRORS(
950 MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
951 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
952 root_rank, this->get()));
953 }
954
974 template <typename SendType, typename RecvType>
975 request i_gather(std::shared_ptr<const Executor> exec,
976 const SendType* send_buffer, const int send_count,
977 RecvType* recv_buffer, const int recv_count,
978 int root_rank) const
979 {
980 auto guard = exec->get_scoped_device_id_guard();
981 request req;
982 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
983 send_buffer, send_count, type_impl<SendType>::get_type(),
984 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
985 this->get(), req.get()));
986 return req;
987 }
988
1007 template <typename SendType, typename RecvType>
1008 void gather_v(std::shared_ptr<const Executor> exec,
1009 const SendType* send_buffer, const int send_count,
1010 RecvType* recv_buffer, const int* recv_counts,
1011 const int* displacements, int root_rank) const
1012 {
1013 auto guard = exec->get_scoped_device_id_guard();
1014 GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
1015 send_buffer, send_count, type_impl<SendType>::get_type(),
1016 recv_buffer, recv_counts, displacements,
1017 type_impl<RecvType>::get_type(), root_rank, this->get()));
1018 }
1019
1040 template <typename SendType, typename RecvType>
1041 request i_gather_v(std::shared_ptr<const Executor> exec,
1042 const SendType* send_buffer, const int send_count,
1043 RecvType* recv_buffer, const int* recv_counts,
1044 const int* displacements, int root_rank) const
1045 {
1046 auto guard = exec->get_scoped_device_id_guard();
1047 request req;
1048 GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
1049 send_buffer, send_count, type_impl<SendType>::get_type(),
1050 recv_buffer, recv_counts, displacements,
1051 type_impl<RecvType>::get_type(), root_rank, this->get(),
1052 req.get()));
1053 return req;
1054 }
1055
1071 template <typename SendType, typename RecvType>
1072 void all_gather(std::shared_ptr<const Executor> exec,
1073 const SendType* send_buffer, const int send_count,
1074 RecvType* recv_buffer, const int recv_count) const
1075 {
1076 auto guard = exec->get_scoped_device_id_guard();
1077 GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
1078 send_buffer, send_count, type_impl<SendType>::get_type(),
1079 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1080 this->get()));
1081 }
1082
1101 template <typename SendType, typename RecvType>
1102 request i_all_gather(std::shared_ptr<const Executor> exec,
1103 const SendType* send_buffer, const int send_count,
1104 RecvType* recv_buffer, const int recv_count) const
1105 {
1106 auto guard = exec->get_scoped_device_id_guard();
1107 request req;
1108 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1109 send_buffer, send_count, type_impl<SendType>::get_type(),
1110 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1111 this->get(), req.get()));
1112 return req;
1113 }
1114
1130 template <typename SendType, typename RecvType>
1131 void scatter(std::shared_ptr<const Executor> exec,
1132 const SendType* send_buffer, const int send_count,
1133 RecvType* recv_buffer, const int recv_count,
1134 int root_rank) const
1135 {
1136 auto guard = exec->get_scoped_device_id_guard();
1137 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1138 send_buffer, send_count, type_impl<SendType>::get_type(),
1139 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1140 this->get()));
1141 }
1142
1161 template <typename SendType, typename RecvType>
1162 request i_scatter(std::shared_ptr<const Executor> exec,
1163 const SendType* send_buffer, const int send_count,
1164 RecvType* recv_buffer, const int recv_count,
1165 int root_rank) const
1166 {
1167 auto guard = exec->get_scoped_device_id_guard();
1168 request req;
1169 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1170 send_buffer, send_count, type_impl<SendType>::get_type(),
1171 recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1172 this->get(), req.get()));
1173 return req;
1174 }
1175
1194 template <typename SendType, typename RecvType>
1195 void scatter_v(std::shared_ptr<const Executor> exec,
1196 const SendType* send_buffer, const int* send_counts,
1197 const int* displacements, RecvType* recv_buffer,
1198 const int recv_count, int root_rank) const
1199 {
1200 auto guard = exec->get_scoped_device_id_guard();
1201 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1202 send_buffer, send_counts, displacements,
1203 type_impl<SendType>::get_type(), recv_buffer, recv_count,
1204 type_impl<RecvType>::get_type(), root_rank, this->get()));
1205 }
1206
1227 template <typename SendType, typename RecvType>
1228 request i_scatter_v(std::shared_ptr<const Executor> exec,
1229 const SendType* send_buffer, const int* send_counts,
1230 const int* displacements, RecvType* recv_buffer,
1231 const int recv_count, int root_rank) const
1232 {
1233 auto guard = exec->get_scoped_device_id_guard();
1234 request req;
1235 GKO_ASSERT_NO_MPI_ERRORS(
1236 MPI_Iscatterv(send_buffer, send_counts, displacements,
1237 type_impl<SendType>::get_type(), recv_buffer,
1238 recv_count, type_impl<RecvType>::get_type(),
1239 root_rank, this->get(), req.get()));
1240 return req;
1241 }
1242
1259 template <typename RecvType>
1260 void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1261 const int recv_count) const
1262 {
1263 auto guard = exec->get_scoped_device_id_guard();
1264 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1265 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1266 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1267 this->get()));
1268 }
1269
1288 template <typename RecvType>
1289 request i_all_to_all(std::shared_ptr<const Executor> exec,
1290 RecvType* recv_buffer, const int recv_count) const
1291 {
1292 auto guard = exec->get_scoped_device_id_guard();
1293 request req;
1294 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1295 MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1296 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1297 this->get(), req.get()));
1298 return req;
1299 }
1300
1317 template <typename SendType, typename RecvType>
1318 void all_to_all(std::shared_ptr<const Executor> exec,
1319 const SendType* send_buffer, const int send_count,
1320 RecvType* recv_buffer, const int recv_count) const
1321 {
1322 auto guard = exec->get_scoped_device_id_guard();
1323 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1324 send_buffer, send_count, type_impl<SendType>::get_type(),
1325 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1326 this->get()));
1327 }
1328
1347 template <typename SendType, typename RecvType>
1348 request i_all_to_all(std::shared_ptr<const Executor> exec,
1349 const SendType* send_buffer, const int send_count,
1350 RecvType* recv_buffer, const int recv_count) const
1351 {
1352 auto guard = exec->get_scoped_device_id_guard();
1353 request req;
1354 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1355 send_buffer, send_count, type_impl<SendType>::get_type(),
1356 recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1357 this->get(), req.get()));
1358 return req;
1359 }
1360
1380 template <typename SendType, typename RecvType>
1381 void all_to_all_v(std::shared_ptr<const Executor> exec,
1382 const SendType* send_buffer, const int* send_counts,
1383 const int* send_offsets, RecvType* recv_buffer,
1384 const int* recv_counts, const int* recv_offsets) const
1385 {
1386 this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1387 send_offsets, type_impl<SendType>::get_type(),
1388 recv_buffer, recv_counts, recv_offsets,
1390 }
1391
1407 void all_to_all_v(std::shared_ptr<const Executor> exec,
1408 const void* send_buffer, const int* send_counts,
1409 const int* send_offsets, MPI_Datatype send_type,
1410 void* recv_buffer, const int* recv_counts,
1411 const int* recv_offsets, MPI_Datatype recv_type) const
1412 {
1413 auto guard = exec->get_scoped_device_id_guard();
1414 GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1415 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1416 recv_counts, recv_offsets, recv_type, this->get()));
1417 }
1418
1438 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1439 const void* send_buffer, const int* send_counts,
1440 const int* send_offsets, MPI_Datatype send_type,
1441 void* recv_buffer, const int* recv_counts,
1442 const int* recv_offsets,
1443 MPI_Datatype recv_type) const
1444 {
1445 auto guard = exec->get_scoped_device_id_guard();
1446 request req;
1447 GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1448 send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1449 recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1450 return req;
1451 }
1452
1473 template <typename SendType, typename RecvType>
1474 request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1475 const SendType* send_buffer, const int* send_counts,
1476 const int* send_offsets, RecvType* recv_buffer,
1477 const int* recv_counts,
1478 const int* recv_offsets) const
1479 {
1480 return this->i_all_to_all_v(
1481 std::move(exec), send_buffer, send_counts, send_offsets,
1482 type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1483 recv_offsets, type_impl<RecvType>::get_type());
1484 }
1485
1500 template <typename ScanType>
1501 void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1502 ScanType* recv_buffer, int count, MPI_Op operation) const
1503 {
1504 auto guard = exec->get_scoped_device_id_guard();
1505 GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1507 operation, this->get()));
1508 }
1509
1526 template <typename ScanType>
1527 request i_scan(std::shared_ptr<const Executor> exec,
1528 const ScanType* send_buffer, ScanType* recv_buffer,
1529 int count, MPI_Op operation) const
1530 {
1531 auto guard = exec->get_scoped_device_id_guard();
1532 request req;
1533 GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1535 operation, this->get(), req.get()));
1536 return req;
1537 }
1538
1539private:
1540 std::shared_ptr<MPI_Comm> comm_;
1541 bool force_host_buffer_;
1542
1543 int get_my_rank() const
1544 {
1545 int my_rank = 0;
1546 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1547 return my_rank;
1548 }
1549
1550 int get_node_local_rank() const
1551 {
1552 MPI_Comm local_comm;
1553 int rank;
1554 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1555 this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1556 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1557 MPI_Comm_free(&local_comm);
1558 return rank;
1559 }
1560
1561 int get_num_ranks() const
1562 {
1563 int size = 1;
1564 GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1565 return size;
1566 }
1567};
1568
1569
1574bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1575 const communicator& comm);
1576
1577
1583inline double get_walltime() { return MPI_Wtime(); }
1584
1585
1594template <typename ValueType>
1595class window {
1596public:
1600 enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1601
1605 enum class lock_type { shared = 1, exclusive = 2 };
1606
1610 window() : window_(MPI_WIN_NULL) {}
1611
1612 window(const window& other) = delete;
1613
1614 window& operator=(const window& other) = delete;
1615
1622 window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1623 {}
1624
1632 {
1633 window_ = std::exchange(other.window_, MPI_WIN_NULL);
1634 }
1635
1648 window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1649 const communicator& comm, const int disp_unit = sizeof(ValueType),
1650 MPI_Info input_info = MPI_INFO_NULL,
1651 create_type c_type = create_type::create)
1652 {
1653 auto guard = exec->get_scoped_device_id_guard();
1654 unsigned size = num_elems * sizeof(ValueType);
1655 if (c_type == create_type::create) {
1656 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1657 base, size, disp_unit, input_info, comm.get(), &this->window_));
1658 } else if (c_type == create_type::dynamic_create) {
1659 GKO_ASSERT_NO_MPI_ERRORS(
1660 MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1661 } else if (c_type == create_type::allocate) {
1662 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1663 size, disp_unit, input_info, comm.get(), base, &this->window_));
1664 } else {
1665 GKO_NOT_IMPLEMENTED;
1666 }
1667 }
1668
1674 MPI_Win get_window() const { return this->window_; }
1675
1682 void fence(int assert = 0) const
1683 {
1684 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1685 }
1686
1695 void lock(int rank, lock_type lock_t = lock_type::shared,
1696 int assert = 0) const
1697 {
1698 if (lock_t == lock_type::shared) {
1699 GKO_ASSERT_NO_MPI_ERRORS(
1700 MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1701 } else if (lock_t == lock_type::exclusive) {
1702 GKO_ASSERT_NO_MPI_ERRORS(
1703 MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1704 } else {
1705 GKO_NOT_IMPLEMENTED;
1706 }
1707 }
1708
1715 void unlock(int rank) const
1716 {
1717 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1718 }
1719
1726 void lock_all(int assert = 0) const
1727 {
1728 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1729 }
1730
1735 void unlock_all() const
1736 {
1737 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1738 }
1739
1746 void flush(int rank) const
1747 {
1748 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1749 }
1750
1757 void flush_local(int rank) const
1758 {
1759 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1760 }
1761
1766 void flush_all() const
1767 {
1768 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1769 }
1770
1775 void flush_all_local() const
1776 {
1777 GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1778 }
1779
1783 void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1784
1789 {
1790 if (this->window_ && this->window_ != MPI_WIN_NULL) {
1791 MPI_Win_free(&this->window_);
1792 }
1793 }
1794
1805 template <typename PutType>
1806 void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1807 const int origin_count, const int target_rank,
1808 const unsigned int target_disp, const int target_count) const
1809 {
1810 auto guard = exec->get_scoped_device_id_guard();
1811 GKO_ASSERT_NO_MPI_ERRORS(
1812 MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1813 target_rank, target_disp, target_count,
1815 }
1816
1829 template <typename PutType>
1830 request r_put(std::shared_ptr<const Executor> exec,
1831 const PutType* origin_buffer, const int origin_count,
1832 const int target_rank, const unsigned int target_disp,
1833 const int target_count) const
1834 {
1835 auto guard = exec->get_scoped_device_id_guard();
1836 request req;
1837 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1838 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1839 target_rank, target_disp, target_count,
1840 type_impl<PutType>::get_type(), this->get_window(), req.get()));
1841 return req;
1842 }
1843
1855 template <typename PutType>
1856 void accumulate(std::shared_ptr<const Executor> exec,
1857 const PutType* origin_buffer, const int origin_count,
1858 const int target_rank, const unsigned int target_disp,
1859 const int target_count, MPI_Op operation) const
1860 {
1861 auto guard = exec->get_scoped_device_id_guard();
1862 GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1863 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1864 target_rank, target_disp, target_count,
1865 type_impl<PutType>::get_type(), operation, this->get_window()));
1866 }
1867
1881 template <typename PutType>
1882 request r_accumulate(std::shared_ptr<const Executor> exec,
1883 const PutType* origin_buffer, const int origin_count,
1884 const int target_rank, const unsigned int target_disp,
1885 const int target_count, MPI_Op operation) const
1886 {
1887 auto guard = exec->get_scoped_device_id_guard();
1888 request req;
1889 GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1890 origin_buffer, origin_count, type_impl<PutType>::get_type(),
1891 target_rank, target_disp, target_count,
1892 type_impl<PutType>::get_type(), operation, this->get_window(),
1893 req.get()));
1894 return req;
1895 }
1896
1907 template <typename GetType>
1908 void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1909 const int origin_count, const int target_rank,
1910 const unsigned int target_disp, const int target_count) const
1911 {
1912 auto guard = exec->get_scoped_device_id_guard();
1913 GKO_ASSERT_NO_MPI_ERRORS(
1914 MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1915 target_rank, target_disp, target_count,
1917 }
1918
1931 template <typename GetType>
1932 request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1933 const int origin_count, const int target_rank,
1934 const unsigned int target_disp, const int target_count) const
1935 {
1936 auto guard = exec->get_scoped_device_id_guard();
1937 request req;
1938 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1939 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1940 target_rank, target_disp, target_count,
1941 type_impl<GetType>::get_type(), this->get_window(), req.get()));
1942 return req;
1943 }
1944
1958 template <typename GetType>
1959 void get_accumulate(std::shared_ptr<const Executor> exec,
1960 GetType* origin_buffer, const int origin_count,
1961 GetType* result_buffer, const int result_count,
1962 const int target_rank, const unsigned int target_disp,
1963 const int target_count, MPI_Op operation) const
1964 {
1965 auto guard = exec->get_scoped_device_id_guard();
1966 GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1967 origin_buffer, origin_count, type_impl<GetType>::get_type(),
1968 result_buffer, result_count, type_impl<GetType>::get_type(),
1969 target_rank, target_disp, target_count,
1970 type_impl<GetType>::get_type(), operation, this->get_window()));
1971 }
1972
1988 template <typename GetType>
1989 request r_get_accumulate(std::shared_ptr<const Executor> exec,
1990 GetType* origin_buffer, const int origin_count,
1991 GetType* result_buffer, const int result_count,
1992 const int target_rank,
1993 const unsigned int target_disp,
1994 const int target_count, MPI_Op operation) const
1995 {
1996 auto guard = exec->get_scoped_device_id_guard();
1997 request req;
1998 GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
1999 origin_buffer, origin_count, type_impl<GetType>::get_type(),
2000 result_buffer, result_count, type_impl<GetType>::get_type(),
2001 target_rank, target_disp, target_count,
2002 type_impl<GetType>::get_type(), operation, this->get_window(),
2003 req.get()));
2004 return req;
2005 }
2006
2017 template <typename GetType>
2018 void fetch_and_op(std::shared_ptr<const Executor> exec,
2019 GetType* origin_buffer, GetType* result_buffer,
2020 const int target_rank, const unsigned int target_disp,
2021 MPI_Op operation) const
2022 {
2023 auto guard = exec->get_scoped_device_id_guard();
2024 GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
2025 origin_buffer, result_buffer, type_impl<GetType>::get_type(),
2026 target_rank, target_disp, operation, this->get_window()));
2027 }
2028
2029private:
2030 MPI_Win window_;
2031};
2032
2033
2034} // namespace mpi
2035} // namespace experimental
2036} // namespace gko
2037
2038
2039#endif // GKO_HAVE_MPI
2040
2041
2042#endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition mpi.hpp:416
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition mpi.hpp:678
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1195
communicator(const communicator &other)=default
Create a copy of a communicator.
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition mpi.hpp:756
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:943
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition mpi.hpp:706
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition mpi.hpp:1228
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition mpi.hpp:1318
bool is_identical(const communicator &rhs) const
Checks if the rhs communicator is identical to this communicator.
Definition mpi.hpp:568
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition mpi.hpp:1348
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1438
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition mpi.hpp:557
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1131
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition mpi.hpp:604
int rank() const
Return the rank of the calling process in the communicator.
Definition mpi.hpp:536
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:809
int size() const
Return the size of the communicator (number of ranks).
Definition mpi.hpp:529
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition mpi.hpp:623
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition mpi.hpp:1474
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition mpi.hpp:975
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition mpi.hpp:1260
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1381
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition mpi.hpp:860
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition mpi.hpp:1289
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition mpi.hpp:1407
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition mpi.hpp:543
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition mpi.hpp:731
static communicator create_owning(const MPI_Comm &comm, bool force_host_buffer=false)
Creates a new communicator and takes ownership of the MPI_Comm.
Definition mpi.hpp:474
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition mpi.hpp:520
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:442
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:835
communicator & operator=(const communicator &other)=default
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1072
communicator & operator=(communicator &&other)
Definition mpi.hpp:505
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition mpi.hpp:1102
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition mpi.hpp:550
bool is_congruent(const communicator &rhs) const
Checks if the rhs communicator is congruent to this communicator.
Definition mpi.hpp:590
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:887
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:1041
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition mpi.hpp:914
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition mpi.hpp:428
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1527
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition mpi.hpp:782
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition mpi.hpp:1162
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition mpi.hpp:1501
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition mpi.hpp:1008
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition mpi.hpp:650
communicator(communicator &&other)
Move constructor.
Definition mpi.hpp:495
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition mpi.hpp:457
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition mpi.hpp:178
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition mpi.hpp:117
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition mpi.hpp:126
contiguous_type(const contiguous_type &)=delete
Disallow copying of wrapper type.
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:143
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition mpi.hpp:155
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition mpi.hpp:166
Class that sets up and finalizes the MPI environment.
Definition mpi.hpp:206
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition mpi.hpp:249
int get_provided_thread_support() const
Return the provided thread support.
Definition mpi.hpp:227
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition mpi.hpp:237
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition mpi.hpp:327
request()
The default constructor.
Definition mpi.hpp:333
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition mpi.hpp:364
status wait()
Allows a rank to wait on a particular request handle.
Definition mpi.hpp:372
This class wraps the MPI_Window class with RAII functionality.
Definition mpi.hpp:1595
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition mpi.hpp:1908
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1830
window()
The default constructor.
Definition mpi.hpp:1610
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition mpi.hpp:1959
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition mpi.hpp:1806
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition mpi.hpp:1788
lock_type
The lock type for passive target synchronization of the windows.
Definition mpi.hpp:1605
window & operator=(window &&other)
The move assignment operator.
Definition mpi.hpp:1631
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition mpi.hpp:1882
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition mpi.hpp:1989
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition mpi.hpp:2018
void sync() const
Synchronize the public and private buffers for the window object.
Definition mpi.hpp:1783
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition mpi.hpp:1715
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition mpi.hpp:1682
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition mpi.hpp:1746
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition mpi.hpp:1735
create_type
The create type for the window object.
Definition mpi.hpp:1600
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition mpi.hpp:1648
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition mpi.hpp:1856
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition mpi.hpp:1726
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition mpi.hpp:1695
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition mpi.hpp:1775
window(window &&other)
The move constructor.
Definition mpi.hpp:1622
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition mpi.hpp:1757
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition mpi.hpp:1674
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition mpi.hpp:1932
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition mpi.hpp:1766
A class providing basic support for half precision floating point types.
Definition half.hpp:286
The mpi namespace, contains wrapper for many MPI functions.
Definition mpi.hpp:36
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
double get_walltime()
Get the rank in the communicator of the calling process.
Definition mpi.hpp:1583
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition mpi.hpp:42
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition mpi.hpp:189
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition mpi.hpp:392
The Ginkgo namespace.
Definition abstract_factory.hpp:20
STL namespace.
The status struct is a light wrapper around the MPI_Status struct.
Definition mpi.hpp:287
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition mpi.hpp:311
status()
The default constructor.
Definition mpi.hpp:291
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition mpi.hpp:298
A struct that is used to determine the MPI_Datatype of a specified type.
Definition mpi.hpp:77