Improve MPI handling for enums
This commit is contained in:
parent
e89bd8f150
commit
66f2de1083
47
atrip.org
47
atrip.org
@ -156,20 +156,23 @@ As an example, for the doubles amplitudes \( T^{ab}_{ij} \), one need two kinds
|
|||||||
const std::vector<int> lengths(n, 1);
|
const std::vector<int> lengths(n, 1);
|
||||||
const MPI_Datatype types[n] = {usizeDt(), usizeDt()};
|
const MPI_Datatype types[n] = {usizeDt(), usizeDt()};
|
||||||
|
|
||||||
|
static_assert(sizeof(Slice<F>::Location) == 2 * sizeof(size_t),
|
||||||
|
"The Location packing is wrong in your compiler");
|
||||||
|
|
||||||
// measure the displacements in the struct
|
// measure the displacements in the struct
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
MPI_Aint displacements[n];
|
MPI_Aint base_address, displacements[n];
|
||||||
|
MPI_Get_address(&measure, &base_address);
|
||||||
MPI_Get_address(&measure.rank, &displacements[j++]);
|
MPI_Get_address(&measure.rank, &displacements[j++]);
|
||||||
MPI_Get_address(&measure.source, &displacements[j++]);
|
MPI_Get_address(&measure.source, &displacements[j++]);
|
||||||
for (size_t i = 1; i < n; i++) displacements[i] -= displacements[0];
|
for (size_t i = 0; i < n; i++)
|
||||||
displacements[0] = 0;
|
displacements[i] = MPI_Aint_diff(displacements[i], base_address);
|
||||||
|
|
||||||
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
||||||
MPI_Type_commit(&dt);
|
MPI_Type_commit(&dt);
|
||||||
return dt;
|
return dt;
|
||||||
}
|
}
|
||||||
|
|
||||||
static MPI_Datatype enumDt() { return MPI_INT; }
|
|
||||||
static MPI_Datatype usizeDt() { return MPI_UINT64_T; }
|
static MPI_Datatype usizeDt() { return MPI_UINT64_T; }
|
||||||
|
|
||||||
static MPI_Datatype sliceInfo () {
|
static MPI_Datatype sliceInfo () {
|
||||||
@ -179,22 +182,31 @@ As an example, for the doubles amplitudes \( T^{ab}_{ij} \), one need two kinds
|
|||||||
const std::vector<int> lengths(n, 1);
|
const std::vector<int> lengths(n, 1);
|
||||||
const MPI_Datatype types[n]
|
const MPI_Datatype types[n]
|
||||||
= { vector(2, usizeDt())
|
= { vector(2, usizeDt())
|
||||||
, enumDt()
|
/*, MPI_UINT64_T*/
|
||||||
, enumDt()
|
, vector(sizeof(enum Type), MPI_CHAR)
|
||||||
|
/*, MPI_UINT64_T*/
|
||||||
|
, vector(sizeof(enum State), MPI_CHAR)
|
||||||
|
/*, vector(sizeof(Location), MPI_CHAR)*/
|
||||||
, sliceLocation()
|
, sliceLocation()
|
||||||
, enumDt()
|
, vector(sizeof(enum Type), MPI_CHAR)
|
||||||
|
/*, MPI_UINT64_T*/
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static_assert(sizeof(enum Type) == 4, "Enum type not 4 bytes long");
|
||||||
|
static_assert(sizeof(enum State) == 4, "Enum State not 4 bytes long");
|
||||||
|
static_assert(sizeof(enum Name) == 4, "Enum Name not 4 bytes long");
|
||||||
|
|
||||||
// create the displacements from the info measurement struct
|
// create the displacements from the info measurement struct
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
MPI_Aint displacements[n];
|
MPI_Aint base_address, displacements[n];
|
||||||
MPI_Get_address(measure.tuple.data(), &displacements[j++]);
|
MPI_Get_address(&measure, &base_address);
|
||||||
|
MPI_Get_address(&measure.tuple[0], &displacements[j++]);
|
||||||
MPI_Get_address(&measure.type, &displacements[j++]);
|
MPI_Get_address(&measure.type, &displacements[j++]);
|
||||||
MPI_Get_address(&measure.state, &displacements[j++]);
|
MPI_Get_address(&measure.state, &displacements[j++]);
|
||||||
MPI_Get_address(&measure.from, &displacements[j++]);
|
MPI_Get_address(&measure.from, &displacements[j++]);
|
||||||
MPI_Get_address(&measure.recycling, &displacements[j++]);
|
MPI_Get_address(&measure.recycling, &displacements[j++]);
|
||||||
for (size_t i = 1; i < n; i++) displacements[i] -= displacements[0];
|
for (size_t i = 0; i < n; i++)
|
||||||
displacements[0] = 0;
|
displacements[i] = MPI_Aint_diff(displacements[i], base_address);
|
||||||
|
|
||||||
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
||||||
MPI_Type_commit(&dt);
|
MPI_Type_commit(&dt);
|
||||||
@ -207,13 +219,15 @@ As an example, for the doubles amplitudes \( T^{ab}_{ij} \), one need two kinds
|
|||||||
LocalDatabaseElement measure;
|
LocalDatabaseElement measure;
|
||||||
const std::vector<int> lengths(n, 1);
|
const std::vector<int> lengths(n, 1);
|
||||||
const MPI_Datatype types[n]
|
const MPI_Datatype types[n]
|
||||||
= { enumDt()
|
= { vector(sizeof(enum Name), MPI_CHAR)
|
||||||
|
/*= { MPI_UINT64_T*/
|
||||||
, sliceInfo()
|
, sliceInfo()
|
||||||
};
|
};
|
||||||
|
|
||||||
// measure the displacements in the struct
|
// measure the displacements in the struct
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
MPI_Aint displacements[n];
|
MPI_Aint base_address, displacements[n];
|
||||||
|
MPI_Get_address(&measure, &base_address);
|
||||||
MPI_Get_address(&measure.name, &displacements[j++]);
|
MPI_Get_address(&measure.name, &displacements[j++]);
|
||||||
MPI_Get_address(&measure.info, &displacements[j++]);
|
MPI_Get_address(&measure.info, &displacements[j++]);
|
||||||
for (size_t i = 1; i < n; i++) displacements[i] -= displacements[0];
|
for (size_t i = 1; i < n; i++) displacements[i] -= displacements[0];
|
||||||
@ -221,6 +235,9 @@ As an example, for the doubles amplitudes \( T^{ab}_{ij} \), one need two kinds
|
|||||||
|
|
||||||
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
MPI_Type_create_struct(n, lengths.data(), displacements, types, &dt);
|
||||||
MPI_Type_commit(&dt);
|
MPI_Type_commit(&dt);
|
||||||
|
/*return vector( 4 + 4 + 48, MPI_CHAR);*/
|
||||||
|
// TODO
|
||||||
|
return vector(sizeof(LocalDatabaseElement), MPI_CHAR);
|
||||||
return dt;
|
return dt;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2260,9 +2277,9 @@ Atrip::Output Atrip::run(Atrip::Input<F> const& in) {
|
|||||||
* double(No)
|
* double(No)
|
||||||
* double(No)
|
* double(No)
|
||||||
* (double(No) + double(Nv))
|
* (double(No) + double(Nv))
|
||||||
* 1
|
* 2.0
|
||||||
* (traits::isComplex<F>() ? 2.0 : 1.0)
|
* (traits::isComplex<F>() ? 2.0 : 1.0)
|
||||||
* 6
|
* 6.0
|
||||||
/ 1e9
|
/ 1e9
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user