Skip to content

Commit 973b539

Browse files
committed
Add and refactor unittests for ParaGlobal
1 parent 3ea114c commit 973b539

File tree

2 files changed

+109
-15
lines changed

2 files changed

+109
-15
lines changed

source/source_base/parallel_global.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ void Parallel_Global::divide_pools(const int& NPROC,
328328
// and MY_BNDGROUP will be the same as well.
329329
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
330330
{
331-
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
331+
std::cerr << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
332332
<< BNDPAR * KPAR << ")." << std::endl;
333333
exit(1);
334334
}
@@ -385,14 +385,16 @@ void Parallel_Global::divide_mpi_groups(const int& procs,
385385
{
386386
if (num_groups == 0)
387387
{
388+
std::cerr << "Error: Number of groups must be greater than 0." << std::endl;
389+
// note that WARNING_QUIT writes to stdout
388390
ModuleBase::WARNING_QUIT(
389391
"Parallel_Global::divide_mpi_groups",
390392
"Number of groups must be greater than 0."
391393
);
392394
}
393395
if (procs < num_groups)
394396
{
395-
std::cout << "Error: Number of processes (" << procs << ") must be greater than the number of groups ("
397+
std::cerr << "Error: Number of processes (" << procs << ") must be greater than the number of groups ("
396398
<< num_groups << ")." << std::endl;
397399
ModuleBase::WARNING_QUIT(
398400
"Parallel_Global::divide_mpi_groups",

source/source_base/test_parallel/parallel_global_test.cpp

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <cstring>
1010
#include <string>
1111

12-
#include "source_base/tool_quit.h"
12+
#include "source_base/global_variable.h"
1313

1414
/************************************************
1515
* unit test of functions in parallel_global.cpp
@@ -66,6 +66,7 @@ class MPIContext
6666
int _size;
6767
};
6868

69+
// --- Normal Test ---
6970
class ParaGlobal : public ::testing::Test
7071
{
7172
protected:
@@ -162,8 +163,58 @@ TEST_F(ParaGlobal, MyProd)
162163
EXPECT_EQ(inout[1], std::complex<double>(-3.0, -3.0));
163164
}
164165

165-
TEST_F(ParaGlobal, InitPools)
166+
167+
168+
TEST_F(ParaGlobal, DivideMPIPools)
169+
{
170+
this->nproc = 12;
171+
mpi.kpar = 3;
172+
this->my_rank = 5;
173+
Parallel_Global::divide_mpi_groups(this->nproc,
174+
mpi.kpar,
175+
this->my_rank,
176+
mpi.nproc_in_pool,
177+
mpi.my_pool,
178+
mpi.rank_in_pool);
179+
EXPECT_EQ(mpi.nproc_in_pool, 4);
180+
EXPECT_EQ(mpi.my_pool, 1);
181+
EXPECT_EQ(mpi.rank_in_pool, 1);
182+
}
183+
184+
185+
// --- DeathTest: Single thread ---
186+
class ParaGlobalDeathTest : public ::testing::Test
187+
{
188+
protected:
189+
MPIContext mpi;
190+
int nproc;
191+
int my_rank;
192+
193+
// DeathTest SetUp:
194+
// Init variable, single thread
195+
void SetUp() override
196+
{
197+
// Only master process runs death test (avoid multi-process conflict)
198+
if (mpi.GetRank() != 0) {return;}
199+
200+
// init log file
201+
GlobalV::ofs_warning.open("warning.log");
202+
// needed by WARNING_QUIT
203+
}
204+
205+
// clean log file
206+
void TearDown() override
207+
{
208+
if (mpi.GetRank() != 0) {return;}
209+
210+
GlobalV::ofs_warning.close();
211+
remove("warning.log");
212+
}
213+
};
214+
215+
TEST_F(ParaGlobalDeathTest, InitPools)
166216
{
217+
GTEST_FLAG_SET(death_test_style, "threadsafe");
167218
nproc = 12;
168219
mpi.kpar = 3;
169220
mpi.nstogroup = 3;
@@ -178,26 +229,67 @@ TEST_F(ParaGlobal, InitPools)
178229
mpi.MY_BNDGROUP,
179230
mpi.nproc_in_pool,
180231
mpi.rank_in_pool,
181-
mpi.my_pool), ::testing::ExitedWithCode(1), "");
182-
std::string output = testing::internal::GetCapturedStdout();
183-
EXPECT_THAT(output, testing::HasSubstr("Error:"));
232+
mpi.my_pool), ::testing::ExitedWithCode(1), "Error:");
184233
}
185234

186-
187-
TEST_F(ParaGlobal, DivideMPIPools)
235+
TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgEqZero)
188236
{
237+
GTEST_FLAG_SET(death_test_style, "threadsafe");
238+
// test for num_groups == 0,
239+
// Num_group Equals 0
240+
// WARNING_QUIT
189241
this->nproc = 12;
190-
mpi.kpar = 3;
242+
mpi.kpar = 0;
191243
this->my_rank = 5;
192-
Parallel_Global::divide_mpi_groups(this->nproc,
244+
EXPECT_EXIT(
245+
Parallel_Global::divide_mpi_groups(this->nproc,
193246
mpi.kpar,
194247
this->my_rank,
195248
mpi.nproc_in_pool,
196249
mpi.my_pool,
197-
mpi.rank_in_pool);
198-
EXPECT_EQ(mpi.nproc_in_pool, 4);
199-
EXPECT_EQ(mpi.my_pool, 1);
200-
EXPECT_EQ(mpi.rank_in_pool, 1);
250+
mpi.rank_in_pool),
251+
::testing::ExitedWithCode(1),
252+
"Number of groups must be greater than 0."
253+
);
254+
// should WARNING_QUIT inside!
255+
std::string output;
256+
std::ifstream ifs;
257+
ifs.open("warning.log");
258+
getline(ifs,output);
259+
// test output in warning.log file
260+
EXPECT_THAT(output,testing::HasSubstr("warning"));
261+
EXPECT_THAT(output,testing::HasSubstr("Number of groups must be greater than 0."));
262+
ifs.close();
263+
}
264+
265+
TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgGtProc)
266+
{
267+
GTEST_FLAG_SET(death_test_style, "threadsafe");
268+
// test for procs < num_groups
269+
// Num_group GreaterThan Processors
270+
// WARNING_QUIT
271+
this->nproc = 12;
272+
mpi.kpar = 24;
273+
this->my_rank = 5;
274+
EXPECT_EXIT(
275+
Parallel_Global::divide_mpi_groups(this->nproc,
276+
mpi.kpar,
277+
this->my_rank,
278+
mpi.nproc_in_pool,
279+
mpi.my_pool,
280+
mpi.rank_in_pool)
281+
,testing::ExitedWithCode(1),
282+
"Error: Number of processes.*must be greater than the number of groups"
283+
);
284+
// should WARNING_QUIT inside!
285+
std::string output;
286+
std::ifstream ifs;
287+
ifs.open("warning.log");
288+
getline(ifs,output);
289+
// test output in warning.log file
290+
EXPECT_THAT(output,testing::HasSubstr("warning"));
291+
EXPECT_THAT(output,testing::HasSubstr("Number of processes must be greater than the number of groups."));
292+
ifs.close();
201293
}
202294

203295
int main(int argc, char** argv)

0 commit comments

Comments
 (0)