From 50f471b6d218f9919522cfb03ced681233855c74 Mon Sep 17 00:00:00 2001 From: Patrick Schultz Date: Mon, 23 Mar 2026 07:17:02 -0400 Subject: [PATCH] migrate from testng to munit --- build.yaml | 85 +- hail/CLAUDE.md | 0 hail/generate_splits.py | 55 - hail/hail/package.mill | 64 +- .../src/is/hail/annotations/UnsafeUtils.scala | 6 + ...g4j2.properties => log4j2-test.properties} | 0 hail/hail/test/src/is/hail/ExecStrategy.scala | 21 + hail/hail/test/src/is/hail/HailSuite.scala | 89 +- .../test/src/is/hail/LogTestListener.scala | 36 - .../test/src/is/hail/TestCaseSupport.scala | 12 + hail/hail/test/src/is/hail/TestUtils.scala | 57 +- .../test/src/is/hail/TestUtilsSuite.scala | 16 +- .../hail/annotations/AnnotationsSuite.scala | 4 +- .../ApproxCDFAggregatorSuite.scala | 14 +- .../src/is/hail/annotations/RegionSuite.scala | 139 +- .../annotations/StagedConstructorSuite.scala | 151 +- .../src/is/hail/annotations/UnsafeSuite.scala | 59 +- .../test/src/is/hail/asm4s/ASM4SSuite.scala | 205 +- .../test/src/is/hail/asm4s/CodeSuite.scala | 79 +- .../is/hail/backend/ServiceBackendSuite.scala | 221 +- .../src/is/hail/backend/WorkerSuite.scala | 85 +- .../hail/collection/ArrayBuilderSuite.scala | 17 +- .../is/hail/collection/ArrayStackSuite.scala | 45 +- .../is/hail/collection/BinaryHeapSuite.scala | 569 +++-- .../collection/FlipbookIteratorSuite.scala | 29 +- .../test/src/is/hail/expr/ParserSuite.scala | 8 +- .../is/hail/expr/ir/Aggregators2Suite.scala | 67 +- .../is/hail/expr/ir/AggregatorsSuite.scala | 227 +- .../expr/ir/ArrayDeforestationSuite.scala | 3 +- .../is/hail/expr/ir/ArrayFunctionsSuite.scala | 575 +++-- .../is/hail/expr/ir/BlockMatrixIRSuite.scala | 137 +- .../is/hail/expr/ir/CallFunctionsSuite.scala | 234 +- .../is/hail/expr/ir/DictFunctionsSuite.scala | 185 +- .../hail/expr/ir/DistinctlyKeyedSuite.scala | 12 +- .../src/is/hail/expr/ir/EmitStreamSuite.scala | 105 +- .../expr/ir/ExtractIntervalFiltersSuite.scala | 406 ++-- .../is/hail/expr/ir/FoldConstantsSuite.scala | 33 +- .../is/hail/expr/ir/ForwardLetsSuite.scala | 273 +-- .../src/is/hail/expr/ir/FunctionSuite.scala | 34 +- .../hail/expr/ir/GenotypeFunctionsSuite.scala | 40 +- .../test/src/is/hail/expr/ir/IRSuite.scala | 1895 +++++++++-------- .../src/is/hail/expr/ir/IntervalSuite.scala | 84 +- .../is/hail/expr/ir/LiftLiteralsSuite.scala | 3 +- .../is/hail/expr/ir/LocusFunctionsSuite.scala | 39 +- .../is/hail/expr/ir/MathFunctionsSuite.scala | 239 ++- .../src/is/hail/expr/ir/MatrixIRSuite.scala | 132 +- .../src/is/hail/expr/ir/MemoryLeakSuite.scala | 4 +- .../expr/ir/MissingArrayBuilderSuite.scala | 183 +- .../src/is/hail/expr/ir/OrderingSuite.scala | 253 +-- .../test/src/is/hail/expr/ir/PruneSuite.scala | 390 ++-- .../src/is/hail/expr/ir/RandomSuite.scala | 19 +- .../is/hail/expr/ir/RequirednessSuite.scala | 79 +- .../is/hail/expr/ir/SetFunctionsSuite.scala | 24 +- .../src/is/hail/expr/ir/SimplifySuite.scala | 877 ++++---- .../is/hail/expr/ir/StagedBTreeSuite.scala | 85 +- .../is/hail/expr/ir/StagedMinHeapSuite.scala | 60 +- .../hail/expr/ir/StringFunctionsSuite.scala | 171 +- .../is/hail/expr/ir/StringLengthSuite.scala | 4 +- .../is/hail/expr/ir/StringSliceSuite.scala | 43 +- .../src/is/hail/expr/ir/TableIRSuite.scala | 390 ++-- .../hail/expr/ir/TakeByAggregatorSuite.scala | 24 +- .../is/hail/expr/ir/UtilFunctionsSuite.scala | 17 +- .../is/hail/expr/ir/agg/DownsampleSuite.scala | 6 +- .../ir/agg/StagedBlockLinkedListSuite.scala | 12 +- .../expr/ir/analyses/SemanticHashSuite.scala | 520 +++-- .../expr/ir/defs/EncodedLiteralSuite.scala | 5 +- .../lowering/LowerDistributedSortSuite.scala | 14 +- .../is/hail/expr/ir/table/TableGenSuite.scala | 93 +- .../test/src/is/hail/io/ArrayImpexSuite.scala | 6 +- .../test/src/is/hail/io/AvroReaderSuite.scala | 5 +- .../src/is/hail/io/ByteArrayReaderSuite.scala | 20 +- .../test/src/is/hail/io/IndexBTreeSuite.scala | 121 +- .../hail/test/src/is/hail/io/IndexSuite.scala | 362 ++-- .../hail/test/src/is/hail/io/TabixSuite.scala | 76 +- .../is/hail/io/compress/BGzipCodecSuite.scala | 67 +- .../is/hail/io/fs/AzureStorageFSSuite.scala | 27 +- .../hail/test/src/is/hail/io/fs/FSSuite.scala | 112 +- .../is/hail/io/fs/GoogleStorageFSSuite.scala | 30 +- .../src/is/hail/linalg/BlockMatrixSuite.scala | 670 +++--- .../is/hail/linalg/GridPartitionerSuite.scala | 64 +- .../is/hail/linalg/MatrixSparsitySuite.scala | 79 +- .../linalg/RichDenseMatrixDoubleSuite.scala | 13 +- .../linalg/RichIndexedRowMatrixSuite.scala | 41 +- .../src/is/hail/linalg/RowMatrixSuite.scala | 71 +- .../is/hail/linalg/RowPartitionerSuite.scala | 29 +- .../lir/CompileTimeRequirednessSuite.scala | 4 +- .../test/src/is/hail/lir/LIRSplitSuite.scala | 4 +- .../test/src/is/hail/methods/ExprSuite.scala | 79 +- .../is/hail/methods/LocalLDPruneSuite.scala | 87 +- .../is/hail/methods/MultiArray2Suite.scala | 97 +- .../test/src/is/hail/methods/SkatSuite.scala | 3 +- .../src/is/hail/rvd/RVDPartitionerSuite.scala | 95 +- .../is/hail/services/BatchClientSuite.scala | 63 +- .../is/hail/sparkextras/RichRDDSuite.scala | 4 +- .../is/hail/stats/FisherExactTestSuite.scala | 4 +- ...neralizedChiSquaredDistributionSuite.scala | 158 +- .../is/hail/stats/LeveneHaldaneSuite.scala | 29 +- .../stats/LogisticRegressionModelSuite.scala | 7 +- .../test/src/is/hail/stats/StatsSuite.scala | 13 +- .../test/src/is/hail/stats/eigSymDSuite.scala | 15 +- .../is/hail/types/encoded/ETypeSuite.scala | 105 +- .../types/physical/PBaseStructSuite.scala | 6 +- .../is/hail/types/physical/PBinarySuite.scala | 4 +- .../is/hail/types/physical/PCallSuite.scala | 4 +- .../hail/types/physical/PContainerTest.scala | 106 +- .../hail/types/physical/PIntervalSuite.scala | 6 +- .../hail/types/physical/PNDArraySuite.scala | 51 +- .../is/hail/types/physical/PTypeSuite.scala | 102 +- .../types/physical/PhysicalTestUtils.scala | 6 +- .../stypes/concrete/SStructViewSuite.scala | 15 +- .../is/hail/types/virtual/TStructSuite.scala | 268 ++- .../BufferedAggregatorIteratorSuite.scala | 12 +- .../test/src/is/hail/utils/GraphSuite.scala | 50 +- .../src/is/hail/utils/IntervalSuite.scala | 88 +- .../is/hail/utils/PartitionCountsSuite.scala | 28 +- .../src/is/hail/utils/RowIntervalSuite.scala | 13 +- .../is/hail/utils/SemanticVersionSuite.scala | 7 +- .../utils/SpillingCollectIteratorSuite.scala | 8 +- .../is/hail/utils/TreeTraversalSuite.scala | 35 +- .../src/is/hail/utils/UnionFindSuite.scala | 65 +- .../test/src/is/hail/utils/UtilsSuite.scala | 124 +- .../prettyPrint/PrettyPrintWriterSuite.scala | 44 +- .../src/is/hail/variant/GenotypeSuite.scala | 142 +- .../is/hail/variant/LocusIntervalSuite.scala | 477 +++-- .../hail/variant/ReferenceGenomeSuite.scala | 212 +- .../hail/variant/vsm/PartitioningSuite.scala | 4 +- .../utils/src/is/hail/utils/package.scala | 34 +- .../mill-build/src/MvnCoordinate.scala | 1 + hail/testng-build.xml | 10 - hail/testng-fs.xml | 8 - hail/testng-services.xml | 8 - hail/testng.xml | 12 - 132 files changed, 7453 insertions(+), 7119 deletions(-) create mode 100644 hail/CLAUDE.md delete mode 100644 hail/generate_splits.py rename hail/hail/test/resources/{log4j2.properties => log4j2-test.properties} (100%) create mode 100644 hail/hail/test/src/is/hail/ExecStrategy.scala delete mode 100644 hail/hail/test/src/is/hail/LogTestListener.scala create mode 100644 hail/hail/test/src/is/hail/TestCaseSupport.scala delete mode 100644 hail/testng-build.xml delete mode 100644 hail/testng-fs.xml delete mode 100644 hail/testng-services.xml delete mode 100644 hail/testng.xml diff --git a/build.yaml b/build.yaml index 52ba7d06881..21ad635bf36 100644 --- a/build.yaml +++ b/build.yaml @@ -868,15 +868,8 @@ steps: set -ex cd /io/repo/hail - export MILLOPTS='--no-daemon' HAIL_BUILD_MODE=CI - time retry sh mill --no-daemon --version - - # See `build_hail_jar_and_wheel` - time retry make shadowTestJar - if [ ! -f out/hail/2.12/test/assembly.dest/out.jar ]; then - echo 'no out.jar found after mill assembly returned. going to sleep' - sleep 5 - fi + export HAIL_BUILD_MODE=CI + time sh mill --no-daemon hail[2.12].test.{compile,testForkGrouping} inputs: - from: /repo to: /io/repo @@ -900,8 +893,6 @@ steps: time tar czf test.tar.gz -C python test pytest.ini time tar czf resources.tar.gz -C hail/test resources time tar czf data.tar.gz -C python/hail/docs data - time TESTNG_SPLITS=5 python3 generate_splits.py - time tar czf splits.tar.gz testng-splits-*.xml inputs: - from: /repo to: /io/repo @@ -910,8 +901,6 @@ steps: to: /test.tar.gz - from: /io/repo/hail/resources.tar.gz to: /resources.tar.gz - - from: /io/repo/hail/splits.tar.gz - to: /splits.tar.gz - from: /io/repo/hail/data.tar.gz to: /data.tar.gz dependsOn: @@ -1024,22 +1013,15 @@ steps: cpu: '2' script: | set -ex - cd /io - mkdir -p hail/test - tar xzf resources.tar.gz -C hail/test - tar xzf splits.tar.gz - export HAIL_TEST_SKIP_R=1 - java -cp hail-test.jar:$SPARK_HOME/jars/* org.testng.TestNG -listener is.hail.LogTestListener testng-splits-$HAIL_RUN_IMAGE_SPLIT_INDEX.xml + cd /io/repo/hail + export HAIL_BUILD_MODE='CI' + export NO_COLOR=1 + sh mill --no-daemon -j2 hail[2.12].test.testCISplit $HAIL_RUN_IMAGE_SPLIT_INDEX inputs: - - from: /resources.tar.gz - to: /io/resources.tar.gz - - from: /derived/debug/hail/out/hail/2.12/test/assembly.dest/out.jar - to: /io/hail-test.jar - - from: /splits.tar.gz - to: /io/splits.tar.gz - outputs: - - from: /io/test-output - to: /test-output + - from: /repo + to: /io/repo + - from: /derived/debug/hail/out + to: /io/repo/hail/out secrets: - name: test-gsa-key namespace: @@ -3589,10 +3571,7 @@ steps: memory: standard cpu: '2' script: | - set -ex - cd /io - mkdir -p src/test - tar xzf resources.tar.gz -C src/test + set -x export HAIL_CLOUD={{ global.cloud }} export HAIL_DEFAULT_NAMESPACE={{ default_ns.name }} @@ -3600,12 +3579,10 @@ steps: export HAIL_FS_TEST_CLOUD_RESOURCES_URI={{ global.test_storage_uri }}/{{ upload_test_resources_to_blob_storage.token }}/test/resources/fs export HAIL_TEST_STORAGE_URI={{ global.test_storage_uri }} - set +e - java -Xms7500M -Xmx7500M \ - -cp hail-test.jar:$SPARK_HOME/jars/* \ - org.testng.TestNG \ - -listener is.hail.LogTestListener \ - testng-fs.xml + cd /io/repo/hail + export HAIL_BUILD_MODE='CI' + export NO_COLOR=1 + sh mill --no-daemon -j2 hail[2.12].test.test-fs exit_code=$? set -e if [[ $exit_code -eq 2 ]] @@ -3616,12 +3593,10 @@ steps: exit $exit_code fi inputs: - - from: /resources.tar.gz - to: /io/resources.tar.gz - - from: /derived/debug/hail/out/hail/2.12/test/assembly.dest/out.jar - to: /io/hail-test.jar - - from: /repo/hail/testng-fs.xml - to: /io/testng-fs.xml + - from: /repo + to: /io/repo + - from: /derived/debug/hail/out + to: /io/repo/hail/out secrets: - name: test-gsa-key namespace: @@ -3649,21 +3624,15 @@ steps: export HAIL_CLOUD={{ global.cloud }} export HAIL_DEFAULT_NAMESPACE={{ default_ns.name }} - cd /io - mkdir -p src/test - tar xzf resources.tar.gz -C src/test - java -Xms7500M -Xmx7500M \ - -cp hail-test.jar:$SPARK_HOME/jars/* \ - org.testng.TestNG \ - -listener is.hail.LogTestListener \ - testng-services.xml + cd /io/repo/hail + export HAIL_BUILD_MODE='CI' + export NO_COLOR=1 + sh mill --no-daemon -j2 hail[2.12].test.test-fs inputs: - - from: /resources.tar.gz - to: /io/resources.tar.gz - - from: /derived/debug/hail/out/hail/2.12/test/assembly.dest/out.jar - to: /io/hail-test.jar - - from: /repo/hail/testng-services.xml - to: /io/testng-services.xml + - from: /repo + to: /io/repo + - from: /derived/debug/hail/out + to: /io/repo/hail/out secrets: - name: test-gsa-key namespace: diff --git a/hail/CLAUDE.md b/hail/CLAUDE.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/generate_splits.py b/hail/generate_splits.py deleted file mode 100644 index 14738e23ec4..00000000000 --- a/hail/generate_splits.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -import random -from pathlib import Path -from typing import List, TypeVar - -T = TypeVar("T") - - -def partition(k: int, ls: List[T]) -> List[List[T]]: - assert k > 0 - assert ls - - n = len(ls) - parts = [(n - i + k - 1) // k for i in range(k)] - assert sum(parts) == n - assert max(parts) - min(parts) <= 1 - - out = [] - start = 0 - for part in parts: - out.append(ls[start : start + part]) - start += part - return out - - -test_src_root = Path('hail/test/src') -services_root = test_src_root / 'is/hail/services' -fs_root = test_src_root / 'is/hail/io/fs' -classes = [ - str(Path(dirpath, file).relative_to(test_src_root).with_suffix('')).replace('/', '.') - for dirpath, _dirnames, filenames in map(lambda x: (Path(x[0]), *x[1:]), os.walk(test_src_root)) - for file in filenames - if not dirpath.is_relative_to(services_root) - and not dirpath.is_relative_to(fs_root) - and (file.endswith('.java') or file.endswith('.scala')) -] - -random.shuffle(classes) - -n_splits = int(os.environ['TESTNG_SPLITS']) -splits = partition(n_splits, classes) - -for split_index, split in enumerate(splits): - classes = '\n'.join(f'' for name in split) - with open(f'testng-splits-{split_index}.xml', 'w', encoding='utf-8') as f: - xml = f""" - - - - {classes} - - - -""" - f.write(xml) diff --git a/hail/hail/package.mill b/hail/hail/package.mill index 2a60c3d4696..1d51854425a 100644 --- a/hail/hail/package.mill +++ b/hail/hail/package.mill @@ -2,19 +2,21 @@ package build.hail import mill.* import mill.api.{BuildCtx, Result} +import mill.javalib.testrunner.TestResult import mill.scalalib.* import mill.scalalib.Assembly.* -import mill.scalalib.TestModule.TestNg +import mill.scalalib.TestModule.Munit import mill.util.{Jvm, VcsVersion} - import build.Env import build.Settings import build.HailScalaModule - +import mill.api.daemon.internal.TestReporter import millbuild.BuildConfig.* import millbuild.BuildMode.* import millbuild.MvnCoordinate.* +import scala.collection.immutable.ArraySeq + object `package` extends Cross[RootHailModule](enabledScalaVersions) trait RootHailModule extends CrossScalaModule, HailScalaModule: @@ -172,7 +174,7 @@ trait RootHailModule extends CrossScalaModule, HailScalaModule: PathRef(Task.dest) } - object test extends HailScalaTests, TestNg, CrossValue: + object test extends HailScalaTests, Munit, CrossValue: override def forkArgs: T[Seq[String]] = Seq("-Xss4m", "-Xmx4096M") @@ -181,17 +183,57 @@ trait RootHailModule extends CrossScalaModule, HailScalaModule: override def mvnDeps: T[Seq[Dep]] = outer.mvnDeps() ++ Seq( - `guice` :: "5.1.0", - `mockito-scala` :: "1.17.31", + `munit` :: "1.0.3", `scalacheck` :: "1.18.1", - `scalacheck-1-18` :: "3.2.19.0", - `scalatest` :: "3.2.19", - `scalatest-shouldmatchers` :: "3.2.19", - `testng-7-10` :: "3.2.19.0", + "org.scalameta::munit-scalacheck" :: "1.2.0", + `mockito-scala` :: "1.17.31", ) override def runMvnDeps: T[Seq[Dep]] = - outer.runMvnDeps() + outer.runMvnDeps() ++ outer.compileMvnDeps() override def assemblyRules: Seq[Rule] = outer.assemblyRules + + val numCISplits: Int = 5 + + // FIXME: we used to run this with `-Xms7500M -Xmx7500M`, will need to split into separate test module + // to give it different `forkArgs` + def `test-fs`(args: String*): Task.Command[(msg: String, results: Seq[TestResult])] = Task.Command { + testTask( + Task.Anon { args }, + Task.Anon { Seq("is.hail.io.fs.*") }, + )() + } + + def `test-services`(args: String*): Task.Command[(msg: String, results: Seq[TestResult])] = Task.Command { + testTask( + Task.Anon { args }, + Task.Anon { Seq("is.hail.services.*") }, + )() + } + + def testCISplit(args: String*): Task.Command[(msg: String, results: Seq[TestResult])] = { + require(args.length == 1) + Task.Command { + testTask( + Task.Anon { + Seq.empty + }, ciTestGrouping.map(_.apply(args.head.toInt)) + )() + } + } + + def ciTestGrouping: T[Seq[Seq[String]]] = Task { + val k = numCISplits + val classes = discoveredTestClasses().filter { cls => + !cls.startsWith("is.hail.io.fs") && !cls.startsWith("is.hail.services") + } + val n = classes.length + val shuffled = scala.util.Random(12345).shuffle(classes) + val offsets = Range(0, k).map(i => (n - i + k - 1) / k).scan(0)(_ + _) + assert(offsets.last == n) + Range(0, k).map { i => + shuffled.slice(offsets(i), offsets(i + 1)) + } + } diff --git a/hail/hail/src/is/hail/annotations/UnsafeUtils.scala b/hail/hail/src/is/hail/annotations/UnsafeUtils.scala index 8234d6547c8..632483cd2cc 100644 --- a/hail/hail/src/is/hail/annotations/UnsafeUtils.scala +++ b/hail/hail/src/is/hail/annotations/UnsafeUtils.scala @@ -30,6 +30,12 @@ object UnsafeUtils { offset & -alignment } + def roundUpAlignment(offset: Int, alignment: Int): Int = { + assert(alignment > 0) + assert((alignment & (alignment - 1)) == 0) // power of 2 + (offset + (alignment - 1)) & ~(alignment - 1) + } + def roundDownAlignment(offset: Int, alignment: Int): Int = { assert(alignment > 0) assert((alignment & (alignment - 1)) == 0) // power of 2 diff --git a/hail/hail/test/resources/log4j2.properties b/hail/hail/test/resources/log4j2-test.properties similarity index 100% rename from hail/hail/test/resources/log4j2.properties rename to hail/hail/test/resources/log4j2-test.properties diff --git a/hail/hail/test/src/is/hail/ExecStrategy.scala b/hail/hail/test/src/is/hail/ExecStrategy.scala new file mode 100644 index 00000000000..d39b9ad84da --- /dev/null +++ b/hail/hail/test/src/is/hail/ExecStrategy.scala @@ -0,0 +1,21 @@ +package is.hail + +object ExecStrategy extends Enumeration { + type ExecStrategy = Value + val Interpret, InterpretUnoptimized, JvmCompile, LoweredJVMCompile, JvmCompileUnoptimized = Value + + val unoptimizedCompileOnly: Set[ExecStrategy] = Set(JvmCompileUnoptimized) + val compileOnly: Set[ExecStrategy] = Set(JvmCompile, JvmCompileUnoptimized) + + val javaOnly: Set[ExecStrategy] = + Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + + val interpretOnly: Set[ExecStrategy] = Set(Interpret, InterpretUnoptimized) + + val nonLowering: Set[ExecStrategy] = + Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + + val lowering: Set[ExecStrategy] = Set(LoweredJVMCompile) + val backendOnly: Set[ExecStrategy] = Set(LoweredJVMCompile) + val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) +} diff --git a/hail/hail/test/src/is/hail/HailSuite.scala b/hail/hail/test/src/is/hail/HailSuite.scala index 8b01ab13929..784406dcd27 100644 --- a/hail/hail/test/src/is/hail/HailSuite.scala +++ b/hail/hail/test/src/is/hail/HailSuite.scala @@ -25,9 +25,6 @@ import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} -import org.scalatestplus.testng.TestNGSuite -import org.testng.ITestContext -import org.testng.annotations.{AfterClass, AfterSuite, BeforeClass, BeforeSuite} object HailSuite { private val hcl: HailClassLoader = @@ -36,34 +33,14 @@ object HailSuite { private val flags: HailFeatureFlags = HailFeatureFlags.fromEnv(sys.env + ("lower" -> "1")) - private var backend_ : SparkBackend = _ -} - -class HailSuite extends TestNGSuite with TestUtils with Logging { - - private[this] var ctx_ : ExecuteContext = _ - - override def ctx: ExecuteContext = ctx_ - def backend: Backend = ctx.backend - def fs: FS = ctx.fs - def pool: RegionPool = ctx.r.pool - def sc: SparkContext = ctx.backend.asSpark.sc - def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader - - private[this] lazy val resources: String = - sys.env.getOrElse("MILL_TEST_RESOURCE_DIR", "hail/test/resources") - - def getTestResource(localPath: String): String = s"$resources/$localPath" - - @BeforeSuite - def setupBackend(): Unit = { + private lazy val backend_ : SparkBackend = { RVD.CheckRvdKeyOrderingForTesting = true - HailSuite.backend_ = SparkBackend( + val b = SparkBackend( SparkSession .builder() .config( SparkBackend.createSparkConf( - appName = "Hail.TestNG", + appName = "Hail.MUnit", master = System.getProperty("hail.master"), local = "local[2]", blockSize = 0, @@ -74,10 +51,34 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { .config("spark.ui.enabled", "false") .getOrCreate() ) + sys.addShutdownHook { + b.spark.stop() + b.close() + IRFunctionRegistry.clearUserFunctions() + }: Unit + b } +} - @BeforeClass - def setupExecuteContext(): Unit = { +class HailSuite extends munit.FunSuite with TestCaseSupport with TestUtils with Logging { + private[this] var ctx_ : ExecuteContext = _ + + override def ctx: ExecuteContext = ctx_ + def backend: Backend = ctx.backend + def fs: FS = ctx.fs + def pool: RegionPool = ctx.r.pool + def sc: SparkContext = ctx.backend.asSpark.sc + def theHailClassLoader: HailClassLoader = ctx.theHailClassLoader + + private[this] lazy val resources: String = + sys.env.getOrElse("MILL_TEST_RESOURCE_DIR", "hail/test/resources") + + def getTestResource(localPath: String): String = s"$resources/$localPath" + + override def beforeAll(): Unit = { + super.beforeAll() + // Force backend initialization + HailSuite.backend_ : Unit val conf = new Configuration(HailSuite.backend_.sc.hadoopConfiguration) val fs = new HadoopFS(new SerializableHadoopConfiguration(conf)) val pool = RegionPool() @@ -100,25 +101,19 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { ) } - @AfterClass - def tearDownExecuteContext(context: ITestContext): Unit = { - ctx_.timer.finish() - ctx_.close() - ctx_.r.pool.close() - ctx_ = null - + override def afterAll(): Unit = { + if (ctx_ != null) { + ctx_.timer.finish() + ctx_.close() + ctx_.r.pool.close() + ctx_ = null + } hadoop.fs.FileSystem.closeAll() if (HailSuite.backend_.sc.isStopped) - throw new RuntimeException(s"'${context.getName}' stopped spark context!") - } + throw new RuntimeException(s"test suite stopped spark context!") - @AfterSuite - def tearDownBackend(): Unit = { - HailSuite.backend_.spark.stop() - HailSuite.backend_.close() - HailSuite.backend_ = null - IRFunctionRegistry.clearUserFunctions() + super.afterAll() } def evaluate( @@ -214,8 +209,6 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { logger.error(s"error from strategy $strat", e) if (execStrats.contains(strat)) throw e } - - succeed } } @@ -317,9 +310,9 @@ class HailSuite extends TestNGSuite with TestUtils with Logging { if (execStrats.contains(strat)) throw e } } - val expectedArray = ArraySeq.tabulate(expected.rows)(i => - ArraySeq.tabulate(expected.cols)(j => expected(i, j)) - ) + val expectedArray = Array.tabulate(expected.rows)(i => + Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq + ).toFastSeq assertNDEvals(BlockMatrixCollect(bm), expectedArray)( filteredExecStrats.filterNot(ExecStrategy.interpretOnly) ) diff --git a/hail/hail/test/src/is/hail/LogTestListener.scala b/hail/hail/test/src/is/hail/LogTestListener.scala deleted file mode 100644 index aa5347480ba..00000000000 --- a/hail/hail/test/src/is/hail/LogTestListener.scala +++ /dev/null @@ -1,36 +0,0 @@ -package is.hail - -import java.io.{PrintWriter, StringWriter} - -import org.testng.{ITestContext, ITestListener, ITestResult} - -class LogTestListener extends ITestListener { - def testString(result: ITestResult): String = - s"${result.getTestClass.getName}.${result.getMethod.getMethodName}" - - override def onTestStart(result: ITestResult): Unit = - System.err.println(s"starting test ${testString(result)}...") - - override def onTestSuccess(result: ITestResult): Unit = - System.err.println(s"test ${testString(result)} SUCCESS") - - override def onTestFailure(result: ITestResult): Unit = { - val cause = result.getThrowable - if (cause != null) { - val sw = new StringWriter() - val pw = new PrintWriter(sw) - cause.printStackTrace(pw) - System.err.println(s"Exception:\n$sw") - } - System.err.println(s"test ${testString(result)} FAILURE\n") - } - - override def onTestSkipped(result: ITestResult): Unit = - System.err.println(s"test ${testString(result)} SKIPPED") - - override def onTestFailedButWithinSuccessPercentage(result: ITestResult): Unit = {} - - override def onStart(context: ITestContext): Unit = {} - - override def onFinish(context: ITestContext): Unit = {} -} diff --git a/hail/hail/test/src/is/hail/TestCaseSupport.scala b/hail/hail/test/src/is/hail/TestCaseSupport.scala new file mode 100644 index 00000000000..36b11e2802f --- /dev/null +++ b/hail/hail/test/src/is/hail/TestCaseSupport.scala @@ -0,0 +1,12 @@ +package is.hail + +trait TestCaseSupport { self: munit.FunSuite => + trait TestCases { + var i: Int = 0 + + def test(name: String)(body: => Any)(implicit loc: munit.Location): Unit = { + i += 1 + self.test(s"$name case $i")(body) + } + } +} diff --git a/hail/hail/test/src/is/hail/TestUtils.scala b/hail/hail/test/src/is/hail/TestUtils.scala index dc0229637c3..1cfc0022ada 100644 --- a/hail/hail/test/src/is/hail/TestUtils.scala +++ b/hail/hail/test/src/is/hail/TestUtils.scala @@ -24,33 +24,12 @@ import java.io.PrintWriter import breeze.linalg.{DenseMatrix, Matrix, Vector} import org.apache.spark.SparkException import org.apache.spark.sql.Row -import org.scalatest.{Assertion, Assertions} -object ExecStrategy extends Enumeration { - type ExecStrategy = Value - val Interpret, InterpretUnoptimized, JvmCompile, LoweredJVMCompile, JvmCompileUnoptimized = Value - - val unoptimizedCompileOnly: Set[ExecStrategy] = Set(JvmCompileUnoptimized) - val compileOnly: Set[ExecStrategy] = Set(JvmCompile, JvmCompileUnoptimized) - - val javaOnly: Set[ExecStrategy] = - Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) - - val interpretOnly: Set[ExecStrategy] = Set(Interpret, InterpretUnoptimized) - - val nonLowering: Set[ExecStrategy] = - Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) - - val lowering: Set[ExecStrategy] = Set(LoweredJVMCompile) - val backendOnly: Set[ExecStrategy] = Set(LoweredJVMCompile) - val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) -} - -trait TestUtils extends Assertions { +trait TestUtils extends munit.Assertions { def ctx: ExecuteContext = ??? - def interceptException[E <: Throwable: Manifest](regex: String)(f: => Any): Assertion = { + def interceptException[E <: Throwable: Manifest](regex: String)(f: => Any): Unit = { val thrown = intercept[E](f) val p = regex.r.findFirstIn(thrown.getMessage).isDefined val msg = @@ -61,20 +40,20 @@ trait TestUtils extends Assertions { assert(p, msg) } - def interceptFatal(regex: String)(f: => Any): Assertion = + def interceptFatal(regex: String)(f: => Any): Unit = interceptException[HailException](regex)(f) - def interceptSpark(regex: String)(f: => Any): Assertion = + def interceptSpark(regex: String)(f: => Any): Unit = interceptException[SparkException](regex)(f) - def interceptAssertion(regex: String)(f: => Any): Assertion = + def interceptAssertion(regex: String)(f: => Any): Unit = interceptException[AssertionError](regex)(f) def assertVectorEqualityDouble( A: Vector[Double], B: Vector[Double], tolerance: Double = utils.defaultTolerance, - ): Assertion = { + ): Unit = { assert(A.size == B.size) assert((0 until A.size).forall(i => D_==(A(i), B(i), tolerance))) } @@ -83,7 +62,7 @@ trait TestUtils extends Assertions { A: Matrix[Double], B: Matrix[Double], tolerance: Double = utils.defaultTolerance, - ): Assertion = { + ): Unit = { assert(A.rows == B.rows) assert(A.cols == B.cols) assert((0 until A.rows).forall(i => @@ -262,13 +241,13 @@ trait TestUtils extends Assertions { } } - def assertEvalSame(x: IR): Assertion = + def assertEvalSame(x: IR): Unit = assertEvalSame(x, Env.empty, FastSeq()) - def assertEvalSame(x: IR, args: IndexedSeq[(Any, Type)]): Assertion = + def assertEvalSame(x: IR, args: IndexedSeq[(Any, Type)]): Unit = assertEvalSame(x, Env.empty, args) - def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Assertion = { + def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Unit = { val t = x.typ val (i, i2, c) = @@ -290,7 +269,7 @@ trait TestUtils extends Assertions { assert(t.valuesSimilar(i2, c), s"interpret (optimize = false) $i vs compile $c") } - def assertThrows[E <: Throwable: Manifest](x: IR, regex: String): Assertion = + def assertThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) def assertThrows[E <: Throwable: Manifest]( @@ -298,7 +277,7 @@ trait TestUtils extends Assertions { env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String, - ): Assertion = + ): Unit = ctx.local() { ctx => ctx.flags.set(Optimize, "1") interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) @@ -308,14 +287,14 @@ trait TestUtils extends Assertions { interceptException[E](regex)(eval(x, env, args, None, None, ctx)) } - def assertFatal(x: IR, regex: String): Assertion = + def assertFatal(x: IR, regex: String): Unit = assertThrows[HailException](x, regex) - def assertFatal(x: IR, args: IndexedSeq[(Any, Type)], regex: String): Assertion = + def assertFatal(x: IR, args: IndexedSeq[(Any, Type)], regex: String): Unit = assertThrows[HailException](x, Env.empty[(Any, Type)], args, regex) def assertFatal(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String) - : Assertion = + : Unit = assertThrows[HailException](x, env, args, regex) def assertCompiledThrows[E <: Throwable: Manifest]( @@ -323,13 +302,13 @@ trait TestUtils extends Assertions { env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String, - ): Assertion = + ): Unit = interceptException[E](regex)(eval(x, env, args, None, None, ctx)) - def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Assertion = + def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertCompiledThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) - def assertCompiledFatal(x: IR, regex: String): Assertion = + def assertCompiledFatal(x: IR, regex: String): Unit = assertCompiledThrows[HailException](x, regex) def importVCF( diff --git a/hail/hail/test/src/is/hail/TestUtilsSuite.scala b/hail/hail/test/src/is/hail/TestUtilsSuite.scala index 415b63fd613..78660a5cef3 100644 --- a/hail/hail/test/src/is/hail/TestUtilsSuite.scala +++ b/hail/hail/test/src/is/hail/TestUtilsSuite.scala @@ -1,11 +1,11 @@ package is.hail import breeze.linalg.{DenseMatrix, DenseVector} -import org.testng.annotations.Test +import munit.FailException class TestUtilsSuite extends HailSuite { - @Test def matrixEqualityTest(): Unit = { + test("matrixEquality") { val M = DenseMatrix((1d, 0d), (0d, 1d)) val M1 = DenseMatrix((1d, 0d), (0d, 1.0001d)) val V = DenseVector(0d, 1d) @@ -15,11 +15,11 @@ class TestUtilsSuite extends HailSuite { assertMatrixEqualityDouble(M, M1, 0.001) assertVectorEqualityDouble(V, 2d * V1) - assertThrows[Exception](assertVectorEqualityDouble(V, V1)) - assertThrows[Exception](assertMatrixEqualityDouble(M, M1)) + intercept[FailException](assertVectorEqualityDouble(V, V1)): Unit + intercept[FailException](assertMatrixEqualityDouble(M, M1)): Unit } - @Test def constantVectorTest(): Unit = { + test("constantVector") { assert(isConstant(DenseVector())) assert(isConstant(DenseVector(0))) assert(isConstant(DenseVector(0, 0))) @@ -28,11 +28,9 @@ class TestUtilsSuite extends HailSuite { assert(!isConstant(DenseVector(0, 0, 1))) } - @Test def removeConstantColsTest(): Unit = { + test("removeConstantCols") { val M = DenseMatrix((0, 0, 1, 1, 0), (0, 1, 0, 1, 1)) - val M1 = DenseMatrix((0, 1, 0), (1, 0, 1)) - - assert(removeConstantCols(M) == M1) + assertEquals(removeConstantCols(M), M1) } } diff --git a/hail/hail/test/src/is/hail/annotations/AnnotationsSuite.scala b/hail/hail/test/src/is/hail/annotations/AnnotationsSuite.scala index 1f9a9415f3b..5409dc7bff4 100644 --- a/hail/hail/test/src/is/hail/annotations/AnnotationsSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/AnnotationsSuite.scala @@ -2,11 +2,9 @@ package is.hail.annotations import is.hail.HailSuite -import org.testng.annotations.Test - /** This testing suite evaluates the functionality of the [[is.hail.annotations]] package */ class AnnotationsSuite extends HailSuite { - @Test def testExtendedOrdering(): Unit = { + test("ExtendedOrdering") { val ord = ExtendedOrdering.extendToNull(implicitly[Ordering[Int]]) val rord = ord.reverse diff --git a/hail/hail/test/src/is/hail/annotations/ApproxCDFAggregatorSuite.scala b/hail/hail/test/src/is/hail/annotations/ApproxCDFAggregatorSuite.scala index 91942459117..944480bc6f3 100644 --- a/hail/hail/test/src/is/hail/annotations/ApproxCDFAggregatorSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/ApproxCDFAggregatorSuite.scala @@ -2,19 +2,14 @@ package is.hail.annotations import is.hail.expr.ir.agg._ -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class ApproxCDFAggregatorSuite extends TestNGSuite { - @Test - def testMerge(): Unit = { +class ApproxCDFAggregatorSuite extends munit.FunSuite { + test("merge") { val array: Array[Double] = Array(1, 3, 5, 0, 0, 0, 2, 4, 6) ApproxCDFHelper.merge(array, 0, 3, array, 6, 9, array, 3) assert(array.view.slice(3, 9) sameElements Range(1, 7)) } - @Test - def testCompactLevelZero(): Unit = { + test("compact level zero") { val rand = new java.util.Random(1) // first Boolean is `true` val levels: Array[Int] = Array(0, 4, 7, 10) val items: Array[Double] = Array(7, 2, 6, 4, 1, 3, 8, 0, 5, 9) @@ -24,8 +19,7 @@ class ApproxCDFAggregatorSuite extends TestNGSuite { assert(items.view.slice(1, 10) sameElements Array(2, 7, 1, 3, 6, 8, 0, 5, 9)) } - @Test - def testCompactLevel(): Unit = { + test("compact level") { val rand = new java.util.Random(1) // first Boolean is `true` val levels: Array[Int] = Array(0, 3, 6, 9) val items: Array[Double] = Array(7, 2, 4, 1, 3, 8, 0, 5, 9) diff --git a/hail/hail/test/src/is/hail/annotations/RegionSuite.scala b/hail/hail/test/src/is/hail/annotations/RegionSuite.scala index 2b1590e0632..7c88fe7f402 100644 --- a/hail/hail/test/src/is/hail/annotations/RegionSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/RegionSuite.scala @@ -5,75 +5,73 @@ import is.hail.utils.using import scala.collection.mutable.ArrayBuffer -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test +class RegionSuite extends munit.FunSuite { -class RegionSuite extends TestNGSuite { - - @Test def testRegionSizes(): Unit = + test("region sizes") { RegionPool.scoped { pool => pool.scopedSmallRegion(region => Array.range(0, 30).foreach(_ => region.allocate(1, 500))) pool.scopedTinyRegion(region => Array.range(0, 30).foreach(_ => region.allocate(1, 60))) } + } - @Test def testRegionAllocationSimple(): Unit = { + test("region allocation simple") { using(RegionPool(strictMemoryCheck = true)) { pool => - assert(pool.numFreeBlocks() == 0) - assert(pool.numRegions() == 0) - assert(pool.numFreeRegions() == 0) + assertEquals(pool.numFreeBlocks(), 0) + assertEquals(pool.numRegions(), 0) + assertEquals(pool.numFreeRegions(), 0) val r = pool.getRegion(Region.REGULAR) - assert(pool.numRegions() == 1) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numRegions(), 1) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 0) r.clear() - assert(pool.numRegions() == 1) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numRegions(), 1) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 0) r.allocate(Region.SIZES(Region.REGULAR) - 1): Unit r.allocate(16): Unit r.clear() - assert(pool.numRegions() == 1) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 1) + assertEquals(pool.numRegions(), 1) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 1) val r2 = pool.getRegion(Region.SMALL) - assert(pool.numRegions() == 2) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 1) + assertEquals(pool.numRegions(), 2) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 1) val r3 = pool.getRegion(Region.REGULAR) - assert(pool.numRegions() == 3) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numRegions(), 3) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 0) r.invalidate() r2.invalidate() r3.invalidate() - assert(pool.numRegions() == 3) - assert(pool.numFreeRegions() == 3) - assert(pool.numFreeBlocks() == 3) + assertEquals(pool.numRegions(), 3) + assertEquals(pool.numFreeRegions(), 3) + assertEquals(pool.numFreeBlocks(), 3) val r4 = pool.getRegion(Region.TINIER) - assert(pool.numRegions() == 3) - assert(pool.numFreeRegions() == 2) - assert(pool.numFreeBlocks() == 3) + assertEquals(pool.numRegions(), 3) + assertEquals(pool.numFreeRegions(), 2) + assertEquals(pool.numFreeBlocks(), 3) r4.invalidate() } } - @Test def testRegionAllocation(): Unit = { + test("region allocation") { RegionPool.scoped { pool => case class Counts(regions: Int, freeRegions: Int) { def allocate(n: Int): Counts = @@ -92,7 +90,7 @@ class RegionSuite extends TestNGSuite { def assertAfterEquals(c: => Counts): Unit = { before = after after = Counts(pool.numRegions(), pool.numFreeRegions()) - assert(after == c) + assertEquals(after, c) } pool.scopedRegion { region => @@ -115,7 +113,7 @@ class RegionSuite extends TestNGSuite { } } - @Test def testRegionReferences(): Unit = { + test("region references") { RegionPool.scoped { pool => def offset(region: Region) = region.allocate(0) @@ -124,7 +122,7 @@ class RegionSuite extends TestNGSuite { def assertUsesRegions[T](n: Int)(f: => T): T = { val usedRegionCount = numUsed() val res = f - assert(usedRegionCount == numUsed() - n) + assertEquals(usedRegionCount, numUsed() - n) res } @@ -140,9 +138,9 @@ class RegionSuite extends TestNGSuite { offset(r) } - using(region.getParentReference(2, Region.TINY))(r => assert(offset(r) == off2)) + using(region.getParentReference(2, Region.TINY))(r => assertEquals(offset(r), off2)) - using(region.getParentReference(4, Region.SMALL))(r => assert(offset(r) == off4)) + using(region.getParentReference(4, Region.SMALL))(r => assertEquals(offset(r), off4)) assertUsesRegions(-1) { region.unreferenceRegionAtIndex(2) @@ -153,82 +151,81 @@ class RegionSuite extends TestNGSuite { } } - @Test def allocationAtStartOfBlockIsCorrect(): Unit = { + test("allocation at start of block is correct") { using(RegionPool(strictMemoryCheck = true)) { pool => val region = pool.getRegion(Region.REGULAR) val off1 = region.allocate(1, 10) val off2 = region.allocate(1, 10) region.invalidate() - assert(off2 - off1 == 10) + assertEquals(off2 - off1, 10L) } } - @Test def blocksAreNotReleasedUntilRegionIsReleased(): Unit = { + test("blocks are not released until region is released") { using(RegionPool(strictMemoryCheck = true)) { pool => val region = pool.getRegion(Region.REGULAR) val nBlocks = 5 (0 until (Region.SIZES(Region.REGULAR)).toInt * nBlocks by 256).foreach { _ => region.allocate(1, 256) } - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numFreeBlocks(), 0) region.invalidate() - assert(pool.numFreeBlocks() == 5) + assertEquals(pool.numFreeBlocks(), 5) } } - @Test def largeChunksAreNotReturnedToBlockPool(): Unit = { + test("large chunks are not returned to block pool") { using(RegionPool(strictMemoryCheck = true)) { pool => val region = pool.getRegion(Region.REGULAR) region.allocate(4, Region.SIZES(Region.REGULAR) - 4): Unit - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numFreeBlocks(), 0) region.allocate(4, 1024 * 1024): Unit region.invalidate() - assert(pool.numFreeBlocks() == 1) + assertEquals(pool.numFreeBlocks(), 1) } } - @Test def referencedRegionsAreNotFreedUntilReferencingRegionIsFreed(): Unit = { + test("referenced regions are not freed until referencing region is freed") { using(RegionPool(strictMemoryCheck = true)) { pool => val r1 = pool.getRegion() val r2 = pool.getRegion() r2.addReferenceTo(r1) r1.invalidate() - assert(pool.numRegions() == 2) - assert(pool.numFreeRegions() == 0) + assertEquals(pool.numRegions(), 2) + assertEquals(pool.numFreeRegions(), 0) r2.invalidate() - assert(pool.numRegions() == 2) - assert(pool.numFreeRegions() == 2) + assertEquals(pool.numRegions(), 2) + assertEquals(pool.numFreeRegions(), 2) } } - @Test def blockSizesWorkAsExpected(): Unit = { + test("block sizes work as expected") { using(RegionPool(strictMemoryCheck = true)) { pool => - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 0) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 0) val region1 = pool.getRegion() - assert(region1.blockSize == Region.REGULAR) + assertEquals(region1.blockSize, Region.REGULAR) region1.invalidate() - assert(pool.numFreeRegions() == 1) - assert(pool.numFreeBlocks() == 1) + assertEquals(pool.numFreeRegions(), 1) + assertEquals(pool.numFreeBlocks(), 1) val region2 = pool.getRegion(Region.SMALL) - assert(region2.blockSize == Region.SMALL) + assertEquals(region2.blockSize, Region.SMALL) - assert(pool.numFreeRegions() == 0) - assert(pool.numFreeBlocks() == 1) + assertEquals(pool.numFreeRegions(), 0) + assertEquals(pool.numFreeBlocks(), 1) region2.invalidate() - assert(pool.numFreeRegions() == 1) - assert(pool.numFreeBlocks() == 2) + assertEquals(pool.numFreeRegions(), 1) + assertEquals(pool.numFreeBlocks(), 2) } } - @Test - def testChunkCache(): Unit = { + test("chunk cache") { RegionPool.scoped { pool => val operations = ArrayBuffer[(String, Long)]() @@ -247,22 +244,22 @@ class RegionSuite extends TestNGSuite { ab += chunkCache.getChunk(pool, 400L)._1 chunkCache.freeChunkToCache(ab.pop()) ab += chunkCache.getChunk(pool, 50L)._1 - assert(operations(0) == (("allocate", 512))) + assertEquals(operations(0), (("allocate", 512L))) // 512 size chunk freed from cache to not exceed peak memory - assert(operations(1) == (("free", 0L))) - assert(operations(2) == (("allocate", 64))) + assertEquals(operations(1), (("free", 0L))) + assertEquals(operations(2), (("allocate", 64L))) chunkCache.freeChunkToCache(ab.pop()) // No additional allocate should be made as uses cache ab += chunkCache.getChunk(pool, 50L)._1 - assert(operations.length == 3) + assertEquals(operations.length, 3) ab += chunkCache.getChunk(pool, 40L)._1 chunkCache.freeChunksToCache(ab) - assert(operations(3) == (("allocate", 64))) - assert(operations.length == 4) + assertEquals(operations(3), (("allocate", 64L))) + assertEquals(operations.length, 4) chunkCache.freeAll(pool) - assert(operations(4) == (("free", 0L))) - assert(operations(5) == (("free", 0L))) - assert(operations.length == 6) + assertEquals(operations(4), (("free", 0L))) + assertEquals(operations(5), (("free", 0L))) + assertEquals(operations.length, 6) } } } diff --git a/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala b/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala index cef44deb5fd..a0fb5558d28 100644 --- a/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/StagedConstructorSuite.scala @@ -13,19 +13,15 @@ import is.hail.types.physical.stypes.primitives.SInt32Value import is.hail.utils._ import org.apache.spark.sql.Row -import org.scalatest.matchers.must.Matchers.{be, include} -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class StagedConstructorSuite extends HailSuite with munit.ScalaCheckSuite { val showRVInfo = true def sm = ctx.stateManager - @Test - def testCanonicalString(): Unit = { + test("canonical string") { val rt = PCanonicalString() val input = "hello" val fb = EmitFunctionBuilder[Region, String, Long](ctx, "fb") @@ -63,13 +59,11 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(rt.loadString(rv.offset) == - rt.loadString(rv2.offset)) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals(rt.loadString(rv.offset), rt.loadString(rv2.offset)) } - @Test - def testInt(): Unit = { + test("int") { val rt = PInt32() val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") @@ -102,12 +96,11 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(Region.loadInt(rv.offset) == Region.loadInt(rv2.offset)) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals(Region.loadInt(rv.offset), Region.loadInt(rv2.offset)) } - @Test - def testArray(): Unit = { + test("array") { val rt = PCanonicalArray(PInt32()) val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") @@ -138,14 +131,15 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rt.loadLength(rv.offset) == 1) - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(Region.loadInt(rt.loadElement(rv.offset, 0)) == - Region.loadInt(rt.loadElement(rv2.offset, 0))) + assertEquals(rt.loadLength(rv.offset), 1) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals( + Region.loadInt(rt.loadElement(rv.offset, 0)), + Region.loadInt(rt.loadElement(rv2.offset, 0)), + ) } - @Test - def testStruct(): Unit = { + test("struct") { val pstring = PCanonicalString() val rt = PCanonicalStruct("a" -> pstring, "b" -> PInt32()) val input = 3 @@ -185,15 +179,18 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(rt.types(0).asInstanceOf[PString].loadString(rt.loadField(rv.offset, 0)) == - rt.types(0).asInstanceOf[PString].loadString(rt.loadField(rv2.offset, 0))) - assert(Region.loadInt(rt.loadField(rv.offset, 1)) == - Region.loadInt(rt.loadField(rv2.offset, 1))) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals( + rt.types(0).asInstanceOf[PString].loadString(rt.loadField(rv.offset, 0)), + rt.types(0).asInstanceOf[PString].loadString(rt.loadField(rv2.offset, 0)), + ) + assertEquals( + Region.loadInt(rt.loadField(rv.offset, 1)), + Region.loadInt(rt.loadField(rv2.offset, 1)), + ) } - @Test - def testArrayOfStruct(): Unit = { + test("array of struct") { val structType = PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString()) val arrayType = PCanonicalArray(structType) val input = "hello" @@ -252,14 +249,13 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(arrayType)) } - assert(rv.pretty(arrayType) == rv2.pretty(arrayType)) + assertEquals(rv.pretty(arrayType), rv2.pretty(arrayType)) assert(new UnsafeIndexedSeq(arrayType, rv.region, rv.offset).sameElements( new UnsafeIndexedSeq(arrayType, rv2.region, rv2.offset) )) } - @Test - def testMissingRandomAccessArray(): Unit = { + test("missing random access array") { val rt = PCanonicalArray(PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString())) val intVal = 20 val strVal = "a string with a partner of 20" @@ -294,14 +290,13 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec } rvb2.endArray() rv2.setOffset(rvb2.end()) - assert(rv.pretty(rt) == rv2.pretty(rt)) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) assert(new UnsafeIndexedSeq(rt, rv.region, rv.offset).sameElements( new UnsafeIndexedSeq(rt, rv2.region, rv2.offset) )) } - @Test - def testSetFieldPresent(): Unit = { + test("set field present") { val rt = PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString(), "c" -> PFloat64()) val intVal = 30 val floatVal = 39.273d @@ -331,15 +326,18 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec rvb2.endStruct() rv2.setOffset(rvb2.end()) - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(Region.loadInt(rt.loadField(rv.offset, 0)) == - Region.loadInt(rt.loadField(rv2.offset, 0))) - assert(Region.loadDouble(rt.loadField(rv.offset, 2)) == - Region.loadDouble(rt.loadField(rv2.offset, 2))) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals( + Region.loadInt(rt.loadField(rv.offset, 0)), + Region.loadInt(rt.loadField(rv2.offset, 0)), + ) + assertEquals( + Region.loadDouble(rt.loadField(rv.offset, 2)), + Region.loadDouble(rt.loadField(rv2.offset, 2)), + ) } - @Test - def testStructWithArray(): Unit = { + test("struct with array") { val tArray = PCanonicalArray(PInt32()) val rt = PCanonicalStruct("a" -> PCanonicalString(), "b" -> tArray) val input = "hello" @@ -403,13 +401,14 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rv.pretty(rt) == rv2.pretty(rt)) - assert(new UnsafeRow(rt, rv.region, rv.offset) == - new UnsafeRow(rt, rv2.region, rv2.offset)) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) + assertEquals( + new UnsafeRow(rt, rv.region, rv.offset), + new UnsafeRow(rt, rv2.region, rv2.offset), + ) } - @Test - def testMissingArray(): Unit = { + test("missing array") { val rt = PCanonicalArray(PInt32()) val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") @@ -439,7 +438,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec println(rv2.pretty(rt)) } - assert(rv.pretty(rt) == rv2.pretty(rt)) + assertEquals(rv.pretty(rt), rv2.pretty(rt)) assert(new UnsafeIndexedSeq(rt, rv.region, rv.offset).sameElements( new UnsafeIndexedSeq(rt, rv2.region, rv2.offset) )) @@ -448,8 +447,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec def printRegion(region: Region, string: String): Unit = println(region.prettyBits()) - @Test - def testAddPrimitive(): Unit = { + test("add primitive") { val t = PCanonicalStruct("a" -> PInt32(), "b" -> PBoolean(), "c" -> PFloat64()) val fb = EmitFunctionBuilder[Region, Int, Boolean, Double, Long](ctx, "fb") @@ -481,8 +479,8 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec ) } - assert(run(3, true, 42.0) == ((3, true, 42.0))) - assert(run(42, false, -1.0) == ((42, false, -1.0))) + assertEquals(run(3, true, 42.0), ((3, true, 42.0))) + assertEquals(run(42, false, -1.0), ((42, false, -1.0))) } def emitCopy(ctx: ExecuteContext, ptype: PType, deepCopy: Boolean) @@ -497,7 +495,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec fb.resultWithIndex() } - @Test def testShallowCopyOfPointersFailsAcrossRegions(): Unit = { + test("shallow copy of pointers fails across regions") { val ptype = PCanonicalStruct(required = true, "a" -> PCanonicalArray(PInt32())) val value = genVal(ctx, ptype).sample.get val ShallowCopy = emitCopy(ctx, ptype, deepCopy = false) @@ -517,30 +515,29 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec } } - ex.getMessage should include("invalid memory access") + assert(ex.getMessage.contains("invalid memory access")) } - @Test def testDeepCopy(): Unit = - forAll(genPTypeVal[PCanonicalStruct](ctx)) { case (t, a: Row) => - val DeepCopy = emitCopy(ctx, t, deepCopy = true) - - val copy: Row = - ctx.scopedExecution { (hcl, fs, htc, r1) => - val validPtr: Long = - using(RegionPool(strictMemoryCheck = true)) { p2 => - using(p2.getRegion()) { r2 => - val offset = ScalaToRegionValue(sm, r2, t, a) - DeepCopy(hcl, fs, htc, r1)(r1, offset) - } + property("deep copy") = forAll(genPTypeVal[PCanonicalStruct](ctx)) { case (t, a: Row) => + val DeepCopy = emitCopy(ctx, t, deepCopy = true) + + val copy: Row = + ctx.scopedExecution { (hcl, fs, htc, r1) => + val validPtr: Long = + using(RegionPool(strictMemoryCheck = true)) { p2 => + using(p2.getRegion()) { r2 => + val offset = ScalaToRegionValue(sm, r2, t, a) + DeepCopy(hcl, fs, htc, r1)(r1, offset) } + } - SafeRow(t, validPtr) - } + SafeRow(t, validPtr) + } - copy should be(a) - } + assertEquals(copy, a) + } - @Test def testUnstagedCopy(): Unit = { + test("unstaged copy") { val t1 = PCanonicalArray( PCanonicalStruct( true, @@ -564,17 +561,17 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec pool.scopedRegion { r => val rvb = new RegionValueBuilder(sm, r) val v1 = t2.unstagedStoreJavaObject(sm, value, r) - assert(SafeRow.read(t2, v1) == value) + assertEquals(SafeRow.read(t2, v1), value) rvb.clear() rvb.start(t1) rvb.addRegionValue(t2, r, v1) val v2 = rvb.end() - assert(SafeRow.read(t1, v2) == value) + assertEquals(SafeRow.read(t1, v2), value) } } - @Test def testStagedCopy(): Unit = { + test("staged copy") { val t1 = PCanonicalStruct( false, "a" -> PCanonicalArray( @@ -601,7 +598,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec val valueT2 = t2.types(0) pool.scopedRegion { r => val v1 = valueT2.unstagedStoreJavaObject(sm, value, r) - assert(SafeRow.read(valueT2, v1) == value) + assertEquals(SafeRow.read(valueT2, v1), value) val f1 = EmitFunctionBuilder[Long](ctx, "stagedCopy1") f1.emitWithBuilder { cb => @@ -614,7 +611,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec ).a } val cp1 = f1.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() - assert(SafeRow.read(t2, cp1) == Row(value)) + assertEquals(SafeRow.read(t2, cp1), Row(value)) val f2 = EmitFunctionBuilder[Long](ctx, "stagedCopy2") f2.emitWithBuilder { cb => @@ -627,7 +624,7 @@ class StagedConstructorSuite extends HailSuite with ScalaCheckDrivenPropertyChec ).a } val cp2 = f2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() - assert(SafeRow.read(t1, cp2) == Row(value)) + assertEquals(SafeRow.read(t1, cp2), Row(value)) } } } diff --git a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala index 38f5d7c9c08..9481e389142 100644 --- a/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala +++ b/hail/hail/test/src/is/hail/annotations/UnsafeSuite.scala @@ -18,10 +18,9 @@ import org.json4s.jackson.Serialization import org.scalacheck._ import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen._ -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.{DataProvider, Test} +import org.scalacheck.Prop.forAll -class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class UnsafeSuite extends HailSuite with munit.ScalaCheckSuite { def subsetType(t: Type): Type = { t match { case t: TStruct => @@ -57,10 +56,6 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { def sm = ctx.stateManager - @DataProvider(name = "codecs") - def codecs(): Array[Array[Any]] = - codecs(ctx) - def codecs(ctx: ExecuteContext): Array[Array[Any]] = (BufferSpec.specs ++ Array(TypedCodecSpec( ctx, @@ -69,13 +64,25 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ))) .map(x => Array[Any](x)) - @Test(dataProvider = "codecs") def testCodecSerialization(codec: Spec): Unit = { - implicit val formats = AbstractRVDSpec.formats - assert(Serialization.read[Spec](codec.toString) == codec) + object checkCodecSerialization extends TestCases { + def apply( + codec: => Spec + )(implicit loc: munit.Location + ): Unit = test("codec serialization") { + implicit val formats = AbstractRVDSpec.formats + val c = codec + assertEquals(Serialization.read[Spec](c.toString), c) + } + } + { + lazy val specs = codecs(ctx) + specs.foreach { case Array(codec: Spec) => + checkCodecSerialization(codec) + } } - @Test def testCodec(): Unit = { + property("codec") { val region = Region(pool = pool) val region2 = Region(pool = pool) val region3 = Region(pool = pool) @@ -147,7 +154,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test def testCodecForNonWrappedTypes(): Unit = { + test("codec for non-wrapped types") { val valuesAndTypes = FastSeq( 5 -> PInt32(), 6L -> PInt64(), @@ -170,7 +177,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val serialized = baos.toByteArray val (decT, dec) = cs2.buildDecoder(ctx, t.virtualType) - assert(decT == t) + assertEquals(decT, t) val res = dec((new ByteArrayInputStream(serialized)), theHailClassLoader).readRegionValue(region) @@ -180,7 +187,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test def testBufferWriteReadDoubles(): Unit = { + test("buffer write read doubles") { val a = Array(1.0, -349.273, 0.0, 9925.467, 0.001) BufferSpec.specs.foreach { bufferSpec => @@ -198,7 +205,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test def testRegionValue(): Unit = { + property("region value") { val region = Region(pool = pool) val region2 = Region(pool = pool) val rvb = new RegionValueBuilder(sm, region) @@ -293,7 +300,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { if v != null } yield (t, v) - @Test def testPacking(): Unit = { + test("packing") { def makeStruct(types: PType*): PCanonicalStruct = PCanonicalStruct(types.zipWithIndex.map { case (t, i) => (s"f$i", t) }: _*) @@ -308,8 +315,8 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { PBoolean(), // 12-13 PBoolean(), ) // 13-14 - assert(t1.byteOffsets.toSeq == Seq(4, 8, 16, 1, 2, 3, 12, 13)) - assert(t1.byteSize == 24) + assertEquals(t1.byteOffsets.toSeq, Seq(4L, 8L, 16L, 1L, 2L, 3L, 12L, 13L)) + assertEquals(t1.byteSize, 24L) val t2 = makeStruct( // missing bytes 0, 1 PBoolean(), // 2-3 @@ -325,21 +332,22 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { PBoolean(), ) // 48-49 - assert(t2.byteOffsets.toSeq == Seq(2, 4, 8, 16, 12, 24, 32, 28, 3, 40, 48)) - assert(t2.byteSize == 49) + assertEquals(t2.byteOffsets.toSeq, Seq(2L, 4L, 8L, 16L, 12L, 24L, 32L, 28L, 3L, 40L, 48L)) + assertEquals(t2.byteSize, 49L) val t3 = makeStruct((0 until 512).map(_ => PFloat64()): _*) - assert(t3.byteSize == (512 / 8) + 512 * 8) + assertEquals(t3.byteSize, (512L / 8) + 512 * 8) val t4 = makeStruct((0 until 256).flatMap(_ => Iterator(PInt32(), PInt32(), PFloat64(), PBoolean()) ): _*) - assert(t4.byteSize == 256 * 4 / 8 + 256 * 4 * 2 + 256 * 8 + 256) + assertEquals(t4.byteSize, 256L * 4 / 8 + 256 * 4 * 2 + 256 * 8 + 256) } - @Test def testEmptySize(): Unit = - assert(PCanonicalStruct().byteSize == 0) + test("empty size") { + assertEquals(PCanonicalStruct().byteSize, 0L) + } - @Test def testUnsafeOrdering(): Unit = { + property("unsafe ordering") { val region = Region(pool = pool) val region2 = Region(pool = pool) @@ -383,6 +391,7 @@ class UnsafeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { println(s"a2=$a2") println(s"c1=$c1, c2=$c2, c3=$c3") } + p } } } diff --git a/hail/hail/test/src/is/hail/asm4s/ASM4SSuite.scala b/hail/hail/test/src/is/hail/asm4s/ASM4SSuite.scala index 88f474ceb6d..ac4a827d355 100644 --- a/hail/hail/test/src/is/hail/asm4s/ASM4SSuite.scala +++ b/hail/hail/test/src/is/hail/asm4s/ASM4SSuite.scala @@ -10,15 +10,14 @@ import scala.language.postfixOps import java.io.PrintWriter import org.scalacheck.Gen.choose -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.{DataProvider, Test} +import org.scalacheck.Prop.forAll trait Z2Z { def apply(z: Boolean): Boolean } -class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class ASM4SSuite extends HailSuite with munit.ScalaCheckSuite { override val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) - @Test def not(): Unit = { + test("not") { val notb = FunctionBuilder[Z2Z]( "is/hail/asm4s/Z2Z", ArraySeq(NotGenericTypeInfo[Boolean]), @@ -30,30 +29,30 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(not(false)) } - @Test def mux(): Unit = { + test("mux") { val gb = FunctionBuilder[Boolean, Int]("G") gb.emit(gb.getArg[Boolean](1).mux(11, -1)) val g = gb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(g(true) == 11) - assert(g(false) == -1) + assertEquals(g(true), 11) + assertEquals(g(false), -1) } - @Test def add(): Unit = { + test("add") { val fb = FunctionBuilder[Int, Int]("F") fb.emit(fb.getArg[Int](1) + 5) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f(-2) == 3) + assertEquals(f(-2), 3) } - @Test def iinc(): Unit = { + test("iinc") { val fb = FunctionBuilder[Int]("F") val l = fb.newLocal[Int]() fb.emit(Code(l := 0, l ++, l += 2, l)) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == 3) + assertEquals(f(), 3) } - @Test def array(): Unit = { + test("array") { val hb = FunctionBuilder[Int, Int]("H") val arr = hb.newLocal[Array[Int]]() hb.emit(Code( @@ -64,46 +63,46 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { arr(hb.getArg[Int](1)), )) val h = hb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(h(0) == 6) - assert(h(1) == 7) - assert(h(2) == -6) + assertEquals(h(0), 6) + assertEquals(h(1), 7) + assertEquals(h(2), -6) } - @Test def get(): Unit = { + test("get") { val fb = FunctionBuilder[Foo, Int]("F") fb.emit(fb.getArg[Foo](1).getField[Int]("i")) val i = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) val a = new Foo - assert(i(a) == 5) + assertEquals(i(a), 5) } - @Test def invoke(): Unit = { + test("invoke") { val fb = FunctionBuilder[Foo, Int]("F") fb.emit(fb.getArg[Foo](1).invoke[Int]("f")) val i = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) val a = new Foo - assert(i(a) == 6) + assertEquals(i(a), 6) } - @Test def invoke2(): Unit = { + test("invoke2") { val fb = FunctionBuilder[Foo, Int]("F") fb.emit(fb.getArg[Foo](1).invoke[Int, Int]("g", 6)) val j = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) val a = new Foo - assert(j(a) == 11) + assertEquals(j(a), 11) } - @Test def newInstance(): Unit = { + test("newInstance") { val fb = FunctionBuilder[Int]("F") fb.emit(Code.newInstance[Foo]().invoke[Int]("f")) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == 6) + assertEquals(f(), 6) } - @Test def put(): Unit = { + test("put") { val fb = FunctionBuilder[Int]("F") val inst = fb.newLocal[Foo]() fb.emit(Code( @@ -112,10 +111,10 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { inst.getField[Int]("i"), )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == -2) + assertEquals(f(), -2) } - @Test def staticPut(): Unit = { + test("staticPut") { val fb = FunctionBuilder[Int]("F") val inst = fb.newLocal[Foo]() fb.emit(Code( @@ -124,17 +123,17 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { Code.getStatic[Foo, Int]("j"), )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == -2) + assertEquals(f(), -2) } - @Test def f2(): Unit = { + test("f2") { val fb = FunctionBuilder[Int, Int, Int]("F") fb.emit(fb.getArg[Int](1) + fb.getArg[Int](2)) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f(3, 5) == 8) + assertEquals(f(3, 5), 8) } - @Test def compare(): Unit = { + test("compare") { val fb = FunctionBuilder[Int, Int, Boolean]("F") fb.emit(fb.getArg[Int](1) > fb.getArg[Int](2)) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) @@ -143,7 +142,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(!f(2, 5)) } - @Test def fact(): Unit = { + test("fact") { val fb = FunctionBuilder[Int, Int]("Fact") val i = fb.getArg[Int](1) fb.emitWithBuilder[Int] { cb => @@ -159,11 +158,11 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f(3) == 6) - assert(f(4) == 24) + assertEquals(f(3), 6) + assertEquals(f(4), 24) } - @Test def dcmp(): Unit = { + test("dcmp") { val fb = FunctionBuilder[Double, Double, Boolean]("F") fb.emit(fb.getArg[Double](1) > fb.getArg[Double](2)) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) @@ -174,7 +173,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(!f(2.3, 5.2)) } - @Test def anewarray(): Unit = { + test("anewarray") { val fb = FunctionBuilder[Int]("F") val arr = fb.newLocal[Array[Foo]]() fb.emit(Code( @@ -184,7 +183,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { arr(0).getField[Int]("i") + arr(1).getField[Int]("i"), )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == 10) + assertEquals(f(), 10) } def fib(_n: Int): Int = { @@ -203,7 +202,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { b } - @Test def fibonacci(): Unit = { + property("fibonacci") { val Fib = FunctionBuilder[Int, Int]("Fib") Fib.emitWithBuilder[Int] { cb => val n = Fib.getArg[Int](1) @@ -228,51 +227,51 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { forAll(choose(0, 100))(i => fib(i) == f(i)) } - // type inference helper - private[this] def refl[A](a: (Code[A], Code[A]) => Code[Boolean]) = - a - - @DataProvider(name = "DoubleComparisonOperator") - def doubleComparisonOperator(): Array[(Code[Double], Code[Double]) => Code[Boolean]] = - Array( - refl(_ < _), - refl(_ <= _), - refl(_ >= _), - refl(_ > _), - refl(_ ceq _), - refl(_ cne _), - ) - - @Test(dataProvider = "DoubleComparisonOperator") - def nanDoubleAlwaysComparesFalse(op: (Code[Double], Code[Double]) => Code[Boolean]): Unit = - forAll { (x: Double) => - val F = FunctionBuilder[Double, Double, Boolean]("CMP") - F.emit(op(F.getArg[Double](1), F.getArg[Double](2))) - val cmp = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - !cmp(Double.NaN, x) + object checkNanDoubleComparisons extends TestCases { + def apply( + op: (Code[Double], Code[Double]) => Code[Boolean], + expected: Boolean, + )(implicit loc: munit.Location + ): Unit = property(s"nan double always compares false") { + forAll { (x: Double) => + val F = FunctionBuilder[Double, Double, Boolean]("CMP") + F.emit(op(F.getArg[Double](1), F.getArg[Double](2))) + val cmp = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) + cmp(Double.NaN, x) == expected && cmp(x, Double.NaN) == expected + } } + } - @DataProvider(name = "FloatComparisonOperator") - def floatComparisonOperator(): Array[(Code[Float], Code[Float]) => Code[Boolean]] = - Array( - refl(_ < _), - refl(_ <= _), - refl(_ >= _), - refl(_ > _), - refl(_ ceq _), - refl(_ cne _), - ) - - @Test(dataProvider = "FloatComparisonOperator") - def nanFloatAlwaysComparesFalse(op: (Code[Float], Code[Float]) => Code[Boolean]): Unit = - forAll { (x: Float) => - val F = FunctionBuilder[Float, Float, Boolean]("CMP") - F.emit(op(F.getArg[Float](1), F.getArg[Float](2))) - val cmp = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - !cmp(Float.NaN, x) + checkNanDoubleComparisons(_ < _, false) + checkNanDoubleComparisons(_ <= _, false) + checkNanDoubleComparisons(_ >= _, false) + checkNanDoubleComparisons(_ > _, false) + checkNanDoubleComparisons(_ ceq _, false) + checkNanDoubleComparisons(_ cne _, true) + + object checkNanFloatComparisons extends TestCases { + def apply( + op: (Code[Float], Code[Float]) => Code[Boolean], + expected: Boolean, + )(implicit loc: munit.Location + ): Unit = property(s"nan float always compares false case") { + forAll { (x: Float) => + val F = FunctionBuilder[Float, Float, Boolean]("CMP") + F.emit(op(F.getArg[Float](1), F.getArg[Float](2))) + val cmp = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) + cmp(Float.NaN, x) == expected && cmp(x, Float.NaN) == expected + } } + } + + checkNanFloatComparisons(_ < _, false) + checkNanFloatComparisons(_ <= _, false) + checkNanFloatComparisons(_ >= _, false) + checkNanFloatComparisons(_ > _, false) + checkNanFloatComparisons(_ ceq _, false) + checkNanFloatComparisons(_ cne _, true) - @Test def defineOpsAsMethods(): Unit = { + test("defineOpsAsMethods") { val fb = FunctionBuilder[Int, Int, Int, Int]("F") val add = fb.genMethod[Int, Int, Int]("add") val sub = fb.genMethod[Int, Int, Int]("sub") @@ -301,12 +300,12 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val f = fb.result(ctx.shouldWriteIRFiles(), Some(new PrintWriter(System.out)))(theHailClassLoader) - assert(f(0, 1, 1) == 2) - assert(f(1, 5, 1) == 4) - assert(f(2, 2, 8) == 16) + assertEquals(f(0, 1, 1), 2) + assertEquals(f(1, 5, 1), 4) + assertEquals(f(2, 2, 8), 16) } - @Test def checkLocalVarsOnMethods(): Unit = { + test("checkLocalVarsOnMethods") { val fb = FunctionBuilder[Int, Int, Int]("F") val add = fb.genMethod[Int, Int, Int]("add") @@ -323,10 +322,10 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { fb.emitWithBuilder(cb => cb.invoke(add, cb.this_, fb.getArg[Int](1), fb.getArg[Int](2))) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f(1, 1) == 2) + assertEquals(f(1, 1), 2) } - @Test def checkClassFields(): Unit = { + test("checkClassFields") { def readField[T: TypeInfo](arg1: Int, arg2: Long, arg3: Boolean): T = { val fb = FunctionBuilder[Int, Long, Boolean, T]("F") @@ -348,12 +347,12 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { f(arg1, arg2, arg3) } - assert(readField[Int](1, 2L, true) == 1) - assert(readField[Long](1, 2L, true) == 2L) + assertEquals(readField[Int](1, 2L, true), 1) + assertEquals(readField[Long](1, 2L, true), 2L) assert(readField[Boolean](1, 2L, true)) } - @Test def checkClassFieldsFromMethod(): Unit = { + test("checkClassFieldsFromMethod") { def readField[T: TypeInfo](arg1: Int, arg2: Long, arg3: Boolean): T = { val fb = FunctionBuilder[Int, Long, Boolean, T]("F") val mb = fb.genMethod[Int, Long, Boolean, T]("m") @@ -378,12 +377,12 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { f(arg1, arg2, arg3) } - assert(readField[Int](1, 2L, true) == 1) - assert(readField[Long](1, 2L, true) == 2L) + assertEquals(readField[Int](1, 2L, true), 1) + assertEquals(readField[Long](1, 2L, true), 2L) assert(readField[Boolean](1, 2L, true)) } - @Test def lazyFieldEvaluatesOnce(): Unit = { + test("lazyFieldEvaluatesOnce") { val F = FunctionBuilder[Int]("LazyField") val a = F.genFieldThisRef[Int]("a") val lzy = F.genLazyFieldThisRef(a + 1, "lzy") @@ -396,10 +395,10 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { )) val f = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f() == 1) + assertEquals(f(), 1) } - @Test def testInitialize(): Unit = { + test("initialize") { val fb = FunctionBuilder[Boolean, Int]("F") fb.emitWithBuilder { cb => val a = cb.newLocal[Int]("a") @@ -407,31 +406,31 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { a } val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(f(true) == 0) - assert(f(false) == 5) + assertEquals(f(true), 0) + assertEquals(f(false), 5) } - @Test def testInit(): Unit = { + test("init") { val Main = FunctionBuilder[Int]("Main") val a = Main.genFieldThisRef[Int]("a") Main.emitInit(a := 1) Main.emit(a) val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(test() == 1) + assertEquals(test(), 1) } - @Test def testClinit(): Unit = { + test("clinit") { val Main = FunctionBuilder[Int]("Main") val a = Main.newStaticField[Int]("a") Main.emitClinit(a.put(1)) Main.emit(a.get()) val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(test() == 1) + assertEquals(test(), 1) } - @Test def testClassInstances(): Unit = { + test("classInstances") { val Counter = FunctionBuilder[Int]("Counter") val x = Counter.genFieldThisRef[Int]("x") Counter.emitInit(x := 0) @@ -454,10 +453,10 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { Counter.result(ctx.shouldWriteIRFiles())(theHailClassLoader): Unit val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - assert(test() == 6) + assertEquals(test(), 6) } - @Test def testIf(): Unit = { + property("if") { val Main = FunctionBuilder[Int, Int]("If") Main.emitWithBuilder[Int] { cb => val a = cb.mb.getArg[Int](1) @@ -470,7 +469,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { forAll((x: Int) => abs(x) == x.abs) } - @Test def testWhile(): Unit = { + property("while") { val Main = FunctionBuilder[Int, Int, Int]("While") Main.emitWithBuilder[Int] { cb => val a = cb.mb.getArg[Int](1) @@ -493,7 +492,7 @@ class ASM4SSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { forAll(choose(-10, 10), choose(-10, 10))((x, y) => add(x, y) == x + y) } - @Test def testFor(): Unit = { + property("for") { val Main = FunctionBuilder[Int, Int, Int]("For") Main.emitWithBuilder[Int] { cb => val a = cb.mb.getArg[Int](1) diff --git a/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala b/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala index 012b5e16ad1..244d3e668ad 100644 --- a/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala +++ b/hail/hail/test/src/is/hail/asm4s/CodeSuite.scala @@ -12,11 +12,10 @@ import is.hail.types.physical.stypes.primitives.{ import is.hail.types.virtual.{TInt32, TInt64, TStruct} import org.apache.spark.sql.Row -import org.testng.annotations.Test class CodeSuite extends HailSuite { - @Test def testForLoop(): Unit = { + test("forLoop") { val fb = EmitFunctionBuilder[Int](ctx, "foo") fb.emitWithBuilder[Int] { cb => val i = cb.newLocal[Int]("i") @@ -27,10 +26,10 @@ class CodeSuite extends HailSuite { } val result = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)() - assert(result == 10) + assertEquals(result, 10) } - @Test def testSizeBasic(): Unit = { + test("sizeBasic") { val int64 = new SInt64Value(5L) val int32 = new SInt32Value(2) val struct = new SStackStructValue( @@ -49,15 +48,16 @@ class CodeSuite extends HailSuite { fb.result()(theHailClassLoader)() } - assert(testSizeHelper(int64) == 8L) - assert(testSizeHelper(int32) == 4L) - assert( - testSizeHelper(struct) == 16L + assertEquals(testSizeHelper(int64), 8L) + assertEquals(testSizeHelper(int32), 4L) + assertEquals( + testSizeHelper(struct), + 16L, ) // 1 missing byte that gets 4 byte aligned, 8 bytes for long, 4 bytes for missing int - assert(testSizeHelper(str) == 7L) // 4 byte header, 3 bytes for the 3 letters. + assertEquals(testSizeHelper(str), 7L) // 4 byte header, 3 bytes for the 3 letters. } - @Test def testArraySizeInBytes(): Unit = { + test("arraySizeInBytes") { val fb = EmitFunctionBuilder[Region, Long](ctx, "test_size_in_bytes") val mb = fb.apply_method val ptype = PCanonicalArray(PInt32()) @@ -73,12 +73,13 @@ class CodeSuite extends HailSuite { } sarray.sizeToStoreInBytes(cb).value }) - assert( - fb.result()(theHailClassLoader)(ctx.r) == 36L + assertEquals( + fb.result()(theHailClassLoader)(ctx.r), + 36L, ) // 2 missing bytes 8 byte aligned + 8 header bytes + 5 elements * 4 bytes for ints. } - @Test def testIntervalSizeInBytes(): Unit = { + test("intervalSizeInBytes") { val fb = EmitFunctionBuilder[Region, Long](ctx, "test_size_in_bytes") val mb = fb.apply_method @@ -111,33 +112,49 @@ class CodeSuite extends HailSuite { ) sval.sizeToStoreInBytes(cb).value }) - assert(fb.result()(theHailClassLoader)(ctx.r) == 72L) // 2 28 byte structs, plus 2 1 byte booleans that get 8 byte for an extra 8 bytes, plus missing bytes. + assertEquals( + fb.result()(theHailClassLoader)(ctx.r), + 72L, + ) // 2 28 byte structs, plus 2 1 byte booleans that get 8 byte for an extra 8 bytes, plus missing bytes. } - @Test def testHash(): Unit = { + test("hash") { val fields = IndexedSeq( PField("a", PCanonicalString(), 0), PField("b", PInt32(), 1), PField("c", PFloat32(), 2), ) - assert(hashTestNumHelper(new SInt32Value(6)) == hashTestNumHelper(new SInt32Value(6))) - assert(hashTestNumHelper(new SInt64Value(5000000000L)) == hashTestNumHelper( - new SInt64Value(5000000000L) - )) - assert( - hashTestNumHelper(new SFloat32Value(3.14f)) == hashTestNumHelper(new SFloat32Value(3.14f)) + assertEquals(hashTestNumHelper(new SInt32Value(6)), hashTestNumHelper(new SInt32Value(6))) + assertEquals( + hashTestNumHelper(new SInt64Value(5000000000L)), + hashTestNumHelper( + new SInt64Value(5000000000L) + ), + ) + assertEquals( + hashTestNumHelper(new SFloat32Value(3.14f)), + hashTestNumHelper(new SFloat32Value(3.14f)), + ) + assertEquals( + hashTestNumHelper(new SFloat64Value(5000000000.89d)), + hashTestNumHelper( + new SFloat64Value(5000000000.89d) + ), + ) + assertEquals(hashTestStringHelper("dog"), hashTestStringHelper("dog")) + assertEquals( + hashTestArrayHelper(IndexedSeq(1, 2, 3, 4, 5, 6)), + hashTestArrayHelper(IndexedSeq(1, 2, + 3, 4, 5, 6)), ) - assert(hashTestNumHelper(new SFloat64Value(5000000000.89d)) == hashTestNumHelper( - new SFloat64Value(5000000000.89d) - )) - assert(hashTestStringHelper("dog") == hashTestStringHelper("dog")) - assert(hashTestArrayHelper(IndexedSeq(1, 2, 3, 4, 5, 6)) == hashTestArrayHelper(IndexedSeq(1, 2, - 3, 4, 5, 6))) assert(hashTestArrayHelper(IndexedSeq(1, 2)) != hashTestArrayHelper(IndexedSeq(3, 4, 5, 6, 7))) - assert(hashTestStructHelper(Row("wolf", 8, .009f), fields) == hashTestStructHelper( - Row("wolf", 8, .009f), - fields, - )) + assertEquals( + hashTestStructHelper(Row("wolf", 8, .009f), fields), + hashTestStructHelper( + Row("wolf", 8, .009f), + fields, + ), + ) assert(hashTestStructHelper(Row("w", 8, .009f), fields) != hashTestStructHelper( Row("opaque", 8, .009f), fields, diff --git a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala index 0aa09f2def5..3c0a031a1f2 100644 --- a/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/hail/test/src/is/hail/backend/ServiceBackendSuite.scala @@ -20,13 +20,10 @@ import java.util.concurrent.CountDownLatch import org.mockito.ArgumentMatchersSugar.any import org.mockito.IdiomaticMockito import org.mockito.MockitoSugar.when -import org.scalatest.OptionValues -import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} -import org.testng.annotations.{DataProvider, Test} -class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionValues { +class ServiceBackendSuite extends HailSuite with IdiomaticMockito { - @Test def testExecutesSinglePartitionLocally(): Unit = + test("ExecutesSinglePartitionLocally") { runMock { (ctx, _, batchClient, backend) => val contexts = ArraySeq.tabulate(10)(_ => Array.emptyByteArray) @@ -42,8 +39,9 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal batchClient.newJobGroup(any[JobGroupRequest]) wasNever called } + } - @Test def testCollectIncrementally(): Unit = + test("CollectIncrementally") { runMock { (ctx, jobConfig, batchClient, backend) => // the service backend expects that each job write its output to a well-known // location when it finishes. @@ -66,17 +64,20 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { jobGroup: JobGroupRequest => - jobGroup.batch_id shouldBe backend.batchConfig.batchId - jobGroup.absolute_parent_id shouldBe backend.batchConfig.jobGroupId + assertEquals(jobGroup.batch_id, backend.batchConfig.batchId) + assertEquals(jobGroup.absolute_parent_id, backend.batchConfig.jobGroupId) val jobs = jobGroup.jobs - jobs.length shouldEqual contexts.length + assertEquals(jobs.length, contexts.length) jobs.foreach { payload => - payload.regions shouldBe jobConfig.regions - payload.resources.value shouldBe JobResources( - preemptible = true, - cpu = jobConfig.worker_cores, - memory = jobConfig.worker_memory, - storage = jobConfig.storage, + assertEquals(payload.regions.map(_.toSeq), jobConfig.regions.map(_.toSeq)) + assertEquals( + payload.resources.get, + JobResources( + preemptible = true, + cpu = jobConfig.worker_cores, + memory = jobConfig.worker_memory, + storage = jobConfig.storage, + ), ) } @@ -92,8 +93,8 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal any[Option[String]], )) thenAnswer { (batchId: Int, _: Int, s: Option[JobState], t: Option[String]) => - s shouldBe Some(JobStates.Success) - t shouldBe (if (getJobGroupJobsCalled > 0) endTime else None) + assertEquals(s, Some(JobStates.Success)) + assertEquals(t, if (getJobGroupJobsCalled > 0) endTime else None) // require more than one call // withhold one job to simulate delays in marking a job complete @@ -118,8 +119,8 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal // make the driver poll for results while the job group is running when(batchClient.getJobGroup(any[Int], any[Int])) thenAnswer { (id: Int, jobGroupId: Int) => - id shouldEqual backend.batchConfig.batchId - jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 + assertEquals(id, backend.batchConfig.batchId) + assertEquals(jobGroupId, backend.batchConfig.jobGroupId + 1) val complete = getJobGroupJobsCalled >= 2 JobGroupResponse( batch_id = id, @@ -143,7 +144,7 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal failure.foreach(throw _) - results.length shouldBe contexts.length + assertEquals(results.length, contexts.length) batchClient.newJobGroup(any[JobGroupRequest]) wasCalled once batchClient.getJobGroup(any[Int], any[Int]) wasCalled thrice batchClient.getJobGroupJobs( @@ -153,100 +154,106 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal any[Option[String]], ) wasCalled thrice } + } + + object checkFailedJobGroup extends TestCases { + def apply( + useFastRestarts: String + )(implicit loc: munit.Location + ): Unit = test("FailedJobGroup") { + runMock { (ctx, _, batchClient, backend) => + ctx.local(flags = ctx.flags + (UseFastRestarts -> useFastRestarts)) { ctx => + val contexts = ArraySeq.tabulate(100)(_ => Array.emptyByteArray) + val startJobId = 2356 + when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { + _: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobId) + } - @DataProvider(name = "UseFastRestarts") - def useFastRestarts: Array[Array[Any]] = - Array(Array(null), Array("1")) + val resultsDir = Path(ctx.tmpdir) / "mapCollectPartitions" / tokenUrlSafe + resultsDir.createDirectory(): Unit - @Test(dataProvider = "UseFastRestarts") - def testFailedJobGroup(useFastRestarts: String): Unit = - runMock { (ctx, _, batchClient, backend) => - ctx.local(flags = ctx.flags + (UseFastRestarts -> useFastRestarts)) { ctx => - val contexts = ArraySeq.tabulate(100)(_ => Array.emptyByteArray) - val startJobId = 2356 - when(batchClient.newJobGroup(any[JobGroupRequest])) thenAnswer { - _: JobGroupRequest => (backend.batchConfig.jobGroupId + 1, startJobId) - } + val successes = ArraySeq(13, 34, 81) // arbitrary indices + if (ctx.flags.isDefined(UseFastRestarts)) + for (i <- successes) + ctx.fs.writePDOS((resultsDir / f"result.$i").toString()) { + os => WireProtocol.write(os, i, Right(i.toString.getBytes())) + } - val resultsDir = Path(ctx.tmpdir) / "mapCollectPartitions" / tokenUrlSafe - resultsDir.createDirectory(): Unit - - val successes = ArraySeq(13, 34, 81) // arbitrary indices - if (ctx.flags.isDefined(UseFastRestarts)) - for (i <- successes) + val failures = ArraySeq(21) + val expectedCause = new NoSuchMethodError("") + for (i <- failures) ctx.fs.writePDOS((resultsDir / f"result.$i").toString()) { - os => WireProtocol.write(os, i, Right(i.toString.getBytes())) + os => WireProtocol.write(os, i, Left(expectedCause)) } - val failures = ArraySeq(21) - val expectedCause = new NoSuchMethodError("") - for (i <- failures) - ctx.fs.writePDOS((resultsDir / f"result.$i").toString()) { - os => WireProtocol.write(os, i, Left(expectedCause)) + when(batchClient.getJobGroup(any[Int], any[Int])) thenAnswer { + (id: Int, jobGroupId: Int) => + JobGroupResponse( + batch_id = id, + job_group_id = jobGroupId, + state = Failure, + complete = false, + n_jobs = contexts.length, + n_completed = successes.length + failures.length, + n_succeeded = successes.length, + n_failed = failures.length, + n_cancelled = contexts.length - failures.length - successes.length, + ) } - when(batchClient.getJobGroup(any[Int], any[Int])) thenAnswer { - (id: Int, jobGroupId: Int) => - JobGroupResponse( - batch_id = id, - job_group_id = jobGroupId, - state = Failure, - complete = false, - n_jobs = contexts.length, - n_completed = successes.length + failures.length, - n_succeeded = successes.length, - n_failed = failures.length, - n_cancelled = contexts.length - failures.length - successes.length, - ) - } - - when(batchClient.getJobGroupJobs( - any[Int], - any[Int], - any[Option[JobState]], - any[Option[String]], - )) thenAnswer { - (batchId: Int, _: Int, s: Option[JobState], _: Option[String]) => - s match { - case Some(JobStates.Failed) => - LazyList(failures.map(i => - JobListEntry( - batch_id = batchId, - job_id = i + startJobId, - state = JobStates.Failed, - exit_code = Some(1), - end_time = Some(""), - ) - )) - - case Some(JobStates.Success) => - ctx.flags.isDefined(UseFastRestarts) shouldBe true - LazyList(successes.map(i => - JobListEntry( - batch_id = batchId, - job_id = i + startJobId, - state = JobStates.Success, - exit_code = Some(0), - end_time = Some(""), - ) - )) - } - } + when(batchClient.getJobGroupJobs( + any[Int], + any[Int], + any[Option[JobState]], + any[Option[String]], + )) thenAnswer { + (batchId: Int, _: Int, s: Option[JobState], _: Option[String]) => + s match { + case Some(JobStates.Failed) => + LazyList(failures.map(i => + JobListEntry( + batch_id = batchId, + job_id = i + startJobId, + state = JobStates.Failed, + exit_code = Some(1), + end_time = Some(""), + ) + )) + + case Some(JobStates.Success) => + assert(ctx.flags.isDefined(UseFastRestarts)) + LazyList(successes.map(i => + JobListEntry( + batch_id = batchId, + job_id = i + startJobId, + state = JobStates.Success, + exit_code = Some(0), + end_time = Some(""), + ) + )) + } + } - val (failure, result) = - backend.runtimeContext(ctx).mapCollectPartitions( - Array.emptyByteArray, - contexts, - "stage1", - )((_, _, _, _) => (_, bytes) => bytes) + val (failure, result) = + backend.runtimeContext(ctx).mapCollectPartitions( + Array.emptyByteArray, + contexts, + "stage1", + )((_, _, _, _) => (_, bytes) => bytes) - val (shortMessage, expanded, id) = handleForPython(expectedCause) - failure.value shouldBe HailWorkerException(failures.head, shortMessage, expanded, id) - if (ctx.flags.isDefined(UseFastRestarts)) result.map(_._2) shouldBe successes + val (shortMessage, expanded, id) = handleForPython(expectedCause) + assertEquals(failure.get, HailWorkerException(failures.head, shortMessage, expanded, id)) + if (ctx.flags.isDefined(UseFastRestarts)) + assertEquals(result.map(_._2), successes) + } } } + } + + checkFailedJobGroup(null) + checkFailedJobGroup("1") - @Test def testCancelledJobGroup(): Unit = + test("CancelledJobGroup") { runMock { (ctx, _, batchClient, backend) => val contexts = ArraySeq.tabulate(2)(_ => Array.emptyByteArray) val startJobId = 2356 @@ -277,10 +284,11 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal "stage1", )((_, _, _, _) => (_, bytes) => bytes) - failure.value shouldBe a[CancellationException] + assert(failure.get.isInstanceOf[CancellationException]) } + } - @Test def testInterrupt(): Unit = + test("Interrupt") { runMock { (ctx, _, batchClient, backend) => val contexts = ArraySeq.tabulate(2)(_ => Array.emptyByteArray) val jobGroupId = Random.nextInt() @@ -319,8 +327,8 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal when(batchClient.cancelJobGroup(any[Int], any[Int])) thenAnswer { (batchId: Int, jgId: Int) => - batchId shouldBe backend.batchConfig.batchId - jgId shouldBe jobGroupId + assertEquals(batchId, backend.batchConfig.batchId) + assertEquals(jgId, jobGroupId) } @volatile var failure: Option[Throwable] = @@ -341,9 +349,10 @@ class ServiceBackendSuite extends HailSuite with IdiomaticMockito with OptionVal t.interrupt() t.join() - failure.value shouldBe a[CancellationException] + assert(failure.get.isInstanceOf[CancellationException]) batchClient.cancelJobGroup(any[Int], any[Int]) wasCalled once } + } def runMock(test: (ExecuteContext, BatchJobConfig, BatchClient, ServiceBackend) => Any): Unit = withObjectSpied[is.hail.utils.UtilsType] { diff --git a/hail/hail/test/src/is/hail/backend/WorkerSuite.scala b/hail/hail/test/src/is/hail/backend/WorkerSuite.scala index a423459c854..ae3b0d3786d 100644 --- a/hail/hail/test/src/is/hail/backend/WorkerSuite.scala +++ b/hail/hail/test/src/is/hail/backend/WorkerSuite.scala @@ -5,55 +5,48 @@ import is.hail.utils.{handleForPython, using, HailWorkerException} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class WorkerSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { - - @Test def testWriteReadSuccess(): Unit = - forAll { (partitionId: Int, payload: Array[Byte]) => - val buffer = - using(new ByteArrayOutputStream()) { bs => - using(new DataOutputStream(bs)) { os => - WireProtocol.write(os, partitionId, Right(payload)) - bs.toByteArray - } - } +import org.scalacheck.Prop.forAll - val (result, readPartition) = - using(new ByteArrayInputStream(buffer)) { bs => - using(new DataInputStream(bs))(is => WireProtocol.read(is).getOrElse(null)) - } +class WorkerSuite extends munit.ScalaCheckSuite { - readPartition shouldBe partitionId - result shouldBe payload - } - - @Test def testWriteReadFailure(): Unit = - forAll { (partitionId: Int, payload: Throwable) => - val buffer = - using(new ByteArrayOutputStream()) { bs => - using(new DataOutputStream(bs)) { os => - WireProtocol.write(os, partitionId, Left(payload)) - bs.toByteArray - } + property("WriteReadSuccess") = forAll { (partitionId: Int, payload: Array[Byte]) => + val buffer = + using(new ByteArrayOutputStream()) { bs => + using(new DataOutputStream(bs)) { os => + WireProtocol.write(os, partitionId, Right(payload)) + bs.toByteArray } - - val exception: HailWorkerException = - using(new ByteArrayInputStream(buffer)) { bs => - using(new DataInputStream(bs))(is => WireProtocol.read(is).left.getOrElse(null)) + } + + val (result, readPartition) = + using(new ByteArrayInputStream(buffer)) { bs => + using(new DataInputStream(bs))(is => WireProtocol.read(is).getOrElse(null)) + } + + readPartition == partitionId && java.util.Arrays.equals(result, payload) + } + + property("WriteReadFailure") = forAll { (partitionId: Int, payload: Throwable) => + val buffer = + using(new ByteArrayOutputStream()) { bs => + using(new DataOutputStream(bs)) { os => + WireProtocol.write(os, partitionId, Left(payload)) + bs.toByteArray } - - val (short, expanded, errorId) = handleForPython(payload) - exception shouldBe HailWorkerException( - partitionId = partitionId, - shortMessage = short, - expandedMessage = expanded, - errorId = errorId, - ) - } + } + + val exception: HailWorkerException = + using(new ByteArrayInputStream(buffer)) { bs => + using(new DataInputStream(bs))(is => WireProtocol.read(is).left.getOrElse(null)) + } + + val (short, expanded, errorId) = handleForPython(payload) + exception == HailWorkerException( + partitionId = partitionId, + shortMessage = short, + expandedMessage = expanded, + errorId = errorId, + ) + } } diff --git a/hail/hail/test/src/is/hail/collection/ArrayBuilderSuite.scala b/hail/hail/test/src/is/hail/collection/ArrayBuilderSuite.scala index 1b2dcc0a2ef..2c98a0da6bf 100644 --- a/hail/hail/test/src/is/hail/collection/ArrayBuilderSuite.scala +++ b/hail/hail/test/src/is/hail/collection/ArrayBuilderSuite.scala @@ -1,28 +1,25 @@ package is.hail.collection -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class ArrayBuilderSuite extends TestNGSuite { - @Test def addOneElement(): Unit = { +class ArrayBuilderSuite extends munit.FunSuite { + test("addOneElement") { val ab = new IntArrayBuilder(0) ab += 3 val a = ab.result() - assert(a.length == 1) - assert(a(0) == 3) + assertEquals(a.length, 1) + assertEquals(a(0), 3) } - @Test def addArray(): Unit = { + test("addArray") { val ab = new IntArrayBuilder(0) ab ++= Array.fill[Int](5)(2) val a = ab.result() - assert(a.length == 5) + assertEquals(a.length, 5) assert(a.forall(_ == 2)) val ab2 = new IntArrayBuilder(0) ab2 ++= (Array.fill[Int](4)(3), 2) val a2 = ab2.result() - assert(a2.length == 2) + assertEquals(a2.length, 2) assert(a2.forall(_ == 3)) ab2(0) = 5 diff --git a/hail/hail/test/src/is/hail/collection/ArrayStackSuite.scala b/hail/hail/test/src/is/hail/collection/ArrayStackSuite.scala index 0a674e50068..9a7f1174e57 100644 --- a/hail/hail/test/src/is/hail/collection/ArrayStackSuite.scala +++ b/hail/hail/test/src/is/hail/collection/ArrayStackSuite.scala @@ -1,46 +1,43 @@ package is.hail.collection -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class ArrayStackSuite extends TestNGSuite { - @Test def test(): Unit = { +class ArrayStackSuite extends munit.FunSuite { + test("basic operations") { val s = new IntArrayStack(4) assert(s.isEmpty) assert(!s.nonEmpty) - assert(s.size == 0) - assert(s.capacity == 4) + assertEquals(s.size, 0) + assertEquals(s.capacity, 4) s.push(13) assert(!s.isEmpty) assert(s.nonEmpty) - assert(s.size == 1) - assert(s.capacity == 4) - assert(s.top == 13) - assert(s(0) == 13) + assertEquals(s.size, 1) + assertEquals(s.capacity, 4) + assertEquals(s.top, 13) + assertEquals(s(0), 13) s.push(0) s.push(-1) s.push(11) s.push(-13) - assert(s.size == 5) + assertEquals(s.size, 5) assert(s.capacity >= 5) - assert(s.top == -13) + assertEquals(s.top, -13) - assert(s(0) == -13) - assert(s(1) == 11) - assert(s(2) == -1) - assert(s(3) == 0) - assert(s(4) == 13) + assertEquals(s(0), -13) + assertEquals(s(1), 11) + assertEquals(s(2), -1) + assertEquals(s(3), 0) + assertEquals(s(4), 13) s(2) = 39 - assert(s.pop() == -13) - assert(s.top == 11) - assert(s.size == 4) + assertEquals(s.pop(), -13) + assertEquals(s.top, 11) + assertEquals(s.size, 4) - assert(s.pop() == 11) - assert(s.pop() == 39) - assert(s.size == 2) + assertEquals(s.pop(), 11) + assertEquals(s.pop(), 39) + assertEquals(s.size, 2) s.pop(): Unit s.pop(): Unit diff --git a/hail/hail/test/src/is/hail/collection/BinaryHeapSuite.scala b/hail/hail/test/src/is/hail/collection/BinaryHeapSuite.scala index 88601306abe..414c685180f 100644 --- a/hail/hail/test/src/is/hail/collection/BinaryHeapSuite.scala +++ b/hail/hail/test/src/is/hail/collection/BinaryHeapSuite.scala @@ -9,84 +9,76 @@ import scala.util.Random import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Gen._ -import org.scalatest -import org.scalatest.matchers.should.Matchers._ -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test - -class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { - @Test - def insertOneIsMax(): Unit = { +import org.scalacheck.Prop.forAll + +class BinaryHeapSuite extends munit.ScalaCheckSuite { + test("insertOneIsMax") { val bh = new BinaryHeap[Int]() bh.insert(1, 10) - assert(bh.max() === 1) - assert(bh.max() === 1) - assert(bh.size === 1) - assert(bh.contains(1) === true) - assert(bh.extractMax() === 1) - assert(bh.contains(1) === false) - assert(bh.size === 0) - - assertThrows[Exception](bh.max()) - assertThrows[Exception](bh.extractMax()) + assertEquals(bh.max(), 1) + assertEquals(bh.max(), 1) + assertEquals(bh.size, 1) + assert(bh.contains(1)) + assertEquals(bh.extractMax(), 1) + assert(!bh.contains(1)) + assertEquals(bh.size, 0) + + intercept[Exception](bh.max()): Unit + intercept[Exception](bh.extractMax()): Unit } - @Test - def twoElements(): Unit = { + test("twoElements") { val bh = new BinaryHeap[Int]() bh.insert(1, 5) - assert(bh.contains(1) == true) + assert(bh.contains(1)) bh.insert(2, 10) - assert(bh.contains(2) == true) - assert(bh.max() === 2) - assert(bh.max() === 2) - assert(bh.size === 2) - assert(bh.extractMax() === 2) - assert(bh.contains(2) == false) - assert(bh.size === 1) - assert(bh.max() === 1) - assert(bh.max() === 1) - assert(bh.size === 1) - assert(bh.extractMax() === 1) - assert(bh.contains(1) == false) - assert(bh.size === 0) + assert(bh.contains(2)) + assertEquals(bh.max(), 2) + assertEquals(bh.max(), 2) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), 2) + assert(!bh.contains(2)) + assertEquals(bh.size, 1) + assertEquals(bh.max(), 1) + assertEquals(bh.max(), 1) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), 1) + assert(!bh.contains(1)) + assertEquals(bh.size, 0) } - @Test - def threeElements(): Unit = { + test("threeElements") { val bh = new BinaryHeap[Int]() bh.insert(1, -10) - assert(bh.contains(1) == true) + assert(bh.contains(1)) bh.insert(2, -5) - assert(bh.contains(2) == true) + assert(bh.contains(2)) bh.insert(3, -7) - assert(bh.contains(3) == true) - assert(bh.max() === 2) - assert(bh.max() === 2) - assert(bh.size === 3) - assert(bh.extractMax() === 2) - assert(bh.size === 2) - - assert(bh.max() === 3) - assert(bh.max() === 3) - assert(bh.size === 2) - assert(bh.extractMax() === 3) - assert(bh.size === 1) - - assert(bh.max() === 1) - assert(bh.max() === 1) - assert(bh.size === 1) - assert(bh.extractMax() === 1) - assert(bh.size === 0) - - assert(bh.contains(1) == false) - assert(bh.contains(2) == false) - assert(bh.contains(3) == false) + assert(bh.contains(3)) + assertEquals(bh.max(), 2) + assertEquals(bh.max(), 2) + assertEquals(bh.size, 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.size, 2) + + assertEquals(bh.max(), 3) + assertEquals(bh.max(), 3) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.size, 1) + + assertEquals(bh.max(), 1) + assertEquals(bh.max(), 1) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.size, 0) + + assert(!bh.contains(1)) + assert(!bh.contains(2)) + assert(!bh.contains(3)) } - @Test - def decreaseToKey1(): Unit = { + test("decreaseToKey1") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -94,14 +86,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriorityTo(2, -10) - assert(bh.max() === 3) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) - assert(bh.extractMax() === 2) + assertEquals(bh.max(), 3) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.extractMax(), 2) } - @Test - def decreaseToKey2(): Unit = { + test("decreaseToKey2") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -109,14 +100,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriorityTo(1, -10) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) } - @Test - def decreaseToKeyButNoOrderingChange(): Unit = { + test("decreaseToKeyButNoOrderingChange") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -124,14 +114,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriorityTo(3, 1) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) } - @Test - def decreaseKey1(): Unit = { + test("decreaseKey1") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -139,14 +128,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriority(2, _ - 110) - assert(bh.max() === 3) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) - assert(bh.extractMax() === 2) + assertEquals(bh.max(), 3) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.extractMax(), 2) } - @Test - def decreaseKey2(): Unit = { + test("decreaseKey2") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -154,14 +142,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriority(1, _ - 10) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) } - @Test - def decreaseKeyButNoOrderingChange(): Unit = { + test("decreaseKeyButNoOrderingChange") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -169,14 +156,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.decreasePriority(3, _ - 9) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 1) } - @Test - def increaseToKey1(): Unit = { + test("increaseToKey1") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -184,14 +170,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.increasePriorityTo(3, 200) - assert(bh.max() === 3) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 3) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 1) } - @Test - def increaseToKeys(): Unit = { + test("increaseToKeys") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -202,14 +187,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.increasePriorityTo(2, 300) bh.increasePriorityTo(1, 250) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 1) - assert(bh.extractMax() === 3) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.extractMax(), 3) } - @Test - def increaseKey1(): Unit = { + test("increaseKey1") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -217,14 +201,13 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(3, 10) bh.increasePriority(3, _ + 190) - assert(bh.max() === 3) - assert(bh.extractMax() === 3) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 1) + assertEquals(bh.max(), 3) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 1) } - @Test - def increaseKeys(): Unit = { + test("increaseKeys") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -235,29 +218,28 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.increasePriority(2, _ + 200) bh.increasePriority(1, _ + 250) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) - assert(bh.extractMax() === 1) - assert(bh.extractMax() === 3) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.extractMax(), 3) } - @Test - def samePriority(): Unit = { + test("samePriority") { val bh = new BinaryHeap[Int]() bh.insert(1, 0) bh.insert(2, 100) bh.insert(3, 0) - assert(bh.max() === 2) - assert(bh.extractMax() === 2) + assertEquals(bh.max(), 2) + assertEquals(bh.extractMax(), 2) - (bh.extractMax(), bh.extractMax()) should (equal((1, 3)) or equal((3, 1))) + val pair = (bh.extractMax(), bh.extractMax()) + assert(pair == ((1, 3)) || pair == ((3, 1))) } - @Test - def successivelyMoreInserts(): Unit = - scalatest.Inspectors.forAll(Seq(2, 4, 8, 16, 32)) { count => + test("successivelyMoreInserts") { + Seq(2, 4, 8, 16, 32).foreach { count => val bh = new BinaryHeap[Int](8) val trace = ArrayBuffer.empty[String] trace += bh.toString() @@ -268,25 +250,26 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { trace += bh.toString() bh.checkHeapProperty() } - assert(bh.size === count) - assert(bh.max() === count - 1) + assertEquals(bh.size, count) + assertEquals(bh.max(), count - 1) - scalatest.Inspectors.forAll(0 until count) { i => + (0 until count).foreach { i => val actual = bh.extractMax() trace += bh.toString() bh.checkHeapProperty() val expected = count - i - 1 - assert( - actual === expected, + assertEquals( + actual, + expected, s"[$count] $actual did not equal $expected, heap: $bh; trace ${trace.mkString("\n")}", ) } assert(bh.isEmpty) } + } - @Test - def growPastCapacity4(): Unit = { + test("growPastCapacity4") { val bh = new BinaryHeap[Int](4) bh.insert(1, 0) bh.insert(2, 0) @@ -295,39 +278,38 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(5, 0) } - @Test - def growPastCapacity32(): Unit = { + test("growPastCapacity32") { val bh = new BinaryHeap[Int](32) for (i <- 0 to 32) bh.insert(i, 0) } - @Test - def shrinkCapacity(): Unit = { + test("shrinkCapacity") { val bh = new BinaryHeap[Int](8) val trace = ArrayBuffer.empty[String] trace += bh.toString() bh.checkHeapProperty() - scalatest.Inspectors.forAll(0 until 64) { i => + (0 until 64).foreach { i => bh.insert(i, i.toLong) trace += bh.toString() bh.checkHeapProperty() } - assert(bh.size === 64, s"trace: ${trace.mkString("\n")}") - assert(bh.max() === 63, s"trace: ${trace.mkString("\n")}") + assertEquals(bh.size, 64, s"trace: ${trace.mkString("\n")}") + assertEquals(bh.max(), 63, s"trace: ${trace.mkString("\n")}") // shrinking happens when size is <1/4 of capacity - scalatest.Inspectors.forAll(0 until (32 + 16 + 1)) { i => + (0 until (32 + 16 + 1)).foreach { i => val actual = bh.extractMax() val expected = 64 - i - 1 trace += bh.toString() bh.checkHeapProperty() - assert( - actual === expected, + assertEquals( + actual, + expected, s"$actual did not equal $expected, trace: ${trace.mkString("\n")}", ) } - assert(bh.size === 15, s"trace: ${trace.mkString("\n")}") - assert(bh.max() === 14, s"trace: ${trace.mkString("\n")}") + assertEquals(bh.size, 15, s"trace: ${trace.mkString("\n")}") + assertEquals(bh.max(), 14, s"trace: ${trace.mkString("\n")}") } sealed trait HeapOp @@ -362,44 +344,34 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { def insert(t: Long, rank: Long): Unit = m += (t -> rank) } - @Test - def sameAsReferenceImplementation(): Unit = { - val ops = - for { - maxOrExtract <- containerOfN[IndexedSeq, HeapOp](64, Gen.oneOf(Max(), ExtractMax())) - ranks <- distinctContainerOfN[IndexedSeq, Long](64, arbitrary[Long]) - inserts = ranks.map(r => Insert(r, r)) - } yield Random.shuffle(inserts ++ maxOrExtract) - - forAll(ops) { opList => - val bh = new BinaryHeap[Long]() - val ref = new LongPriorityQueueReference() - val trace = ArrayBuffer.empty[String] - trace += bh.toString() - scalatest.Inspectors.forAll(opList) { - case Max() => - if (bh.isEmpty && ref.isEmpty) - assert(true, s"trace; ${trace.mkString("\n")}") - else - assert(bh.max() === ref.max(), s"trace; ${trace.mkString("\n")}") - trace += bh.toString() - bh.checkHeapProperty() - succeed - case ExtractMax() => - if (bh.isEmpty && ref.isEmpty) - assert(true, s"trace; ${trace.mkString("\n")}") - else - assert(bh.max() === ref.max(), s"trace; ${trace.mkString("\n")}") - trace += bh.toString() - bh.checkHeapProperty() - succeed - case Insert(t, rank) => - bh.insert(t, rank) - ref.insert(t, rank) - trace += bh.toString() - bh.checkHeapProperty() - assert(bh.size === ref.size, s"trace; ${trace.mkString("\n")}") - } + property("sameAsReferenceImplementation") = forAll( + for { + maxOrExtract <- containerOfN[IndexedSeq, HeapOp](64, Gen.oneOf(Max(), ExtractMax())) + ranks <- distinctContainerOfN[IndexedSeq, Long](64, arbitrary[Long]) + inserts = ranks.map(r => Insert(r, r)) + } yield Random.shuffle(inserts ++ maxOrExtract) + ) { opList => + val bh = new BinaryHeap[Long]() + val ref = new LongPriorityQueueReference() + val trace = ArrayBuffer.empty[String] + trace += bh.toString() + opList.foreach { + case Max() => + if (!(bh.isEmpty && ref.isEmpty)) + assertEquals(bh.max(), ref.max(), s"trace; ${trace.mkString("\n")}") + trace += bh.toString() + bh.checkHeapProperty() + case ExtractMax() => + if (!(bh.isEmpty && ref.isEmpty)) + assertEquals(bh.max(), ref.max(), s"trace; ${trace.mkString("\n")}") + trace += bh.toString() + bh.checkHeapProperty() + case Insert(t, rank) => + bh.insert(t, rank) + ref.insert(t, rank) + trace += bh.toString() + bh.checkHeapProperty() + assertEquals(bh.size, ref.size, s"trace; ${trace.mkString("\n")}") } } @@ -412,140 +384,135 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { 0 } - @Test - def tieBreakingDoesntChangeExistingFunctionality(): Unit = { + test("tieBreakingDoesntChangeExistingFunctionality") { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -10) - assert(bh.contains(1) == true) + assert(bh.contains(1)) bh.insert(2, -5) - assert(bh.contains(2) == true) + assert(bh.contains(2)) bh.insert(3, -7) - assert(bh.contains(3) == true) - assert(bh.max() === 2) - assert(bh.max() === 2) - assert(bh.size === 3) - assert(bh.extractMax() === 2) - assert(bh.size === 2) - - assert(bh.max() === 3) - assert(bh.max() === 3) - assert(bh.size === 2) - assert(bh.extractMax() === 3) - assert(bh.size === 1) - - assert(bh.max() === 1) - assert(bh.max() === 1) - assert(bh.size === 1) - assert(bh.extractMax() === 1) - assert(bh.size === 0) - - assert(bh.contains(1) == false) - assert(bh.contains(2) == false) - assert(bh.contains(3) == false) + assert(bh.contains(3)) + assertEquals(bh.max(), 2) + assertEquals(bh.max(), 2) + assertEquals(bh.size, 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.size, 2) + + assertEquals(bh.max(), 3) + assertEquals(bh.max(), 3) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.size, 1) + + assertEquals(bh.max(), 1) + assertEquals(bh.max(), 1) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.size, 0) + + assert(!bh.contains(1)) + assert(!bh.contains(2)) + assert(!bh.contains(3)) } - @Test - def tieBreakingHappens(): Unit = { + test("tieBreakingHappens") { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -10) - assert(bh.contains(1) == true) + assert(bh.contains(1)) bh.insert(2, -5) - assert(bh.contains(2) == true) + assert(bh.contains(2)) bh.insert(3, -5) - assert(bh.contains(3) == true) - assert(bh.max() === 2) - assert(bh.max() === 2) - assert(bh.size === 3) - assert(bh.extractMax() === 2) - assert(bh.size === 2) - - assert(bh.max() === 3) - assert(bh.max() === 3) - assert(bh.size === 2) - assert(bh.extractMax() === 3) - assert(bh.size === 1) - - assert(bh.max() === 1) - assert(bh.max() === 1) - assert(bh.size === 1) - assert(bh.extractMax() === 1) - assert(bh.size === 0) - - assert(bh.contains(1) == false) - assert(bh.contains(2) == false) - assert(bh.contains(3) == false) + assert(bh.contains(3)) + assertEquals(bh.max(), 2) + assertEquals(bh.max(), 2) + assertEquals(bh.size, 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.size, 2) + + assertEquals(bh.max(), 3) + assertEquals(bh.max(), 3) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.size, 1) + + assertEquals(bh.max(), 1) + assertEquals(bh.max(), 1) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.size, 0) + + assert(!bh.contains(1)) + assert(!bh.contains(2)) + assert(!bh.contains(3)) } - @Test - def tieBreakingThreeWayDeterministic(): Unit = { + test("tieBreakingThreeWayDeterministic") { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -5) - assert(bh.contains(1) == true) + assert(bh.contains(1)) bh.insert(2, -5) - assert(bh.contains(2) == true) + assert(bh.contains(2)) bh.insert(3, -5) - assert(bh.contains(3) == true) - assert(bh.max() === 2) - assert(bh.max() === 2) - assert(bh.size === 3) - assert(bh.extractMax() === 2) - assert(bh.size === 2) + assert(bh.contains(3)) + assertEquals(bh.max(), 2) + assertEquals(bh.max(), 2) + assertEquals(bh.size, 3) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.size, 2) val x = bh.max() val y = if (x == 3) 1 else 3 - assert(x === 3 || x === 1) - assert(bh.max() === x) - assert(bh.size === 2) - assert(bh.extractMax() === x) - assert(bh.size === 1) - - assert(bh.max() === y) - assert(bh.max() === y) - assert(bh.size === 1) - assert(bh.extractMax() === y) - assert(bh.size === 0) - - assert(bh.contains(1) == false) - assert(bh.contains(2) == false) - assert(bh.contains(3) == false) + assert(x == 3 || x == 1) + assertEquals(bh.max(), x) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), x) + assertEquals(bh.size, 1) + + assertEquals(bh.max(), y) + assertEquals(bh.max(), y) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), y) + assertEquals(bh.size, 0) + + assert(!bh.contains(1)) + assert(!bh.contains(2)) + assert(!bh.contains(3)) } - @Test - def tieBreakingThreeWayNonDeterministic(): Unit = { + test("tieBreakingThreeWayNonDeterministic") { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(0, -5) - assert(bh.contains(0) == true) + assert(bh.contains(0)) bh.insert(2, -5) - assert(bh.contains(2) == true) + assert(bh.contains(2)) bh.insert(3, -5) - assert(bh.contains(3) == true) + assert(bh.contains(3)) val firstMax = bh.max() val nextMax = if (firstMax == 2) 0 else 2 - assert(firstMax === 2 || firstMax === 0) - assert(bh.max() === firstMax) - assert(bh.size === 3) - assert(bh.extractMax() === firstMax) - assert(bh.size === 2) - - assert(bh.max() === nextMax) - assert(bh.max() === nextMax) - assert(bh.size === 2) - assert(bh.extractMax() === nextMax) - assert(bh.size === 1) - - assert(bh.max() === 3) - assert(bh.max() === 3) - assert(bh.size === 1) - assert(bh.extractMax() === 3) - assert(bh.size === 0) - - assert(bh.contains(0) == false) - assert(bh.contains(2) == false) - assert(bh.contains(3) == false) + assert(firstMax == 2 || firstMax == 0) + assertEquals(bh.max(), firstMax) + assertEquals(bh.size, 3) + assertEquals(bh.extractMax(), firstMax) + assertEquals(bh.size, 2) + + assertEquals(bh.max(), nextMax) + assertEquals(bh.max(), nextMax) + assertEquals(bh.size, 2) + assertEquals(bh.extractMax(), nextMax) + assertEquals(bh.size, 1) + + assertEquals(bh.max(), 3) + assertEquals(bh.max(), 3) + assertEquals(bh.size, 1) + assertEquals(bh.extractMax(), 3) + assertEquals(bh.size, 0) + + assert(!bh.contains(0)) + assert(!bh.contains(2)) + assert(!bh.contains(3)) } - @Test - def tieBreakingAfterPriorityChange(): Unit = { + test("tieBreakingAfterPriorityChange") { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, 15) bh.insert(2, 10) @@ -553,34 +520,34 @@ class BinaryHeapSuite extends ScalaCheckDrivenPropertyChecks { bh.insert(4, 0) println(bh) - assert(bh.max() === 1) + assertEquals(bh.max(), 1) bh.decreasePriorityTo(1, 10) - assert(bh.max() === 2) + assertEquals(bh.max(), 2) bh.decreasePriorityTo(1, 5) - assert(bh.max() === 2) + assertEquals(bh.max(), 2) bh.increasePriorityTo(1, 10) - assert(bh.max() === 2) + assertEquals(bh.max(), 2) bh.increasePriorityTo(1, 15) - assert(bh.max() === 1) + assertEquals(bh.max(), 1) bh.decreasePriorityTo(1, 10) - assert(bh.extractMax() === 2) - assert(bh.max() === 1) + assertEquals(bh.extractMax(), 2) + assertEquals(bh.max(), 1) bh.increasePriorityTo(4, 10) - assert(bh.extractMax() === 4) - assert(bh.extractMax() === 1) - assert(bh.extractMax() === 3) + assertEquals(bh.extractMax(), 4) + assertEquals(bh.extractMax(), 1) + assertEquals(bh.extractMax(), 3) assert(bh.isEmpty) } diff --git a/hail/hail/test/src/is/hail/collection/FlipbookIteratorSuite.scala b/hail/hail/test/src/is/hail/collection/FlipbookIteratorSuite.scala index 5c8de2edd6e..14d9e73ddf1 100644 --- a/hail/hail/test/src/is/hail/collection/FlipbookIteratorSuite.scala +++ b/hail/hail/test/src/is/hail/collection/FlipbookIteratorSuite.scala @@ -8,8 +8,6 @@ import is.hail.utils.{Muple, OrderingView} import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe._ -import org.testng.annotations.Test - class FlipbookIteratorSuite extends HailSuite { class Box[A: TypeTag] extends AnyRef { @@ -141,26 +139,25 @@ class FlipbookIteratorSuite extends HailSuite { ) } - @Test def flipbookIteratorStartsWithRightValue(): Unit = { + test("flipbookIteratorStartsWithRightValue") { val it: FlipbookIterator[Box[Int]] = makeTestIterator(1, 2, 3, 4, 5) - assert(it.value.value == 1) + assertEquals(it.value.value, 1) } - @Test def makeTestIteratorWorks(): Unit = { + test("makeTestIteratorWorks") { assert(makeTestIterator(1, 2, 3, 4, 5) shouldBe Iterator.range(1, 6)) - assert(makeTestIterator[Int]() shouldBe Iterator.empty) } - @Test def toFlipbookIteratorOnFlipbookIteratorIsIdentity(): Unit = { + test("toFlipbookIteratorOnFlipbookIteratorIsIdentity") { val it1 = makeTestIterator(1, 2, 3) val it2 = Iterator(1, 2, 3) assert(it1.toFlipbookIterator shouldBe it2) assert(makeTestIterator[Int]().toFlipbookIterator shouldBe Iterator.empty) } - @Test def toStaircaseWorks(): Unit = { + test("toStaircaseWorks") { val testIt = makeTestIterator(1, 1, 2, 3, 3, 3) val it = Iterator( Iterator(1, 1), @@ -170,7 +167,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(testIt.staircased(boxOrdView) shouldBe it) } - @Test def orderedZipJoinWorks(): Unit = { + test("orderedZipJoinWorks") { val left = makeTestIterator(1, 2, 4, 1000, 1000) val right = makeTestIterator(2, 3, 4, 1000, 1000) val zipped = left.orderedZipJoin( @@ -185,7 +182,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(zipped shouldBe it) } - @Test def innerJoinDistinctWorks(): Unit = { + test("innerJoinDistinctWorks") { val left = makeTestIterator(1, 2, 2, 4, 1000, 1000) val right = makeTestIterator(2, 4, 4, 5, 1000, 1000) val joined = left.innerJoinDistinct( @@ -201,7 +198,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def leftJoinDistinctWorks(): Unit = { + test("leftJoinDistinctWorks") { val left = makeTestIterator(1, 2, 2, 4, 1000, 1000) val right = makeTestIterator(2, 4, 4, 5, 1000, 1000) val joined = left.leftJoinDistinct( @@ -215,7 +212,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def innerJoinWorks(): Unit = { + test("innerJoinWorks") { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.innerJoin( @@ -232,7 +229,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def leftJoinWorks(): Unit = { + test("leftJoinWorks") { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.leftJoin( @@ -261,7 +258,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def rightJoinWorks(): Unit = { + test("rightJoinWorks") { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.rightJoin( @@ -290,7 +287,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def outerJoinWorks(): Unit = { + test("outerJoinWorks") { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.outerJoin( @@ -323,7 +320,7 @@ class FlipbookIteratorSuite extends HailSuite { assert(joined shouldBe it) } - @Test def multiZipJoinWorks(): Unit = { + test("multiZipJoinWorks") { val one = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val two = makeTestIterator(2, 3, 4, 5, 5, 6, 1000, 1000) val three = makeTestIterator(2, 3, 4, 4, 5, 6, 1000, 1000) diff --git a/hail/hail/test/src/is/hail/expr/ParserSuite.scala b/hail/hail/test/src/is/hail/expr/ParserSuite.scala index 3eeaf3b90a2..16e2a034a8d 100644 --- a/hail/hail/test/src/is/hail/expr/ParserSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ParserSuite.scala @@ -3,17 +3,15 @@ package is.hail.expr import is.hail.HailSuite import is.hail.collection.compat.immutable.ArraySeq -import org.testng.annotations.Test - class ParserSuite extends HailSuite { - @Test def testOneOfLiteral(): Unit = { + test("OneOfLiteral") { val strings = ArraySeq("A", "B", "AB", "AA", "CAD", "EF") val p = Parser.oneOfLiteral(strings) - strings.foreach(s => assert(p.parse(s) == s)) + strings.foreach(s => assertEquals(p.parse(s), s)) assert(p.parseOpt("hello^&").isEmpty) assert(p.parseOpt("ABhello").isEmpty) - assert(Parser.rep(p).parse("ABCADEF") == List("AB", "CAD", "EF")) + assertEquals(Parser.rep(p).parse("ABCADEF"), List("AB", "CAD", "EF")) } } diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index c5055b0dedb..837903d33ad 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -14,11 +14,12 @@ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.variant.{Call0, Call1, Call2} +import scala.concurrent.duration.Duration + import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.testng.annotations.Test class Aggregators2Suite extends HailSuite { + override val munitTimeout = Duration(120, "s") def assertAggEqualsProcessed( aggSig: PhysicalAggSig, @@ -213,7 +214,7 @@ class Aggregators2Suite extends HailSuite { def collectAggSig(t: Type): PhysicalAggSig = PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(t)))) - @Test def TestCount(): Unit = { + test("Count") { val seqOpArgs = ArraySeq.fill(rows.length)(FastSeq[IR]()) assertAggEquals( countAggSig, @@ -224,7 +225,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testSum(): Unit = { + test("Sum") { val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(a, i), "b")) @@ -238,7 +239,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testPrevNonnullStr(): Unit = { + test("PrevNonnullStr") { val aggSig = PhysicalAggSig(PrevNonnull(), TypedStateSig(VirtualTypeWithReq(PCanonicalString()))) val a = Ref(freshName(), arrayType) @@ -255,7 +256,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testPrevNonnull(): Unit = { + test("PrevNonnull") { val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](ArrayRef(a, i))) @@ -268,7 +269,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testProduct(): Unit = { + test("Product") { val aggSig = PhysicalAggSig( Product(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), @@ -286,7 +287,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testCallStats(): Unit = { + test("CallStats") { val t = TStruct("x" -> TCall) val calls = FastSeq( @@ -338,7 +339,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testTakeBy(): Unit = { + test("TakeBy") { val t = TStruct( "a" -> TStruct("x" -> TInt32, "y" -> TInt64), "b" -> TInt32, @@ -396,7 +397,7 @@ class Aggregators2Suite extends HailSuite { (TStruct("x" -> TInt32, "y" -> TInt64), GetField(_, "a")), ) - def test( + def check( n: Int, data: IndexedSeq[Row], valueType: Type, @@ -436,17 +437,17 @@ class Aggregators2Suite extends HailSuite { perm <- permutations so <- FastSeq(Ascending, Descending) } - test(n, perm, t, identity[IR], identity[Row], TInt32, GetField(_, "b"), so) + check(n, perm, t, identity[IR], identity[Row], TInt32, GetField(_, "b"), so) // test key and value types for { (vt, valueF, resultF) <- valueTransformations (kt, keyF) <- keyTransformations } - test(4, permutations.last, vt, valueF, resultF, kt, keyF) + check(4, permutations.last, vt, valueF, resultF, kt, keyF) // test stable sort - test(7, rows, t, identity[IR], identity[Row], TInt64, _ => I64(5L)) + check(7, rows, t, identity[IR], identity[Row], TInt64, _ => I64(5L)) // test GC behavior by passing a large collection val rows2 = ArraySeq.tabulate(1200)(i => Row(i, i.toString)) @@ -491,7 +492,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testTake(): Unit = { + test("Take") { val t = TStruct( "a" -> TStruct("x" -> TInt32, "y" -> TInt64), "b" -> TInt32, @@ -543,7 +544,7 @@ class Aggregators2Suite extends HailSuite { ((x: IR) => GetField(x, f.name), (r: Row) => if (r == null) null else r.get(f.index), f.typ) }.filter(_._3 == TString) - forAll(transformations) { case (irF, rowF, subT) => + transformations.foreach { case (irF, rowF, subT) => val aggSig = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PType.canonical(subT)))) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](irF(ArrayRef(a, i)))) @@ -574,7 +575,7 @@ class Aggregators2Suite extends HailSuite { )) } - @Test def testMin(): Unit = { + test("Min") { val aggSig = PhysicalAggSig(Min(), TypedStateSig(VirtualTypeWithReq(PInt64(false)))) val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(a, i), "b"))) @@ -596,7 +597,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testMax(): Unit = { + test("Max") { val aggSig = PhysicalAggSig(Max(), TypedStateSig(VirtualTypeWithReq(PInt64(false)))) val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(a, i), "b"))) @@ -618,7 +619,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testCollectLongs(): Unit = { + test("CollectLongs") { val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(a, i), "b"))) assertAggEquals( @@ -630,7 +631,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testCollectStrs(): Unit = { + test("CollectStrs") { val a = Ref(freshName(), arrayType) val seqOpArgs = ArraySeq.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(a, i), "a"))) @@ -643,7 +644,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testCollectBig(): Unit = { + test("CollectBig") { val seqOpArgs = ArraySeq.tabulate(100)(i => FastSeq(I64(i.toLong))) assertAggEquals( collectAggSig(TInt64), @@ -654,7 +655,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testArrayElementsAgg(): Unit = { + test("ArrayElementsAgg") { val alState = ArrayLenAggSig(knownLength = false, FastSeq(pnnAggSig, countAggSig, sumAggSig)) val value = FastSeq( @@ -710,7 +711,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testNestedArrayElementsAgg(): Unit = { + test("NestedArrayElementsAgg") { val alstate1 = ArrayLenAggSig(knownLength = false, FastSeq(sumAggSig)) val alstate2 = ArrayLenAggSig(knownLength = false, FastSeq[PhysicalAggSig](alstate1)) @@ -752,7 +753,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testArrayElementsAggTake(): Unit = { + test("ArrayElementsAggTake") { val value = FastSeq( FastSeq(Row("a", 0L), Row("b", 0L), Row("c", 0L), Row("f", 0L)), FastSeq(Row("a", 1L), null, Row("c", 1L), null), @@ -792,7 +793,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testGroup(): Unit = { + test("Group") { val group = GroupedAggSig( VirtualTypeWithReq(PCanonicalString()), FastSeq(pnnAggSig, countAggSig, sumAggSig), @@ -834,7 +835,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testNestedGroup(): Unit = { + test("NestedGroup") { val group1 = GroupedAggSig( VirtualTypeWithReq(PCanonicalString()), @@ -894,7 +895,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testCollectAsSet(): Unit = { + test("CollectAsSet") { val rows = FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) val rref = Ref(freshName(), arrayType) @@ -927,7 +928,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testDownsample(): Unit = { + test("Downsample") { val aggSig = PhysicalAggSig( Downsample(), DownsampleStateSig(VirtualTypeWithReq(PCanonicalArray(PCanonicalString()))), @@ -981,7 +982,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testLoweringMatrixMapColsWithAggFilterAndLets(): Unit = { + test("LoweringMatrixMapColsWithAggFilterAndLets") { val t = MatrixType( TStruct.empty, FastSeq("col_idx"), @@ -1016,7 +1017,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testRunAggScan(): Unit = { + test("RunAggScan") { implicit val execStrats = ExecStrategy.compileOnly val sig = PhysicalAggSig( Sum(), @@ -1034,7 +1035,7 @@ class Aggregators2Suite extends HailSuite { assertEvalsTo(x, FastSeq(0.0, 0.0, 1.0, 3.0, 6.0)) } - @Test def testNestedRunAggScan(): Unit = { + test("NestedRunAggScan") { implicit val execStrats = ExecStrategy.compileOnly val sig = PhysicalAggSig( Sum(), @@ -1063,7 +1064,7 @@ class Aggregators2Suite extends HailSuite { ) } - @Test def testRunAggBasic(): Unit = { + test("RunAggBasic") { implicit val execStrats = ExecStrategy.compileOnly val sig = PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq(PFloat64(true)))) val x = RunAgg( @@ -1078,7 +1079,7 @@ class Aggregators2Suite extends HailSuite { assertEvalsTo(x, Row(-4.0)) } - @Test def testRunAggNested(): Unit = { + test("RunAggNested") { implicit val execStrats = ExecStrategy.compileOnly val sumSig = PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq(PFloat64(true)))) val takeSig = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PFloat64(true)))) @@ -1109,7 +1110,7 @@ class Aggregators2Suite extends HailSuite { assertEvalsTo(x, FastSeq(-1d, 0d, 1d, 2d, 3d)) } - @Test(enabled = false) def testAggStateAndCombOp(): Unit = { + test("AggStateAndCombOp".ignore) { implicit val execStrats = ExecStrategy.compileOnly val takeSig = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PInt64(true)))) val x = bindIR( diff --git a/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala index 9ca456af5b5..071bc1ac07c 100644 --- a/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/AggregatorsSuite.scala @@ -15,7 +15,6 @@ import is.hail.types.virtual._ import is.hail.variant.Call2 import org.apache.spark.sql.Row -import org.testng.annotations.Test class AggregatorsSuite extends HailSuite { @@ -52,7 +51,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def nestedAgg(): Unit = { + test("nestedAgg") { val agg = ToArray(mapIR(StreamRange(0, 10, 1))(_ => ApplyAggOp(Count())())) assertEvalsTo( agg, @@ -61,7 +60,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def sumFloat64(): Unit = { + test("sumFloat64") { runAggregator(Sum(), TFloat64, (0 to 100).map(_.toDouble), 5050.0) runAggregator(Sum(), TFloat64, FastSeq(), 0.0) runAggregator(Sum(), TFloat64, FastSeq(42.0), 42.0) @@ -69,10 +68,11 @@ class AggregatorsSuite extends HailSuite { runAggregator(Sum(), TFloat64, FastSeq(null, null, null), 0.0) } - @Test def sumInt64(): Unit = + test("sumInt64") { runAggregator(Sum(), TInt64, FastSeq(-1L, 2L, 3L), 4L) + } - @Test def collectBoolean(): Unit = { + test("collectBoolean") { runAggregator( Collect(), TBoolean, @@ -81,22 +81,27 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def collectInt(): Unit = + test("collectInt") { runAggregator(Collect(), TInt32, FastSeq(10, null, 5), FastSeq(10, null, 5)) + } - @Test def collectLong(): Unit = + test("collectLong") { runAggregator(Collect(), TInt64, FastSeq(10L, null, 5L), FastSeq(10L, null, 5L)) + } - @Test def collectFloat(): Unit = + test("collectFloat") { runAggregator(Collect(), TFloat32, FastSeq(10f, null, 5f), FastSeq(10f, null, 5f)) + } - @Test def collectDouble(): Unit = + test("collectDouble") { runAggregator(Collect(), TFloat64, FastSeq(10d, null, 5d), FastSeq(10d, null, 5d)) + } - @Test def collectString(): Unit = + test("collectString") { runAggregator(Collect(), TString, FastSeq("hello", null, "foo"), FastSeq("hello", null, "foo")) + } - @Test def collectArray(): Unit = { + test("collectArray") { runAggregator( Collect(), TArray(TInt32), @@ -105,7 +110,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def collectStruct(): Unit = { + test("collectStruct") { runAggregator( Collect(), TStruct("a" -> TInt32, "b" -> TBoolean), @@ -114,7 +119,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def count(): Unit = { + test("count") { runAggregator( Count(), TStruct("x" -> TString), @@ -125,7 +130,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def collectAsSetBoolean(): Unit = { + test("collectAsSetBoolean") { runAggregator( CollectAsSet(), TBoolean, @@ -135,14 +140,14 @@ class AggregatorsSuite extends HailSuite { runAggregator(CollectAsSet(), TBoolean, FastSeq(true, null, true), Set(true, null)) } - @Test def collectAsSetNumeric(): Unit = { + test("collectAsSetNumeric") { runAggregator(CollectAsSet(), TInt32, FastSeq(10, null, 5, 5, null), Set(10, null, 5)) runAggregator(CollectAsSet(), TInt64, FastSeq(10L, null, 5L, 5L, null), Set(10L, null, 5L)) runAggregator(CollectAsSet(), TFloat32, FastSeq(10f, null, 5f, 5f, null), Set(10f, null, 5f)) runAggregator(CollectAsSet(), TFloat64, FastSeq(10d, null, 5d, 5d, null), Set(10d, null, 5d)) } - @Test def collectAsSetString(): Unit = { + test("collectAsSetString") { runAggregator( CollectAsSet(), TString, @@ -151,21 +156,22 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def collectAsSetArray(): Unit = { + test("collectAsSetArray") { val inputCollection = FastSeq(FastSeq(1, 2, 3), null, FastSeq(), null, FastSeq(1, 2, 3)) val expected = Set(FastSeq(1, 2, 3), null, FastSeq()) runAggregator(CollectAsSet(), TArray(TInt32), inputCollection, expected) } - @Test def collectAsSetStruct(): Unit = + test("collectAsSetStruct") { runAggregator( CollectAsSet(), TStruct("a" -> TInt32, "b" -> TBoolean), FastSeq(Row(5, true), Row(3, false), null, Row(0, false), null, Row(5, true)), Set(Row(5, true), Row(3, false), null, Row(0, false)), ) + } - @Test def callStats(): Unit = { + test("callStats") { runAggregator( CallStats(), TCall, @@ -177,22 +183,25 @@ class AggregatorsSuite extends HailSuite { // FIXME Max Boolean not supported by old-style MaxAggregator - @Test def maxInt32(): Unit = { + test("maxInt32") { runAggregator(Max(), TInt32, FastSeq(), null) runAggregator(Max(), TInt32, FastSeq(null), null) runAggregator(Max(), TInt32, FastSeq(-2, null, 7), 7) } - @Test def maxInt64(): Unit = + test("maxInt64") { runAggregator(Max(), TInt64, FastSeq(-2L, null, 7L), 7L) + } - @Test def maxFloat32(): Unit = + test("maxFloat32") { runAggregator(Max(), TFloat32, FastSeq(-2.0f, null, 7.2f), 7.2f) + } - @Test def maxFloat64(): Unit = + test("maxFloat64") { runAggregator(Max(), TFloat64, FastSeq(-2.0, null, 7.2), 7.2) + } - @Test def takeInt32(): Unit = { + test("takeInt32") { runAggregator( Take(), TInt32, @@ -202,7 +211,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeInt64(): Unit = { + test("takeInt64") { runAggregator( Take(), TInt64, @@ -212,7 +221,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeFloat32(): Unit = { + test("takeFloat32") { runAggregator( Take(), TFloat32, @@ -222,7 +231,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeFloat64(): Unit = { + test("takeFloat64") { runAggregator( Take(), TFloat64, @@ -232,7 +241,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeCall(): Unit = { + test("takeCall") { runAggregator( Take(), TCall, @@ -242,7 +251,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeString(): Unit = { + test("takeString") { runAggregator( Take(), TString, @@ -252,8 +261,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def sumMultivar(): Unit = { + test("sumMultivar") { assertEvalsTo( ApplyAggOp( FastSeq(), @@ -285,56 +293,55 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def arraySumFloat64OnEmpty(): Unit = + test("arraySumFloat64OnEmpty") { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(), null, ) + } - @Test - def arraySumFloat64OnSingletonMissing(): Unit = + test("arraySumFloat64OnSingletonMissing") { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(null), null, ) + } - @Test - def arraySumFloat64OnAllMissing(): Unit = + test("arraySumFloat64OnAllMissing") { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(null, null, null), null, ) + } - @Test - def arraySumInt64OnEmpty(): Unit = + test("arraySumInt64OnEmpty") { assertArraySumEvalsTo[Long]( TInt64, FastSeq(), null, ) + } - @Test - def arraySumInt64OnSingletonMissing(): Unit = + test("arraySumInt64OnSingletonMissing") { assertArraySumEvalsTo[Long]( TInt64, FastSeq(null), null, ) + } - @Test - def arraySumInt64OnAllMissing(): Unit = + test("arraySumInt64OnAllMissing") { assertArraySumEvalsTo[Long]( TInt64, FastSeq(null, null, null), null, ) + } - @Test - def arraySumFloat64OnSmallArray(): Unit = + test("arraySumFloat64OnSmallArray") { assertArraySumEvalsTo( TFloat64, FastSeq( @@ -344,9 +351,9 @@ class AggregatorsSuite extends HailSuite { ), FastSeq(11.0, 22.0), ) + } - @Test - def arraySumInt64OnSmallArray(): Unit = + test("arraySumInt64OnSmallArray") { assertArraySumEvalsTo( TInt64, FastSeq( @@ -356,9 +363,9 @@ class AggregatorsSuite extends HailSuite { ), FastSeq(11L, 22L), ) + } - @Test - def arraySumInt64FirstElementMissing(): Unit = + test("arraySumInt64FirstElementMissing") { assertArraySumEvalsTo( TInt64, FastSeq( @@ -368,6 +375,7 @@ class AggregatorsSuite extends HailSuite { ), FastSeq(43L, 36L), ) + } private[this] def assertTakeByEvalsTo( aggType: Type, @@ -386,10 +394,11 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByNGreater(): Unit = + test("takeByNGreater") { assertTakeByEvalsTo(TInt32, TInt32, 5, FastSeq(Row(3, 4)), FastSeq(3)) + } - @Test def takeByBooleanBoolean(): Unit = { + test("takeByBooleanBoolean") { assertTakeByEvalsTo( TBoolean, TBoolean, @@ -399,7 +408,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByBooleanInt(): Unit = { + test("takeByBooleanInt") { assertTakeByEvalsTo( TBoolean, TInt32, @@ -416,7 +425,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByBooleanLong(): Unit = { + test("takeByBooleanLong") { assertTakeByEvalsTo( TBoolean, TInt64, @@ -433,7 +442,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByBooleanFloat(): Unit = { + test("takeByBooleanFloat") { assertTakeByEvalsTo( TBoolean, TFloat32, @@ -450,7 +459,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByBooleanDouble(): Unit = { + test("takeByBooleanDouble") { assertTakeByEvalsTo( TBoolean, TFloat64, @@ -467,7 +476,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByBooleanAnnotation(): Unit = { + test("takeByBooleanAnnotation") { assertTakeByEvalsTo( TBoolean, TString, @@ -484,7 +493,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntBoolean(): Unit = { + test("takeByIntBoolean") { assertTakeByEvalsTo( TInt32, TBoolean, @@ -494,7 +503,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntInt(): Unit = { + test("takeByIntInt") { assertTakeByEvalsTo( TInt32, TInt32, @@ -504,7 +513,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntLong(): Unit = { + test("takeByIntLong") { assertTakeByEvalsTo( TInt32, TInt64, @@ -514,7 +523,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntFloat(): Unit = { + test("takeByIntFloat") { assertTakeByEvalsTo( TInt32, TFloat32, @@ -524,7 +533,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntDouble(): Unit = { + test("takeByIntDouble") { assertTakeByEvalsTo( TInt32, TFloat64, @@ -534,7 +543,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByIntAnnotation(): Unit = { + test("takeByIntAnnotation") { assertTakeByEvalsTo( TInt32, TString, @@ -551,7 +560,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongBoolean(): Unit = { + test("takeByLongBoolean") { assertTakeByEvalsTo( TInt64, TBoolean, @@ -561,7 +570,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongInt(): Unit = { + test("takeByLongInt") { assertTakeByEvalsTo( TInt64, TInt32, @@ -571,7 +580,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongLong(): Unit = { + test("takeByLongLong") { assertTakeByEvalsTo( TInt64, TInt64, @@ -588,7 +597,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongFloat(): Unit = { + test("takeByLongFloat") { assertTakeByEvalsTo( TInt64, TFloat32, @@ -605,7 +614,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongDouble(): Unit = { + test("takeByLongDouble") { assertTakeByEvalsTo( TInt64, TFloat64, @@ -622,7 +631,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByLongAnnotation(): Unit = { + test("takeByLongAnnotation") { assertTakeByEvalsTo( TInt64, TString, @@ -639,7 +648,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatBoolean(): Unit = { + test("takeByFloatBoolean") { assertTakeByEvalsTo( TFloat32, TBoolean, @@ -649,7 +658,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatInt(): Unit = { + test("takeByFloatInt") { assertTakeByEvalsTo( TFloat32, TInt32, @@ -659,7 +668,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatLong(): Unit = { + test("takeByFloatLong") { assertTakeByEvalsTo( TFloat32, TInt64, @@ -676,7 +685,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatFloat(): Unit = { + test("takeByFloatFloat") { assertTakeByEvalsTo( TFloat32, TFloat32, @@ -693,7 +702,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatDouble(): Unit = { + test("takeByFloatDouble") { assertTakeByEvalsTo( TFloat32, TFloat64, @@ -710,7 +719,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByFloatAnnotation(): Unit = { + test("takeByFloatAnnotation") { assertTakeByEvalsTo( TFloat32, TString, @@ -727,7 +736,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleBoolean(): Unit = { + test("takeByDoubleBoolean") { assertTakeByEvalsTo( TFloat64, TBoolean, @@ -737,7 +746,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleInt(): Unit = { + test("takeByDoubleInt") { assertTakeByEvalsTo( TFloat64, TInt32, @@ -747,7 +756,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleLong(): Unit = { + test("takeByDoubleLong") { assertTakeByEvalsTo( TFloat64, TInt64, @@ -764,7 +773,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleFloat(): Unit = { + test("takeByDoubleFloat") { assertTakeByEvalsTo( TFloat64, TFloat32, @@ -781,7 +790,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleDouble(): Unit = { + test("takeByDoubleDouble") { assertTakeByEvalsTo( TFloat64, TFloat64, @@ -798,7 +807,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByDoubleAnnotation(): Unit = { + test("takeByDoubleAnnotation") { assertTakeByEvalsTo( TFloat64, TString, @@ -815,7 +824,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationBoolean(): Unit = { + test("takeByAnnotationBoolean") { assertTakeByEvalsTo( TString, TBoolean, @@ -825,7 +834,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationInt(): Unit = { + test("takeByAnnotationInt") { assertTakeByEvalsTo( TString, TInt32, @@ -835,7 +844,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationLong(): Unit = { + test("takeByAnnotationLong") { assertTakeByEvalsTo( TString, TInt64, @@ -852,7 +861,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationFloat(): Unit = { + test("takeByAnnotationFloat") { assertTakeByEvalsTo( TString, TFloat32, @@ -869,7 +878,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationDouble(): Unit = { + test("takeByAnnotationDouble") { assertTakeByEvalsTo( TString, TFloat64, @@ -886,7 +895,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByAnnotationAnnotation(): Unit = { + test("takeByAnnotationAnnotation") { assertTakeByEvalsTo( TString, TString, @@ -903,7 +912,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def takeByCallLong(): Unit = { + test("takeByCallLong") { assertTakeByEvalsTo( TCall, TInt64, @@ -940,8 +949,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedCount(): Unit = { + test("keyedCount") { runKeyedAggregator( Count(), Ref(Name("k"), TInt32), @@ -982,8 +990,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedCollect(): Unit = { + test("keyedCollect") { runKeyedAggregator( Collect(), Ref(Name("k"), TBoolean), @@ -1003,8 +1010,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedCallStats(): Unit = { + test("keyedCallStats") { runKeyedAggregator( CallStats(), Ref(Name("k"), TBoolean), @@ -1026,8 +1032,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedTakeBy(): Unit = { + test("keyedTakeBy") { runKeyedAggregator( TakeBy(), Ref(Name("k"), TString), @@ -1046,8 +1051,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedKeyedCollect(): Unit = { + test("keyedKeyedCollect") { val agg = FastSeq(Row("EUR", true, 1), Row("EUR", false, 2), Row("AFR", true, 3), Row("AFR", null, 4)) val aggType = TStruct("k1" -> TString, "k2" -> TBoolean, "x" -> TInt32) @@ -1070,8 +1074,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedKeyedCallStats(): Unit = { + test("keyedKeyedCallStats") { val agg = FastSeq( Row("EUR", "CASE", null), Row("EUR", "CONTROL", Call2(0, 1)), @@ -1104,8 +1107,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedKeyedTakeBy(): Unit = { + test("keyedKeyedTakeBy") { val agg = FastSeq( Row("case", "a", 0.2, 5), Row("control", "b", 0.4, 0), @@ -1139,8 +1141,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test - def keyedKeyedKeyedCollect(): Unit = { + test("keyedKeyedKeyedCollect") { val agg = FastSeq( Row("EUR", "CASE", true, 1), Row("EUR", "CONTROL", true, 2), @@ -1171,7 +1172,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def downsampleWhenEmpty(): Unit = { + test("downsampleWhenEmpty") { runAggregator( Downsample(), TStruct("x" -> TFloat64, "y" -> TFloat64, "label" -> TArray(TString)), @@ -1186,7 +1187,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def testAggFilter(): Unit = { + test("AggFilter") { val aggType = TStruct("x" -> TBoolean, "y" -> TInt64) val agg = FastSeq(Row(true, -1L), Row(true, 1L), Row(false, 3L), Row(true, 5L)) @@ -1201,7 +1202,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def testAggExplode(): Unit = { + test("AggExplode") { val aggType = TStruct("x" -> TArray(TInt64)) val agg = FastSeq( Row(FastSeq[Long](1, 4)), @@ -1219,7 +1220,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def testArrayElementsAggregator(): Unit = { + test("ArrayElementsAggregator") { implicit val execStrats = ExecStrategy.interpretOnly def getAgg(n: Int, m: Int): IR = { @@ -1237,7 +1238,7 @@ class AggregatorsSuite extends HailSuite { assertEvalsTo(getAgg(10, 10), IndexedSeq.range(0, 10).map(_ * 10L)) } - @Test def testArrayElementsAggregatorEmpty(): Unit = { + test("ArrayElementsAggregatorEmpty") { implicit val execStrats = ExecStrategy.interpretOnly def getAgg(n: Int, m: Int, knownLength: Option[IR]): IR = { @@ -1263,7 +1264,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def testImputeTypeSimple(): Unit = { + test("ImputeTypeSimple") { runAggregator(ImputeType(), TString, FastSeq(null), Row(false, false, true, true, true, true)) runAggregator( ImputeType(), @@ -1285,7 +1286,7 @@ class AggregatorsSuite extends HailSuite { ) } - @Test def testFoldAgg(): Unit = { + test("FoldAgg") { val myIR = streamAggIR( mapIR(rangeIR(100))(idx => makestruct(("idx", idx), ("unused", idx + idx))) )(foo => aggFoldIR(I32(0))(_ + GetField(foo, "idx"))(_ + _)) @@ -1304,7 +1305,7 @@ class AggregatorsSuite extends HailSuite { assertEvalsTo(myLoweredTableIR, 4950) } - @Test def testFoldScan(): Unit = { + test("FoldScan") { val foo = Ref(freshName(), TStruct("idx" -> TInt32, "unused" -> TInt32)) val myIR = ToArray( @@ -1318,7 +1319,7 @@ class AggregatorsSuite extends HailSuite { } // fails because there is no "lowest binding referenced in an init op" - @Test def testStreamAgg(): Unit = { + test("StreamAgg") { implicit val execStrats = Set(ExecStrategy.JvmCompileUnoptimized) val foo = StreamRange(I32(0), I32(10), I32(1)) @@ -1333,7 +1334,7 @@ class AggregatorsSuite extends HailSuite { assertEvalsTo(ir, 1) } - @Test def testLetBoundInitOpArg(): Unit = { + test("LetBoundInitOpArg") { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(10, 3) tir = diff --git a/hail/hail/test/src/is/hail/expr/ir/ArrayDeforestationSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ArrayDeforestationSuite.scala index a884003209d..746d817ddd3 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ArrayDeforestationSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ArrayDeforestationSuite.scala @@ -7,7 +7,6 @@ import is.hail.expr.ir.defs.{ } import org.apache.spark.sql.Row -import org.testng.annotations.Test class ArrayDeforestationSuite extends HailSuite { implicit val execStrats: ExecStrategy.ValueSet = ExecStrategy.values @@ -51,7 +50,7 @@ class ArrayDeforestationSuite extends HailSuite { } } - @Test def testArrayFold(): Unit = { + test("ArrayFold") { assertEvalsTo(arrayFoldWithStructWithPrimitiveValues(5, -5, -6), Row(5, 4)) assertEvalsTo(arrayFoldWithStruct(5, -5, -6), Row(Row(4, 0), Row(5, 0))) } diff --git a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala index 891c25273d6..c69c754bcce 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ArrayFunctionsSuite.scala @@ -7,63 +7,83 @@ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{ArraySlice, F32, F64, I32, In, MakeArray, NA, Str} import is.hail.types.virtual._ -import org.testng.annotations.{DataProvider, Test} - class ArrayFunctionsSuite extends HailSuite { val naa = NA(TArray(TInt32)) implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @DataProvider(name = "basic") - def basicData(): Array[Array[Any]] = Array( - Array(FastSeq(3, 7)), - Array(null), - Array(FastSeq(3, null, 7, null)), - Array(FastSeq()), - ) + def lift(f: (Int, Int) => Int) + : (IndexedSeq[Integer], IndexedSeq[Integer]) => IndexedSeq[Integer] = { + case (a, b) => + Option(a).zip(Option(b)).headOption.map { case (a0, b0) => + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[Integer] { case (m, n) => f(m, n) }.orNull + } + }.orNull + } - @DataProvider(name = "basicPairs") - def basicPairsData(): Array[Array[Any]] = basicData().flatten.combinations(2).toArray + // basic DataProvider tests - @Test(dataProvider = "basic") - def isEmpty(a: IndexedSeq[Integer]): Unit = - assertEvalsTo(invoke("isEmpty", TBoolean, toIRArray(a)), Option(a).map(_.isEmpty).orNull) + val basicData: Array[IndexedSeq[Integer]] = + Array(FastSeq(3, 7), null, FastSeq(3, null, 7, null), FastSeq()) - @Test(dataProvider = "basic") - def append(a: IndexedSeq[Integer]): Unit = - assertEvalsTo( - invoke("append", TArray(TInt32), toIRArray(a), I32(1)), - Option(a).map(_ :+ 1).orNull, - ) + object checkIsEmpty extends TestCases { + def apply(a: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = test("isEmpty") { + assertEvalsTo(invoke("isEmpty", TBoolean, toIRArray(a)), Option(a).map(_.isEmpty).orNull) + } + } - @Test(dataProvider = "basic") - def appendNull(a: IndexedSeq[Integer]): Unit = - assertEvalsTo( - invoke("append", TArray(TInt32), toIRArray(a), NA(TInt32)), - Option(a).map(_ :+ null).orNull, - ) + basicData.foreach(checkIsEmpty(_)) - @Test(dataProvider = "basic") - def sum(a: IndexedSeq[Integer]): Unit = { - assertEvalsTo( - invoke("sum", TInt32, toIRArray(a)), - Option(a).flatMap(_.foldLeft[Option[Int]](Some(0))((comb, x) => - comb.flatMap(c => Option(x).map(_ + c)) - )).orNull, - ) + object checkAppend extends TestCases { + def apply(a: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = test("append") { + assertEvalsTo( + invoke("append", TArray(TInt32), toIRArray(a), I32(1)), + Option(a).map(_ :+ 1).orNull, + ) + } } - @Test(dataProvider = "basic") - def product(a: IndexedSeq[Integer]): Unit = { - assertEvalsTo( - invoke("product", TInt32, toIRArray(a)), - Option(a).flatMap(_.foldLeft[Option[Int]](Some(1))((comb, x) => - comb.flatMap(c => Option(x).map(_ * c)) - )).orNull, - ) + basicData.foreach(checkAppend(_)) + + object checkAppendNull extends TestCases { + def apply(a: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = test("appendNull") { + assertEvalsTo( + invoke("append", TArray(TInt32), toIRArray(a), NA(TInt32)), + Option(a).map(_ :+ null).orNull, + ) + } } - @Test def mean(): Unit = { + basicData.foreach(checkAppendNull(_)) + + object checkSum extends TestCases { + def apply(a: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = test("sum") { + assertEvalsTo( + invoke("sum", TInt32, toIRArray(a)), + Option(a).flatMap(_.foldLeft[Option[Int]](Some(0))((comb, x) => + comb.flatMap(c => Option(x).map(_ + c)) + )).orNull, + ) + } + } + + basicData.foreach(checkSum(_)) + + object checkProduct extends TestCases { + def apply(a: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = test("product") { + assertEvalsTo( + invoke("product", TInt32, toIRArray(a)), + Option(a).flatMap(_.foldLeft[Option[Int]](Some(1))((comb, x) => + comb.flatMap(c => Option(x).map(_ * c)) + )).orNull, + ) + } + } + + basicData.foreach(checkProduct(_)) + + test("mean") { assertEvalsTo(invoke("mean", TFloat64, IRArray(3, 7)), 5.0) assertEvalsTo(invoke("mean", TFloat64, IRArray(3, null, 7)), null) assertEvalsTo(invoke("mean", TFloat64, IRArray(3, 7, 11)), 7.0) @@ -72,7 +92,7 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mean", TFloat64, naa), null) } - @Test def median(): Unit = { + test("median") { assertEvalsTo(invoke("median", TInt32, IRArray(5)), 5) assertEvalsTo(invoke("median", TInt32, IRArray(5, null, null)), 5) assertEvalsTo(invoke("median", TInt32, IRArray(3, 7)), 5) @@ -84,30 +104,67 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("median", TInt32, naa), null) } - @Test(dataProvider = "basicPairs") - def extend(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = - assertEvalsTo( - invoke("extend", TArray(TInt32), toIRArray(a), toIRArray(b)), - Option(a).zip(Option(b)).headOption.map { case (x, y) => x ++ y }.orNull, - ) + // basicPairs DataProvider test - @DataProvider(name = "sort") - def sortData(): Array[Array[Any]] = Array( - Array(FastSeq(3, 9, 7), FastSeq(3, 7, 9), FastSeq(9, 7, 3)), - Array(null, null, null), - Array(FastSeq(3, null, 1, null, 3), FastSeq(1, 3, 3, null, null), FastSeq(3, 3, 1, null, null)), - Array(FastSeq(1, null, 3, null, 1), FastSeq(1, 1, 3, null, null), FastSeq(3, 1, 1, null, null)), - Array(FastSeq(), FastSeq(), FastSeq()), - ) + object checkExtend extends TestCases { + def apply(a: IndexedSeq[Integer], b: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = + test("extend") { + assertEvalsTo( + invoke("extend", TArray(TInt32), toIRArray(a), toIRArray(b)), + Option(a).zip(Option(b)).headOption.map { case (x, y) => x ++ y }.orNull, + ) + } + } - @Test(dataProvider = "sort") - def min(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]): Unit = - assertEvalsTo( - invoke("min", TInt32, toIRArray(a)), - Option(asc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + basicData.combinations(2).foreach { case Array(a, b) => checkExtend(a, b) } + + // sort DataProvider tests + + object checkMin extends TestCases { + def apply( + a: IndexedSeq[Integer], + asc: IndexedSeq[Integer], + desc: IndexedSeq[Integer], + )(implicit + loc: munit.Location + ): Unit = test("min") { + assertEvalsTo( + invoke("min", TInt32, toIRArray(a)), + Option(asc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + ) + } + } + + object checkMax extends TestCases { + def apply( + a: IndexedSeq[Integer], + asc: IndexedSeq[Integer], + desc: IndexedSeq[Integer], + )(implicit + loc: munit.Location + ): Unit = test("max") { + assertEvalsTo( + invoke("max", TInt32, toIRArray(a)), + Option(desc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + ) + } + } + + { + val sortData: Array[(IndexedSeq[Integer], IndexedSeq[Integer], IndexedSeq[Integer])] = Array( + (FastSeq(3, 9, 7), FastSeq(3, 7, 9), FastSeq(9, 7, 3)), + (null, null, null), + (FastSeq(3, null, 1, null, 3), FastSeq(1, 3, 3, null, null), FastSeq(3, 3, 1, null, null)), + (FastSeq(1, null, 3, null, 1), FastSeq(1, 1, 3, null, null), FastSeq(3, 1, 1, null, null)), + (FastSeq(), FastSeq(), FastSeq()), ) + for ((a, asc, desc) <- sortData) { + checkMin(a, asc, desc) + checkMax(a, asc, desc) + } + } - @Test def testMinMaxNans(): Unit = { + test("MinMaxNans") { assertAllEvalTo( ( invoke( @@ -156,121 +213,171 @@ class ArrayFunctionsSuite extends HailSuite { ) } - @Test(dataProvider = "sort") - def max(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]): Unit = - assertEvalsTo( - invoke("max", TInt32, toIRArray(a)), - Option(desc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + // argminmax DataProvider tests + + object checkArgmin extends TestCases { + def apply( + a: IndexedSeq[Integer], + argmin: Integer, + argmax: Integer, + )(implicit + loc: munit.Location + ): Unit = test("argmin") { + assertEvalsTo(invoke("argmin", TInt32, toIRArray(a)), argmin) + } + } + + object checkArgmax extends TestCases { + def apply( + a: IndexedSeq[Integer], + argmin: Integer, + argmax: Integer, + )(implicit + loc: munit.Location + ): Unit = test("argmax") { + assertEvalsTo(invoke("argmax", TInt32, toIRArray(a)), argmax) + } + } + + { + val argMinMaxData: Array[(IndexedSeq[Integer], Integer, Integer)] = Array( + (FastSeq(3, 9, 7), 0, 1), + (null, null, null), + (FastSeq(3, null, 1, null, 3), 2, 0), + (FastSeq(1, null, 3, null, 1), 0, 2), + (FastSeq(), null, null), ) + for ((a, amin, amax) <- argMinMaxData) { + checkArgmin(a, amin, amax) + checkArgmax(a, amin, amax) + } + } - @DataProvider(name = "argminmax") - def argMinMaxData(): Array[Array[Any]] = Array( - Array(FastSeq(3, 9, 7), 0, 1), - Array(null, null, null), - Array(FastSeq(3, null, 1, null, 3), 2, 0), - Array(FastSeq(1, null, 3, null, 1), 0, 2), - Array(FastSeq(), null, null), - ) + // uniqueMinMaxIndex DataProvider tests + + object checkUniqueMinIndex extends TestCases { + def apply( + a: IndexedSeq[Integer], + argmin: Integer, + argmax: Integer, + )(implicit + loc: munit.Location + ): Unit = test("uniqueMinIndex") { + assertEvalsTo(invoke("uniqueMinIndex", TInt32, toIRArray(a)), argmin) + } + } - @Test(dataProvider = "argminmax") - def argmin(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = - assertEvalsTo(invoke("argmin", TInt32, toIRArray(a)), argmin) - - @Test(dataProvider = "argminmax") - def argmax(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = - assertEvalsTo(invoke("argmax", TInt32, toIRArray(a)), argmax) - - @DataProvider(name = "uniqueMinMaxIndex") - def uniqueMinMaxData(): Array[Array[Any]] = Array( - Array(FastSeq(3, 9, 7), 0, 1), - Array(null, null, null), - Array(FastSeq(3, null, 1, null, 3), 2, null), - Array(FastSeq(1, null, 3, null, 1), null, 2), - Array(FastSeq(), null, null), - ) + object checkUniqueMaxIndex extends TestCases { + def apply( + a: IndexedSeq[Integer], + argmin: Integer, + argmax: Integer, + )(implicit + loc: munit.Location + ): Unit = test("uniqueMaxIndex") { + assertEvalsTo(invoke("uniqueMaxIndex", TInt32, toIRArray(a)), argmax) + } + } - @Test(dataProvider = "uniqueMinMaxIndex") - def uniqueMinIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = - assertEvalsTo(invoke("uniqueMinIndex", TInt32, toIRArray(a)), argmin) - - @Test(dataProvider = "uniqueMinMaxIndex") - def uniqueMaxIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = - assertEvalsTo(invoke("uniqueMaxIndex", TInt32, toIRArray(a)), argmax) - - @DataProvider(name = "arrayOpsData") - def arrayOpsData(): Array[Array[Any]] = Array[Any]( - FastSeq(3, 9, 7, 1), - FastSeq(null, 2, null, 8), - FastSeq(5, 3, null, null), - null, - ).combinations(2).toArray - - @DataProvider(name = "arrayOpsOperations") - def arrayOpsOperations: Array[Array[Any]] = Array[(String, (Int, Int) => Int)]( - ("add", _ + _), - ("sub", _ - _), - ("mul", _ * _), - ("floordiv", _ / _), - ("mod", _ % _), - ).map(_.productIterator.toArray) - - @DataProvider(name = "arrayOps") - def arrayOpsPairs(): Array[Array[Any]] = - for { - Array(a, b) <- arrayOpsData() - Array(s, f) <- arrayOpsOperations - } yield Array(a, b, s, f) + { + val uniqueMinMaxData: Array[(IndexedSeq[Integer], Integer, Integer)] = Array( + (FastSeq(3, 9, 7), 0, 1), + (null, null, null), + (FastSeq(3, null, 1, null, 3), 2, null), + (FastSeq(1, null, 3, null, 1), null, 2), + (FastSeq(), null, null), + ) + for ((a, amin, amax) <- uniqueMinMaxData) { + checkUniqueMinIndex(a, amin, amax) + checkUniqueMaxIndex(a, amin, amax) + } + } - def lift(f: (Int, Int) => Int) - : (IndexedSeq[Integer], IndexedSeq[Integer]) => IndexedSeq[Integer] = { - case (a, b) => - Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => - Option(i).zip(Option(j)).headOption.map[Integer] { case (m, n) => f(m, n) }.orNull - } - }.orNull + // arrayOps DataProvider tests + + object checkArrayOps extends TestCases { + def apply( + a: IndexedSeq[Integer], + b: IndexedSeq[Integer], + s: String, + f: (Int, Int) => Int, + )(implicit loc: munit.Location + ): Unit = test("arrayOps") { + assertEvalsTo(invoke(s, TArray(TInt32), toIRArray(a), toIRArray(b)), lift(f)(a, b)) + } } - @Test(dataProvider = "arrayOps") - def arrayOps(a: IndexedSeq[Integer], b: IndexedSeq[Integer], s: String, f: (Int, Int) => Int) - : Unit = - assertEvalsTo(invoke(s, TArray(TInt32), toIRArray(a), toIRArray(b)), lift(f)(a, b)) + object checkArrayOpsFPDiv extends TestCases { + def apply(a: IndexedSeq[Integer], b: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = + test("arrayOpsFPDiv") { + assertEvalsTo( + invoke("div", TArray(TFloat64), toIRArray(a), toIRArray(b)), + Option(a).zip(Option(b)).headOption.map { case (a0, b0) => + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => + m.toDouble / n + }.orNull + } + }.orNull, + ) + } + } - @Test(dataProvider = "arrayOpsData") - def arrayOpsFPDiv(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = { - assertEvalsTo( - invoke("div", TArray(TFloat64), toIRArray(a), toIRArray(b)), - Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => - Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => - m.toDouble / n - }.orNull - } - }.orNull, - ) + object checkArrayOpsPow extends TestCases { + def apply(a: IndexedSeq[Integer], b: IndexedSeq[Integer])(implicit loc: munit.Location): Unit = + test("arrayOpsPow") { + assertEvalsTo( + invoke("pow", TArray(TFloat64), toIRArray(a), toIRArray(b)), + Option(a).zip(Option(b)).headOption.map { case (a0, b0) => + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => + math.pow(m.toDouble, n.toDouble) + }.orNull + } + }.orNull, + ) + } } - @Test(dataProvider = "arrayOpsData") - def arrayOpsPow(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = { - assertEvalsTo( - invoke("pow", TArray(TFloat64), toIRArray(a), toIRArray(b)), - Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => - Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => - math.pow(m.toDouble, n.toDouble) - }.orNull - } - }.orNull, - ) + object checkArrayOpsDifferentLength extends TestCases { + def apply(s: String, f: (Int, Int) => Int)(implicit loc: munit.Location): Unit = + test("arrayOpsDifferentLength") { + assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2, 3), IRArray(1, 2)), "length mismatch") + assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2), IRArray(1, 2, 3)), "length mismatch") + } } - @Test(dataProvider = "arrayOpsOperations") - def arrayOpsDifferentLength(s: String, f: (Int, Int) => Int): Unit = { - assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2, 3), IRArray(1, 2)), "length mismatch") - assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2), IRArray(1, 2, 3)), "length mismatch") + { + val arrayOpsDataValues: Array[Array[IndexedSeq[Integer]]] = Array[IndexedSeq[Integer]]( + FastSeq(3, 9, 7, 1), + FastSeq(null, 2, null, 8), + FastSeq(5, 3, null, null), + null, + ).combinations(2).toArray.map(_.toArray) + + val arrayOpsOps: Array[(String, (Int, Int) => Int)] = Array( + ("add", _ + _), + ("sub", _ - _), + ("mul", _ * _), + ("floordiv", _ / _), + ("mod", _ % _), + ) + + for { + Array(a, b) <- arrayOpsDataValues + (s, f) <- arrayOpsOps + } checkArrayOps(a, b, s, f) + + for (Array(a, b) <- arrayOpsDataValues) { + checkArrayOpsFPDiv(a, b) + checkArrayOpsPow(a, b) + } + + for ((s, f) <- arrayOpsOps) + checkArrayOpsDifferentLength(s, f) } - @Test def indexing(): Unit = { + test("indexing") { val a = IRArray(0, null, 2) assertEvalsTo(invoke("indexArray", TInt32, a, I32(0)), 0) assertEvalsTo(invoke("indexArray", TInt32, a, I32(1)), null) @@ -283,7 +390,7 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("indexArray", TInt32, a, NA(TInt32)), null) } - @Test def slicing(): Unit = { + test("slicing") { val a = IRArray(0, null, 2) assertEvalsTo(ArraySlice(a, I32(1), None), FastSeq(null, 2)) assertEvalsTo(ArraySlice(a, I32(-2), None), FastSeq(null, 2)) @@ -311,23 +418,29 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(ArraySlice(a, I32(3), Some(I32(2))), FastSeq()) } - @DataProvider(name = "flatten") - def flattenData(): Array[Array[Any]] = Array( - Array(FastSeq(FastSeq(3, 9, 7), FastSeq(3, 7, 9)), FastSeq(3, 9, 7, 3, 7, 9)), - Array(FastSeq(null, FastSeq(1)), FastSeq(1)), - Array(FastSeq(null, null), FastSeq()), - Array(FastSeq(FastSeq(null), FastSeq(), FastSeq(7)), FastSeq(null, 7)), - Array(FastSeq(FastSeq(), FastSeq()), FastSeq()), - ) + // flatten DataProvider test + + object checkFlatten extends TestCases { + def apply( + in: IndexedSeq[IndexedSeq[Integer]], + expected: IndexedSeq[Any], + )(implicit + loc: munit.Location + ): Unit = test("flatten") { + assertEvalsTo( + invoke("flatten", TArray(TInt32), MakeArray(in.map(toIRArray(_)), TArray(TArray(TInt32)))), + expected, + ) + } + } - @Test(dataProvider = "flatten") - def flatten(in: IndexedSeq[IndexedSeq[Integer]], expected: IndexedSeq[Int]): Unit = - assertEvalsTo( - invoke("flatten", TArray(TInt32), MakeArray(in.map(toIRArray(_)), TArray(TArray(TInt32)))), - expected, - ) + checkFlatten(FastSeq(FastSeq(3, 9, 7), FastSeq(3, 7, 9)), FastSeq(3, 9, 7, 3, 7, 9)) + checkFlatten(FastSeq(null, FastSeq(1)), FastSeq(1)) + checkFlatten(FastSeq(null, null), FastSeq()) + checkFlatten(FastSeq(FastSeq(null), FastSeq(), FastSeq(7)), FastSeq[Any](null, 7)) + checkFlatten(FastSeq(FastSeq(), FastSeq()), FastSeq()) - @Test def testContains(): Unit = { + test("Contains") { val t = TArray(TString) assertEvalsTo( @@ -373,53 +486,65 @@ class ArrayFunctionsSuite extends HailSuite { ) } - @DataProvider(name = "scatter") - def scatterData: Array[Array[Any]] = Array( - Array(FastSeq("a", "b", "c"), FastSeq(1, 3, 4), FastSeq(null, "a", null, "b", "c")), - Array(FastSeq("a", "b", "c"), FastSeq(2, 0, 3), FastSeq("b", null, "a", "c", null)), - Array(FastSeq(), FastSeq(), FastSeq(null, null, null)), - Array(FastSeq(), FastSeq(), FastSeq()), - ) - - @Test(dataProvider = "scatter") - def testScatter(elts: IndexedSeq[String], indices: IndexedSeq[Int], expected: IndexedSeq[String]) - : Unit = { - val t1 = TArray(TInt32) - val t2 = TArray(TString) + // scatter DataProvider test + + object checkScatter extends TestCases { + def apply( + elts: IndexedSeq[String], + indices: IndexedSeq[Int], + expected: IndexedSeq[String], + )(implicit loc: munit.Location + ): Unit = test("Scatter") { + val t1 = TArray(TInt32) + val t2 = TArray(TString) + + assertEvalsTo( + invoke("scatter", t2, FastSeq(TString), In(0, t2), In(1, t1), expected.length), + args = FastSeq(elts -> t2, indices -> t1), + expected = expected, + ) + } + } - assertEvalsTo( - invoke("scatter", t2, FastSeq(TString), In(0, t2), In(1, t1), expected.length), - args = FastSeq(elts -> t2, indices -> t1), - expected = expected, - ) + checkScatter(FastSeq("a", "b", "c"), FastSeq(1, 3, 4), FastSeq(null, "a", null, "b", "c")) + checkScatter(FastSeq("a", "b", "c"), FastSeq(2, 0, 3), FastSeq("b", null, "a", "c", null)) + checkScatter(FastSeq(), FastSeq(), FastSeq(null, null, null)) + checkScatter(FastSeq(), FastSeq(), FastSeq()) + + // scatter_errors DataProvider test + + object checkScatterErrors extends TestCases { + def apply( + elts: IndexedSeq[String], + indices: IndexedSeq[Int], + length: Int, + regex: String, + )(implicit loc: munit.Location + ): Unit = test("ScatterErrors") { + val t1 = TArray(TInt32) + val t2 = TArray(TString) + + assertFatal( + invoke("scatter", t2, FastSeq(TString), In(0, t2), In(1, t1), length), + args = FastSeq(elts -> t2, indices -> t1), + regex = regex, + ) + } } - @DataProvider(name = "scatter_errors") - def scatterErrorData: Array[Array[Any]] = Array( - Array(FastSeq("a", "b", "c"), FastSeq(1, 3, 4), 4, "indices array contained index 4"), - Array( - FastSeq("a", "b"), - FastSeq(1, 3, 4), - 4, - "values and indices arrays have different lengths", - ), - Array(FastSeq("a", "b", "c"), FastSeq(1, 2, 2), 2, "values array is larger than result length"), + checkScatterErrors(FastSeq("a", "b", "c"), FastSeq(1, 3, 4), 4, "indices array contained index 4") + + checkScatterErrors( + FastSeq("a", "b"), + FastSeq(1, 3, 4), + 4, + "values and indices arrays have different lengths", ) - @Test(dataProvider = "scatter_errors") - def testScatterErrors( - elts: IndexedSeq[String], - indices: IndexedSeq[Int], - length: Int, - regex: String, - ): Unit = { - val t1 = TArray(TInt32) - val t2 = TArray(TString) - - assertFatal( - invoke("scatter", t2, FastSeq(TString), In(0, t2), In(1, t1), length), - args = FastSeq(elts -> t2, indices -> t1), - regex = regex, - ) - } + checkScatterErrors( + FastSeq("a", "b", "c"), + FastSeq(1, 2, 2), + 2, + "values array is larger than result length", + ) } diff --git a/hail/hail/test/src/is/hail/expr/ir/BlockMatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/BlockMatrixIRSuite.scala index 8127f5acdfb..95a51c7c8fb 100644 --- a/hail/hail/test/src/is/hail/expr/ir/BlockMatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/BlockMatrixIRSuite.scala @@ -19,8 +19,6 @@ import java.lang.Math.floorDiv import breeze.linalg.{DenseMatrix => BDM} import breeze.math.Ring.ringFromField -import org.scalatest.Inspectors.forAll -import org.testng.annotations.{DataProvider, Test} class BlockMatrixIRSuite extends HailSuite { @@ -50,56 +48,64 @@ class BlockMatrixIRSuite extends HailSuite { BlockMatrixMap2(left, right, l.name, r.name, ApplyBinaryPrimOp(op, l, r), strategy) } - @DataProvider(name = "valToBMData") - def valToBMData(): Array[Array[Any]] = Array( - Array( - ValueToBlockMatrix(F64(1), FastSeq(1, 1), BLOCK_SIZE), - BDM.fill[Double](1, 1)(1), - ), - Array(ones, BDM.fill[Double](N_ROWS, N_COLS)(1)), - Array( - ValueToBlockMatrix( - child = ToArray(mapIR(rangeIR(64))(it => it.toD)), - shape = FastSeq(8, 8), - blockSize = 3, - ), - BDM.tabulate[Double](8, 8)((i, j) => i * 8.0 + j), + object checkValueToBlockMatrix extends TestCases { + def apply( + bmir: BlockMatrixIR, + bdm: BDM[Double], + )(implicit loc: munit.Location + ): Unit = test("ValueToBlockMatrix") { + assertBMEvalsTo(bmir, bdm) + } + } + + checkValueToBlockMatrix( + ValueToBlockMatrix(F64(1), FastSeq(1, 1), BLOCK_SIZE), + BDM.fill[Double](1, 1)(1), + ) + + checkValueToBlockMatrix(ones, BDM.fill[Double](N_ROWS, N_COLS)(1)) + + checkValueToBlockMatrix( + ValueToBlockMatrix( + child = ToArray(mapIR(rangeIR(64))(it => it.toD)), + shape = FastSeq(8, 8), + blockSize = 3, ), - Array( - ValueToBlockMatrix( - child = ToArray(mapIR(rangeIR(27))(it => it.toD)), - shape = FastSeq(3, 9), - blockSize = 3, - ), - BDM.tabulate[Double](3, 9)((i, j) => i * 9.0 + j), + BDM.tabulate[Double](8, 8)((i, j) => i * 8.0 + j), + ) + + checkValueToBlockMatrix( + ValueToBlockMatrix( + child = ToArray(mapIR(rangeIR(27))(it => it.toD)), + shape = FastSeq(3, 9), + blockSize = 3, ), - Array( - ValueToBlockMatrix( - child = MakeNDArray.fill(F64(1), FastSeq(I64(8), I64(8)), True()), - shape = FastSeq(8, 8), - blockSize = 3, - ), - BDM.fill[Double](8, 8)(1), + BDM.tabulate[Double](3, 9)((i, j) => i * 9.0 + j), + ) + + checkValueToBlockMatrix( + ValueToBlockMatrix( + child = MakeNDArray.fill(F64(1), FastSeq(I64(8), I64(8)), True()), + shape = FastSeq(8, 8), + blockSize = 3, ), - Array( - ValueToBlockMatrix( - child = MakeNDArray( - ToArray(mapIR(rangeIR(27))(it => it.toD)), - MakeTuple.ordered(FastSeq(I64(3), I64(9))), - True(), - ErrorIDs.NO_ERROR, - ), - shape = FastSeq(3, 9), - blockSize = 3, + BDM.fill[Double](8, 8)(1), + ) + + checkValueToBlockMatrix( + ValueToBlockMatrix( + child = MakeNDArray( + ToArray(mapIR(rangeIR(27))(it => it.toD)), + MakeTuple.ordered(FastSeq(I64(3), I64(9))), + True(), + ErrorIDs.NO_ERROR, ), - BDM.tabulate[Double](3, 9)((i, j) => i * 9.0 + j), + shape = FastSeq(3, 9), + blockSize = 3, ), + BDM.tabulate[Double](3, 9)((i, j) => i * 9.0 + j), ) - @Test(dataProvider = "valToBMData") - def testValueToBlockMatrix(bmir: ValueToBlockMatrix, bdm: BDM[Double]): Unit = - assertBMEvalsTo(bmir, bdm) - def rangeBlockMatrix(nRows: Int, nCols: Int, blockSize: Int = 2): BlockMatrixIR = ValueToBlockMatrix( ToArray(mapIR(rangeIR(nRows * nCols))(i => i.toD)), @@ -128,16 +134,14 @@ class BlockMatrixIRSuite extends HailSuite { else 0.0 } - @Test - def testBlockMatrixSparsify(): Unit = { + test("BlockMatrixSparsify") { val blocks = FastSeq((1, 0), (0, 1), (2, 1), (0, 2)) val bm = sparseRangeBlockMatrix(5, 5, blocks) val expected = sparseRangeBreezeMatrix(5, 5, blocks) assertBMEvalsTo(bm, expected) } - @Test - def testBlockMatrixSparseTranspose(): Unit = { + test("BlockMatrixSparseTranspose") { val blocks = FastSeq((1, 0), (0, 1), (2, 1), (0, 2)) val bm = sparseRangeBlockMatrix(5, 5, blocks) val expected = sparseRangeBreezeMatrix(5, 5, blocks) @@ -147,8 +151,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test - def testBlockMatrixSparseSumCols(): Unit = { + test("BlockMatrixSparseSumCols") { val blocks = FastSeq((1, 0), (0, 1), (2, 1), (0, 2)) val bm = sparseRangeBlockMatrix(5, 5, blocks) assertBMEvalsTo( @@ -157,8 +160,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test - def testBlockMatrixSparseSumRows(): Unit = { + test("BlockMatrixSparseSumRows") { val blocks = FastSeq((1, 0), (0, 1), (2, 1), (0, 2)) val bm = sparseRangeBlockMatrix(5, 5, blocks) assertBMEvalsTo( @@ -167,8 +169,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test - def testBlockMatrixSparseSumAll(): Unit = { + test("BlockMatrixSparseSumAll") { val blocks = FastSeq((1, 0), (0, 1), (2, 1), (0, 2)) val bm = sparseRangeBlockMatrix(5, 5, blocks) assertBMEvalsTo( @@ -177,7 +178,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def testBlockMatrixWriteRead(): Unit = { + test("BlockMatrixWriteRead") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.interpretOnly val tempPath = ctx.createTmpPath("test-blockmatrix-write-read", "bm") Interpret[Unit]( @@ -191,7 +192,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def testBlockMatrixMap(): Unit = { + test("BlockMatrixMap") { val element = Ref(freshName(), TFloat64) val sqrtIR = BlockMatrixMap( ones, @@ -220,7 +221,7 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(absIR, BDM.fill[Double](3, 3)(1)) } - @Test def testBlockMatrixMap2(): Unit = { + test("BlockMatrixMap2") { val onesAddOnes = makeMap2(ones, ones, Add(), UnionBlocks) val onesSubOnes = makeMap2(ones, ones, Subtract(), UnionBlocks) val onesMulOnes = makeMap2(ones, ones, Multiply(), IntersectionBlocks) @@ -232,7 +233,7 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(onesDivOnes, BDM.fill[Double](3, 3)(1.0 / 1.0)) } - @Test def testBlockMatrixBroadcastValue_Scalars(): Unit = { + test("BlockMatrixBroadcastValue_Scalars") { val broadcastTwo = BlockMatrixBroadcast( ValueToBlockMatrix( MakeArray(IndexedSeq[F64](F64(2)), TArray(TFloat64)), @@ -255,7 +256,7 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(onesDivTwo, BDM.fill[Double](3, 3)(1.0 / 2.0)) } - @Test def testBlockMatrixBroadcastValue_Vectors(): Unit = { + test("BlockMatrixBroadcastValue_Vectors") { val vectorLiteral = MakeArray(IndexedSeq[F64](F64(1), F64(2), F64(3)), TArray(TFloat64)) val broadcastRowVector = BlockMatrixBroadcast( @@ -278,7 +279,7 @@ class BlockMatrixIRSuite extends HailSuite { (FloatingPointDivide(), NeedsDense, (i: Double, j: Double) => i / j), ) - forAll(ops) { case (op, merge, f) => + ops.foreach { case (op, merge, f) => val rightRowOp = makeMap2(ones, broadcastRowVector, op, merge) val rightColOp = makeMap2(ones, broadcastColVector, op, merge) val leftRowOp = makeMap2(broadcastRowVector, ones, op, merge) @@ -296,7 +297,7 @@ class BlockMatrixIRSuite extends HailSuite { } } - @Test def testBlockMatrixFilter(): Unit = { + test("BlockMatrixFilter") { val nRows = 5 val nCols = 8 val original = BDM.tabulate[Double](nRows, nCols)((i, j) => i.toDouble * nCols + j) @@ -319,7 +320,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def testBlockMatrixSlice(): Unit = { + test("BlockMatrixSlice") { val nRows = 12 val nCols = 8 val original = BDM.tabulate[Double](nRows, nCols)((i, j) => i.toDouble * nCols + j) @@ -336,13 +337,13 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def testBlockMatrixDot(): Unit = { + test("BlockMatrixDot") { val m1 = BDM.tabulate[Double](5, 4)((i, j) => (i.toDouble + 1) * j) val m2 = BDM.tabulate[Double](4, 6)((i, j) => (i.toDouble + 5) * (j - 2)) assertBMEvalsTo(BlockMatrixDot(toIR(m1), toIR(m2)), m1 * m2) } - @Test def testBlockMatrixRandom(): Unit = { + test("BlockMatrixRandom") { val gaussian = BlockMatrixRandom(0, gaussian = true, shape = ArraySeq(5L, 6L), blockSize = 3) val uniform = BlockMatrixRandom(0, gaussian = false, shape = ArraySeq(5L, 6L), blockSize = 3) @@ -358,7 +359,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def readBlockMatrixIR(): Unit = { + test("readBlockMatrixIR") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val etype = EBlockMatrixNDArray(EFloat64Required, required = true) val path = @@ -387,7 +388,7 @@ class BlockMatrixIRSuite extends HailSuite { ) } - @Test def readWriteBlockMatrix(): Unit = { + test("readWriteBlockMatrix") { val original = getTestResource("blockmatrix_example/0") val expected = BlockMatrix.read(ctx, original).toBreezeMatrix() @@ -404,7 +405,7 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(BlockMatrixRead(BlockMatrixNativeReader(ctx.fs, path)), expected) } - @Test def testBlockMatrixDensify(): Unit = { + test("BlockMatrixDensify") { val dense = fill(1, nRows = 10, nCols = 10, blockSize = 5) val sparse = BlockMatrixSparsify(dense, PerBlockSparsifier(FastSeq(0, 1, 2))) val densified = BlockMatrixDensify(sparse) diff --git a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala index 847a530a9f0..accd6774e12 100644 --- a/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/CallFunctionsSuite.scala @@ -8,56 +8,42 @@ import is.hail.expr.ir.defs.{False, I32, Str, True} import is.hail.types.virtual.{TArray, TBoolean, TCall, TInt32} import is.hail.variant._ -import org.testng.annotations.{DataProvider, Test} - class CallFunctionsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @DataProvider(name = "basic") - def basicData(): Array[Array[Any]] = { - assert(true) - Array( - Array(Call0()), - Array(Call1(0, false)), - Array(Call1(1, true)), - Array(Call2(1, 0, true)), - Array(Call2(0, 1, false)), - Array(CallN(ArraySeq(1, 1), false)), - Array(Call.parse("0|1")), - ) - } - - @DataProvider(name = "diploid") - def uphasedDiploidData(): Array[Array[Any]] = { - assert(true) - Array( - Array(Call2(0, 0, false)), - Array(Call2(1, 0, false)), - Array(Call2(0, 1, false)), - Array(Call2(3, 1, false)), - Array(Call2(3, 3, false)), - ) - } - - @DataProvider(name = "basicWithIndex") - def basicDataWIndex(): Array[Array[Any]] = { - assert(true) - Array( - Array(Call1(0, false), 0), - Array(Call1(1, true), 0), - Array(Call2(1, 0, true), 0), - Array(Call2(1, 0, true), 1), - Array(Call2(0, 1, false), 0), - Array(Call2(0, 1, false), 1), - Array(CallN(ArraySeq(1, 1), false), 0), - Array(CallN(ArraySeq(1, 1), false), 1), - Array(Call.parse("0|1"), 0), - Array(Call.parse("0|1"), 1), - ) - } - - @Test def constructors(): Unit = { + val basicData: Array[Call] = Array( + Call0(), + Call1(0, false), + Call1(1, true), + Call2(1, 0, true), + Call2(0, 1, false), + CallN(ArraySeq(1, 1), false), + Call.parse("0|1"), + ) + + val diploidData: Array[Call] = Array( + Call2(0, 0, false), + Call2(1, 0, false), + Call2(0, 1, false), + Call2(3, 1, false), + Call2(3, 3, false), + ) + + val basicWithIndexData: Array[(Call, Int)] = Array( + (Call1(0, false), 0), + (Call1(1, true), 0), + (Call2(1, 0, true), 0), + (Call2(1, 0, true), 1), + (Call2(0, 1, false), 0), + (Call2(0, 1, false), 1), + (CallN(ArraySeq(1, 1), false), 0), + (CallN(ArraySeq(1, 1), false), 1), + (Call.parse("0|1"), 0), + (Call.parse("0|1"), 1), + ) + + test("constructors") { assertEvalsTo(invoke("Call", TCall, False()), Call0()) assertEvalsTo(invoke("Call", TCall, I32(0), True()), Call1(0, true)) assertEvalsTo(invoke("Call", TCall, I32(1), False()), Call1(1, false)) @@ -69,69 +55,117 @@ class CallFunctionsSuite extends HailSuite { assertEvalsTo(invoke("Call", TCall, Str("0|1")), Call2(0, 1, true)) } - @Test(dataProvider = "basic") - def isPhased(c: Call): Unit = - assertEvalsTo(invoke("isPhased", TBoolean, IRCall(c)), Option(c).map(Call.isPhased).orNull) + object checkIsPhased extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isPhased") { + assertEvalsTo(invoke("isPhased", TBoolean, IRCall(c)), Option(c).map(Call.isPhased).orNull) + } + } - @Test(dataProvider = "basic") - def isHomRef(c: Call): Unit = - assertEvalsTo(invoke("isHomRef", TBoolean, IRCall(c)), Option(c).map(Call.isHomRef).orNull) + basicData.foreach(checkIsPhased(_)) - @Test(dataProvider = "basic") - def isHet(c: Call): Unit = - assertEvalsTo(invoke("isHet", TBoolean, IRCall(c)), Option(c).map(Call.isHet).orNull) + object checkIsHomRef extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isHomRef") { + assertEvalsTo(invoke("isHomRef", TBoolean, IRCall(c)), Option(c).map(Call.isHomRef).orNull) + } + } - @Test(dataProvider = "basic") - def isHomVar(c: Call): Unit = - assertEvalsTo(invoke("isHomVar", TBoolean, IRCall(c)), Option(c).map(Call.isHomVar).orNull) + basicData.foreach(checkIsHomRef(_)) - @Test(dataProvider = "basic") - def isNonRef(c: Call): Unit = - assertEvalsTo(invoke("isNonRef", TBoolean, IRCall(c)), Option(c).map(Call.isNonRef).orNull) + object checkIsHet extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isHet") { + assertEvalsTo(invoke("isHet", TBoolean, IRCall(c)), Option(c).map(Call.isHet).orNull) + } + } - @Test(dataProvider = "basic") - def isHetNonRef(c: Call): Unit = - assertEvalsTo( - invoke("isHetNonRef", TBoolean, IRCall(c)), - Option(c).map(Call.isHetNonRef).orNull, - ) + basicData.foreach(checkIsHet(_)) - @Test(dataProvider = "basic") - def isHetRef(c: Call): Unit = - assertEvalsTo(invoke("isHetRef", TBoolean, IRCall(c)), Option(c).map(Call.isHetRef).orNull) + object checkIsHomVar extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isHomVar") { + assertEvalsTo(invoke("isHomVar", TBoolean, IRCall(c)), Option(c).map(Call.isHomVar).orNull) + } + } - @Test(dataProvider = "basic") - def nNonRefAlleles(c: Call): Unit = - assertEvalsTo( - invoke("nNonRefAlleles", TInt32, IRCall(c)), - Option(c).map(Call.nNonRefAlleles).orNull, - ) + basicData.foreach(checkIsHomVar(_)) - @Test(dataProvider = "basicWithIndex") - def alleleByIndex(c: Call, idx: Int): Unit = - assertEvalsTo( - invoke("index", TInt32, IRCall(c), I32(idx)), - Option(c).map(c => Call.alleleByIndex(c, idx)).orNull, - ) + object checkIsNonRef extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isNonRef") { + assertEvalsTo(invoke("isNonRef", TBoolean, IRCall(c)), Option(c).map(Call.isNonRef).orNull) + } + } - @Test(dataProvider = "basicWithIndex") - def downcode(c: Call, idx: Int): Unit = - assertEvalsTo( - invoke("downcode", TCall, IRCall(c), I32(idx)), - Option(c).map(c => Call.downcode(c, idx)).orNull, - ) + basicData.foreach(checkIsNonRef(_)) - @Test(dataProvider = "diploid") - def unphasedDiploidGtIndex(c: Call): Unit = - assertEvalsTo( - invoke("unphasedDiploidGtIndex", TInt32, IRCall(c)), - Option(c).map(c => Call.unphasedDiploidGtIndex(c)).orNull, - ) + object checkIsHetNonRef extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isHetNonRef") { + assertEvalsTo( + invoke("isHetNonRef", TBoolean, IRCall(c)), + Option(c).map(Call.isHetNonRef).orNull, + ) + } + } - @Test(dataProvider = "basic") - def oneHotAlleles(c: Call): Unit = - assertEvalsTo( - invoke("oneHotAlleles", TArray(TInt32), IRCall(c), I32(2)), - Option(c).map(c => Call.oneHotAlleles(c, 2)).orNull, - ) + basicData.foreach(checkIsHetNonRef(_)) + + object checkIsHetRef extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("isHetRef") { + assertEvalsTo(invoke("isHetRef", TBoolean, IRCall(c)), Option(c).map(Call.isHetRef).orNull) + } + } + + basicData.foreach(checkIsHetRef(_)) + + object checkNNonRefAlleles extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("nNonRefAlleles") { + assertEvalsTo( + invoke("nNonRefAlleles", TInt32, IRCall(c)), + Option(c).map(Call.nNonRefAlleles).orNull, + ) + } + } + + basicData.foreach(checkNNonRefAlleles(_)) + + object checkAlleleByIndex extends TestCases { + def apply(c: Call, idx: Int)(implicit loc: munit.Location): Unit = test("alleleByIndex") { + assertEvalsTo( + invoke("index", TInt32, IRCall(c), I32(idx)), + Option(c).map(c => Call.alleleByIndex(c, idx)).orNull, + ) + } + } + + basicWithIndexData.foreach { case (c, idx) => checkAlleleByIndex(c, idx) } + + object checkDowncode extends TestCases { + def apply(c: Call, idx: Int)(implicit loc: munit.Location): Unit = test("downcode") { + assertEvalsTo( + invoke("downcode", TCall, IRCall(c), I32(idx)), + Option(c).map(c => Call.downcode(c, idx)).orNull, + ) + } + } + + basicWithIndexData.foreach { case (c, idx) => checkDowncode(c, idx) } + + object checkUnphasedDiploidGtIndex extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("unphasedDiploidGtIndex") { + assertEvalsTo( + invoke("unphasedDiploidGtIndex", TInt32, IRCall(c)), + Option(c).map(c => Call.unphasedDiploidGtIndex(c)).orNull, + ) + } + } + + diploidData.foreach(checkUnphasedDiploidGtIndex(_)) + + object checkOneHotAlleles extends TestCases { + def apply(c: Call)(implicit loc: munit.Location): Unit = test("oneHotAlleles") { + assertEvalsTo( + invoke("oneHotAlleles", TArray(TInt32), IRCall(c), I32(2)), + Option(c).map(c => Call.oneHotAlleles(c, 2)).orNull, + ) + } + } + + basicData.foreach(checkOneHotAlleles(_)) } diff --git a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala index 54ca2648b2b..675fa057907 100644 --- a/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/DictFunctionsSuite.scala @@ -8,7 +8,6 @@ import is.hail.expr.ir.defs.{NA, ToSet, ToStream} import is.hail.types.virtual._ import org.apache.spark.sql.Row -import org.testng.annotations.{DataProvider, Test} class DictFunctionsSuite extends HailSuite { def tuplesToMap(a: IndexedSeq[(Integer, Integer)]): Map[Integer, Integer] = @@ -16,95 +15,135 @@ class DictFunctionsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @DataProvider(name = "basic") - def basicData(): Array[Array[Any]] = Array( - Array(IndexedSeq((1, 3), (2, 7))), - Array(IndexedSeq((1, 3), (2, null), null, (null, 1), (3, 7))), - Array(IndexedSeq()), - Array(IndexedSeq(null)), - Array(null), + val basicData: Array[IndexedSeq[(Integer, Integer)]] = Array( + IndexedSeq((1, 3), (2, 7)), + IndexedSeq((1, 3), (2, null), null, (null, 1), (3, 7)), + IndexedSeq(), + IndexedSeq(null), + null, ) - @Test(dataProvider = "basic") - def dictFromArray(a: IndexedSeq[(Integer, Integer)]): Unit = { - assertEvalsTo(invoke("dict", TDict(TInt32, TInt32), toIRPairArray(a)), tuplesToMap(a)) - assertEvalsTo(toIRDict(a), tuplesToMap(a)) + object checkDictFromArray extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)] + )(implicit loc: munit.Location + ): Unit = test("dictFromArray") { + assertEvalsTo(invoke("dict", TDict(TInt32, TInt32), toIRPairArray(a)), tuplesToMap(a)) + assertEvalsTo(toIRDict(a), tuplesToMap(a)) + } } - @Test(dataProvider = "basic") - def dictFromSet(a: IndexedSeq[(Integer, Integer)]): Unit = - assertEvalsTo( - invoke("dict", TDict(TInt32, TInt32), ToSet(ToStream(toIRPairArray(a)))), - tuplesToMap(a), - ) - - @Test(dataProvider = "basic") - def isEmpty(a: IndexedSeq[(Integer, Integer)]): Unit = - assertEvalsTo( - invoke("isEmpty", TBoolean, toIRDict(a)), - Option(a).map(_.forall(_ == null)).orNull, - ) - - @DataProvider(name = "dictToArray") - def dictToArrayData(): Array[Array[Any]] = Array( - Array(FastSeq(1 -> 3, 2 -> 7), FastSeq(Row(1, 3), Row(2, 7))), - Array( - FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), + basicData.foreach(checkDictFromArray(_)) + + object checkDictFromSet extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)] + )(implicit loc: munit.Location + ): Unit = test("dictFromSet") { + assertEvalsTo( + invoke("dict", TDict(TInt32, TInt32), ToSet(ToStream(toIRPairArray(a)))), + tuplesToMap(a), + ) + } + } + + basicData.foreach(checkDictFromSet(_)) + + object checkIsEmpty extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)] + )(implicit loc: munit.Location + ): Unit = test("isEmpty") { + assertEvalsTo( + invoke("isEmpty", TBoolean, toIRDict(a)), + Option(a).map(_.forall(_ == null)).orNull, + ) + } + } + + basicData.foreach(checkIsEmpty(_)) + + val dictToArrayData: Array[(IndexedSeq[(Integer, Integer)], IndexedSeq[Row])] = Array( + (FastSeq[(Integer, Integer)]((1, 3), (2, 7)), FastSeq(Row(1, 3), Row(2, 7))), + ( + FastSeq[(Integer, Integer)]((1, 3), (2, null), null, (null, 1), (3, 7)), FastSeq(Row(1, 3), Row(2, null), Row(3, 7), Row(null, 1)), ), - Array(FastSeq(), FastSeq()), - Array(FastSeq(null), FastSeq()), - Array(null, null), + (FastSeq[(Integer, Integer)](), FastSeq()), + (FastSeq[(Integer, Integer)](null), FastSeq()), + (null, null), ) - @Test(dataProvider = "dictToArray") - def dictToArray(a: IndexedSeq[(Integer, Integer)], expected: IndexedSeq[Row]): Unit = { - implicit val execStrats = Set(ExecStrategy.JvmCompile) - assertEvalsTo(invoke("dictToArray", TArray(TTuple(TInt32, TInt32)), toIRDict(a)), expected) + object checkDictToArray extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)], + expected: IndexedSeq[Row], + )(implicit loc: munit.Location + ): Unit = test("dictToArray") { + implicit val execStrats = Set(ExecStrategy.JvmCompile) + assertEvalsTo(invoke("dictToArray", TArray(TTuple(TInt32, TInt32)), toIRDict(a)), expected) + } } - @DataProvider(name = "keysAndValues") - def keysAndValuesData(): Array[Array[Any]] = Array( - Array(FastSeq(1 -> 3, 2 -> 7), FastSeq(1, 2), FastSeq(3, 7)), - Array( - FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), - FastSeq(1, 2, 3, null), - FastSeq(3, null, 7, 1), + dictToArrayData.foreach { case (a, expected) => checkDictToArray(a, expected) } + + val keysAndValuesData + : Array[(IndexedSeq[(Integer, Integer)], IndexedSeq[Integer], IndexedSeq[Integer])] = Array( + (FastSeq[(Integer, Integer)]((1, 3), (2, 7)), FastSeq[Integer](1, 2), FastSeq[Integer](3, 7)), + ( + FastSeq[(Integer, Integer)]((1, 3), (2, null), null, (null, 1), (3, 7)), + FastSeq[Integer](1, 2, 3, null), + FastSeq[Integer](3, null, 7, 1), ), - Array(FastSeq(), FastSeq(), FastSeq()), - Array(FastSeq(null), FastSeq(), FastSeq()), - Array(null, null, null), + (FastSeq[(Integer, Integer)](), FastSeq[Integer](), FastSeq[Integer]()), + (FastSeq[(Integer, Integer)](null), FastSeq[Integer](), FastSeq[Integer]()), + (null, null, null), ) - @Test(dataProvider = "keysAndValues") - def keySet( - a: IndexedSeq[(Integer, Integer)], - keys: IndexedSeq[Integer], - values: IndexedSeq[Integer], - ): Unit = - assertEvalsTo(invoke("keySet", TSet(TInt32), toIRDict(a)), Option(keys).map(_.toSet).orNull) - - @Test(dataProvider = "keysAndValues") - def keys( - a: IndexedSeq[(Integer, Integer)], - keys: IndexedSeq[Integer], - values: IndexedSeq[Integer], - ): Unit = - assertEvalsTo(invoke("keys", TArray(TInt32), toIRDict(a)), keys) - - @Test(dataProvider = "keysAndValues") - def values( - a: IndexedSeq[(Integer, Integer)], - keys: IndexedSeq[Integer], - values: IndexedSeq[Integer], - ): Unit = - assertEvalsTo(invoke("values", TArray(TInt32), toIRDict(a)), values) + object checkKeySet extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)], + keys: IndexedSeq[Integer], + values: IndexedSeq[Integer], + )(implicit loc: munit.Location + ): Unit = test("keySet") { + assertEvalsTo(invoke("keySet", TSet(TInt32), toIRDict(a)), Option(keys).map(_.toSet).orNull) + } + } + + keysAndValuesData.foreach { case (a, keys, values) => checkKeySet(a, keys, values) } + + object checkKeys extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)], + keys: IndexedSeq[Integer], + values: IndexedSeq[Integer], + )(implicit loc: munit.Location + ): Unit = test("keys") { + assertEvalsTo(invoke("keys", TArray(TInt32), toIRDict(a)), keys) + } + } + + keysAndValuesData.foreach { case (a, keys, values) => checkKeys(a, keys, values) } + + object checkValues extends TestCases { + def apply( + a: IndexedSeq[(Integer, Integer)], + keys: IndexedSeq[Integer], + values: IndexedSeq[Integer], + )(implicit loc: munit.Location + ): Unit = test("values") { + assertEvalsTo(invoke("values", TArray(TInt32), toIRDict(a)), values) + } + } + + keysAndValuesData.foreach { case (a, keys, values) => checkValues(a, keys, values) } val d = IRDict((1, 3), (3, 7), (5, null), (null, 5)) val dwoutna = IRDict((1, 3), (3, 7), (5, null)) val na = NA(TInt32) - @Test def dictGet(): Unit = { + test("dictGet") { assertEvalsTo(invoke("get", TInt32, NA(TDict(TInt32, TInt32)), 1, na), null) assertEvalsTo(invoke("get", TInt32, d, 0, na), null) assertEvalsTo(invoke("get", TInt32, d, 1, na), 3) @@ -130,7 +169,7 @@ class DictFunctionsSuite extends HailSuite { assertFatal(invoke("index", TInt32, IRDict(), 100), "dictionary") } - @Test def dictContains(): Unit = { + test("dictContains") { assertEvalsTo(invoke("contains", TBoolean, d, 0), false) assertEvalsTo(invoke("contains", TBoolean, d, 1), true) assertEvalsTo(invoke("contains", TBoolean, d, 2), false) diff --git a/hail/hail/test/src/is/hail/expr/ir/DistinctlyKeyedSuite.scala b/hail/hail/test/src/is/hail/expr/ir/DistinctlyKeyedSuite.scala index 2bb722cb11f..4497896f99f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/DistinctlyKeyedSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/DistinctlyKeyedSuite.scala @@ -7,10 +7,8 @@ import is.hail.expr.ir.defs.{ TableWrite, ToArray, } -import org.testng.annotations.Test - class DistinctlyKeyedSuite extends HailSuite { - @Test def distinctlyKeyedRangeTableBase(): Unit = { + test("distinctlyKeyedRangeTableBase") { val tableRange = TableRange(10, 2) val tableFilter = TableFilter( tableRange, @@ -26,7 +24,7 @@ class DistinctlyKeyedSuite extends HailSuite { assert(tableIRSeq.forall(tableIR => distinctlyKeyedAnalysis.contains(tableIR))) } - @Test def readTableKeyByDistinctlyKeyedAnalysis(): Unit = { + test("readTableKeyByDistinctlyKeyedAnalysis") { val rt = TableRange(40, 4) val idxRef = GetField(Ref(TableIR.rowName, rt.typ.rowType), "idx") val at = TableMapRows( @@ -61,7 +59,7 @@ class DistinctlyKeyedSuite extends HailSuite { assert(distinctlyKeyedAnalysis4.contains(intactKeysTable)) } - @Test def nonDistinctlyKeyedParent(): Unit = { + test("nonDistinctlyKeyedParent") { val tableRange1 = TableRange(10, 2) val tableRange2 = TableRange(10, 2) val row = Ref(TableIR.rowName, tableRange2.typ.rowType) @@ -83,7 +81,7 @@ class DistinctlyKeyedSuite extends HailSuite { assert(distinctlyKeyedSeq.forall(tableIR => distinctlyKeyedAnalysis.contains(tableIR))) } - @Test def distinctlyKeyedParent(): Unit = { + test("distinctlyKeyedParent") { val tableRange1 = TableRange(10, 2) val tableRange2 = TableRange(10, 2) val row = Ref(TableIR.rowName, tableRange2.typ.rowType) @@ -102,7 +100,7 @@ class DistinctlyKeyedSuite extends HailSuite { assert(distinctlyKeyedAnalysis.contains(tableDistinct)) } - @Test def iRparent(): Unit = { + test("iRparent") { val tableRange = TableRange(10, 2) val tableFilter = TableFilter( tableRange, diff --git a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala index 69c98d1943d..231d095f65f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/EmitStreamSuite.scala @@ -22,9 +22,6 @@ import is.hail.utils._ import is.hail.variant.Call2 import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.Test class EmitStreamSuite extends HailSuite { @@ -193,10 +190,11 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitNA(): Unit = + test("EmitNA") { assert(evalStream(NA(TStream(TInt32))) == null) + } - @Test def testEmitMake(): Unit = { + test("EmitMake") { val typ = TStream(TInt32) val tests: Array[(IR, IndexedSeq[Any])] = Array( MakeStream(IndexedSeq[IR](1, 2, NA(TInt32), 3), typ) -> IndexedSeq(1, 2, null, 3), @@ -209,13 +207,13 @@ class EmitStreamSuite extends HailSuite { MakeStream(IndexedSeq[IR](Str("hi"), Str("world")), TStream(TString)) -> IndexedSeq("hi", "world"), ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir).contains(v.length), Pretty(ctx, ir)) } } - @Test def testEmitRange(): Unit = { + test("EmitRange") { val tripleType = PCanonicalStruct(false, "start" -> PInt32(), "stop" -> PInt32(), "step" -> PInt32()) val range = compileStream( @@ -250,7 +248,7 @@ class EmitStreamSuite extends HailSuite { assert(range(null) == null) } - @Test def testEmitSeqSample(): Unit = { + test("EmitSeqSample") { val N = 20 val n = 2 @@ -275,8 +273,8 @@ class EmitStreamSuite extends HailSuite { assert(IndexedSeq.forall(e => e >= 0 && e < N)) } - forAll(0 until N) { i => - forAll(i + 1 until N) { j => + (0 until N).foreach { i => + (i + 1 until N).foreach { j => val entry = results(i)(j) // Expected value of entry is 5263. assert(entry > 4880 && entry < 5650) @@ -284,20 +282,20 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitToStream(): Unit = { + test("EmitToStream") { val tests: Array[(IR, IndexedSeq[Any])] = Array( ToStream(MakeArray(IndexedSeq[IR](), TArray(TInt32))) -> IndexedSeq(), ToStream(MakeArray(IndexedSeq[IR](1, 2, 3, 4), TArray(TInt32))) -> IndexedSeq(1, 2, 3, 4), ToStream(NA(TArray(TInt32))) -> null, ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => val expectedLen = Option(v).map(_.length) assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir) == expectedLen, Pretty(ctx, ir)) } } - @Test def testEmitLet(): Unit = { + test("EmitLet") { val ir = bindIRs(3, 10) { case Seq(start, end) => flatMapIR(StreamRange(start, end, 1)) { i => @@ -308,7 +306,7 @@ class EmitStreamSuite extends HailSuite { assert(evalStreamLen(ir).isEmpty, Pretty(ctx, ir)) } - @Test def testEmitMap(): Unit = { + test("EmitMap") { def ten = StreamRange(I32(0), I32(10), I32(1)) val tests: Array[(IR, IndexedSeq[Any])] = Array( @@ -317,13 +315,13 @@ class EmitStreamSuite extends HailSuite { mapIR(mapIR(ten)(_ + 1))(y => y * y) -> (0 until 10).map(i => (i + 1) * (i + 1)), mapIR(ten)(_ => NA(TInt32)) -> IndexedSeq.tabulate(10)(_ => null), ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir).contains(v.length), Pretty(ctx, ir)) } } - @Test def testEmitFilter(): Unit = { + test("EmitFilter") { def ten = StreamRange(I32(0), I32(10), I32(1)) val tests: Array[(IR, IndexedSeq[Any])] = Array( @@ -334,13 +332,13 @@ class EmitStreamSuite extends HailSuite { filterIR(mapIR(ten)(_ => NA(TInt32)))(_ => True()) -> IndexedSeq.tabulate(10)(_ => null), ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir).isEmpty, Pretty(ctx, ir)) } } - @Test def testEmitFlatMap(): Unit = { + test("EmitFlatMap") { val tests: Array[(IR, IndexedSeq[Any])] = Array( flatMapIR(rangeIR(6))(rangeIR(_)) -> @@ -358,7 +356,7 @@ class EmitStreamSuite extends HailSuite { } -> IndexedSeq(0, 0, 1, 1, 2, 2, 3, 3), ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(evalStream(ir) == v, Pretty(ctx, ir)) if (v != null) assert(evalStreamLen(ir).isEmpty, Pretty(ctx, ir)) @@ -366,7 +364,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamBufferedAggregator(): Unit = { + test("StreamBufferedAggregator") { val resultArrayToCompare = (0 until 12).map(i => Row(Row(i, i + 1), 1)) val streamType = TStream(TStruct("a" -> TInt64, "b" -> TInt64)) val numSeq = (0L until 12).map(i => IndexedSeq(I64(i), I64(i + 1))) @@ -412,7 +410,7 @@ class EmitStreamSuite extends HailSuite { } - @Test def testStreamBufferedAggregatorCombine(): Unit = { + test("StreamBufferedAggregatorCombine") { val resultArrayToCompare = IndexedSeq(Row(Row(1), 2)) val streamType = TStream(TStruct("a" -> TInt64)) val elemOne = MakeStruct(IndexedSeq(("a", I64(1)))) @@ -456,7 +454,7 @@ class EmitStreamSuite extends HailSuite { assert(evalStream(result) == resultArrayToCompare) } - @Test def testStreamBufferedAggregatorCollectAggregator(): Unit = { + test("StreamBufferedAggregatorCollectAggregator") { val resultArrayToCompare = IndexedSeq(Row(Row(1), IndexedSeq(1, 3)), Row(Row(2), IndexedSeq(2, 4))) val streamType = TStream(TStruct("a" -> TInt64, "b" -> TInt64)) @@ -502,7 +500,7 @@ class EmitStreamSuite extends HailSuite { assert(evalStream(result) == resultArrayToCompare) } - @Test def testStreamBufferedAggregatorMultipleAggregators(): Unit = { + test("StreamBufferedAggregatorMultipleAggregators") { val resultArrayToCompare = IndexedSeq( Row(Row(1), Row(3, IndexedSeq(1L, 3L, 2L))), Row(Row(2), Row(2, IndexedSeq(2L, 4L))), @@ -593,7 +591,7 @@ class EmitStreamSuite extends HailSuite { assert(evalStream(result) == resultArrayToCompare) } - @Test def testEmitJoinRightDistinct(): Unit = { + test("EmitJoinRightDistinct") { val eltType = TStruct("k" -> TInt32, "v" -> TString) def join(lstream: IR, rstream: IR, joinType: String): IR = @@ -639,7 +637,7 @@ class EmitStreamSuite extends HailSuite { IndexedSeq(Row("A", "a"), Row("B1", "b"), Row("B2", "b"), Row(null, "c")), ), ) - forAll(tests) { case (lstream, rstream, expectedLeft, expectedOuter) => + tests.foreach { case (lstream, rstream, expectedLeft, expectedOuter) => val l = leftjoin(lstream, rstream) val o = outerjoin(lstream, rstream) assert(evalStream(l) == expectedLeft, Pretty(ctx, l)) @@ -649,7 +647,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitJoinRightDistinctInterval(): Unit = { + test("EmitJoinRightDistinctInterval") { val lEltType = TStruct("k" -> TInt32, "v" -> TString) val rEltType = TStruct("k" -> TInterval(TInt32), "v" -> TString) @@ -723,7 +721,7 @@ class EmitStreamSuite extends HailSuite { ), ) - forAll(tests) { case (lstream, rstream, expectedLeft, expectedInner) => + tests.foreach { case (lstream, rstream, expectedLeft, expectedInner) => val l = leftjoin(lstream, rstream) val i = innerjoin(lstream, rstream) assert(evalStream(l) == expectedLeft, Pretty(ctx, l)) @@ -733,7 +731,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamJoinOuterWithKeyRepeats(): Unit = { + test("StreamJoinOuterWithKeyRepeats") { val lEltType = TStruct("k" -> TInt32, "idx_left" -> TInt32) val lRows = FastSeq( Row(1, 1), @@ -768,26 +766,26 @@ class EmitStreamSuite extends HailSuite { assert(compiled == expected) } - @Test def testEmitScan(): Unit = { + test("EmitScan") { val tests: Array[(IR, IndexedSeq[Any])] = Array( streamScanIR(MakeStream(IndexedSeq(), TStream(TInt32)), 9)(_ + _) -> IndexedSeq(9), streamScanIR(mapIR(rangeIR(4))(x => x * x), 1)(_ + _) -> IndexedSeq(1, 1 /*1+0*0*/, 2 /*1+1*1*/, 6 /*2+2*2*/, 15 /*6+3*3*/ ), ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir).contains(v.length), Pretty(ctx, ir)) } } - @Test def testEmitAggScan(): Unit = { + test("EmitAggScan") { def assertAggScan(ir: IR, inType: Type, tests: (Any, Any)*): Unit = unoptimized { ctx => val aggregate = compileStream( LoweringPipeline.compileLowerer(ctx, ir).asInstanceOf[IR], PType.canonical(inType), ) - forAll(tests) { case (inp, expected) => + tests.foreach { case (inp, expected) => assert(aggregate(inp) == expected, Pretty(ctx, ir)) } } @@ -824,7 +822,7 @@ class EmitStreamSuite extends HailSuite { ) } - @Test def testEmitFromIterator(): Unit = { + test("EmitFromIterator") { val intsPType = PInt32(true) val f1 = compileStreamWithIter( @@ -861,7 +859,7 @@ class EmitStreamSuite extends HailSuite { assert(f3(IndexedSeq().iterator) == IndexedSeq()) } - @Test def testEmitIf(): Unit = { + test("EmitIf") { def xs = MakeStream(IndexedSeq[IR](5, 3, 6), TStream(TInt32)) def ys = StreamRange(0, 4, 1) @@ -880,13 +878,13 @@ class EmitStreamSuite extends HailSuite { -> IndexedSeq(0, 1, 2, 3, 5, 3, 6, 0, 1, 2, 3), ) val lens: Array[Option[Int]] = Array(Some(3), Some(4), Some(3), None, None, None) - forAll(tests zip lens) { case ((ir, v), len) => + (tests zip lens).foreach { case ((ir, v), len) => assert(evalStream(ir) == v, Pretty(ctx, ir)) assert(evalStreamLen(ir) == len, Pretty(ctx, ir)) } } - @Test def testZipIfNA(): Unit = { + test("ZipIfNA") { val t = PCanonicalStruct( true, @@ -929,7 +927,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testFold(): Unit = { + test("Fold") { val ints = Literal(TArray(TInt32), IndexedSeq(1, 2, 3, 4)) val strsLit = Literal(TArray(TString), IndexedSeq("one", "two", "three", "four")) val strs = @@ -955,7 +953,7 @@ class EmitStreamSuite extends HailSuite { ) } - @Test def testGrouped(): Unit = { + test("Grouped") { // empty => empty assertEvalsTo( ToArray( @@ -1022,7 +1020,7 @@ class EmitStreamSuite extends HailSuite { } - @Test def testMakeStream(): Unit = { + test("MakeStream") { assertEvalsTo( ToArray( MakeStream(IndexedSeq(I32(1), NA(TInt32), I32(2)), TStream(TInt32)) @@ -1041,7 +1039,7 @@ class EmitStreamSuite extends HailSuite { ) } - @Test def testMultiplicity(): Unit = { + test("Multiplicity") { val target = Ref(freshName(), TStream(TInt32)) val tests = IndexedSeq( StreamRange(0, 10, 1) -> 0, @@ -1057,7 +1055,7 @@ class EmitStreamSuite extends HailSuite { streamScanIR(streamScanIR(target, 0)((_, i) => i), 0)((_, i) => i) -> 1, ) - forAll(tests) { case (ir, v) => + tests.foreach { case (ir, v) => assert(StreamUtils.multiplicity(ir, target.name) == v, Pretty(ctx, ir)) } } @@ -1102,7 +1100,7 @@ class EmitStreamSuite extends HailSuite { ) } - @Test def testMemoryRangeFold(): Unit = + test("MemoryRangeFold") { assertMemoryDoesNotScaleWithStreamSize() { size => foldIR( mapIR(flatMapIR(StreamRange(0, size, 1, true))(x => StreamRange(0, x, 1, true))) { i => @@ -1111,8 +1109,9 @@ class EmitStreamSuite extends HailSuite { I32(0), ) { case (acc, value) => maxIR(acc, invoke("length", TInt32, value)) } } + } - @Test def testStreamJoinMemory(): Unit = { + test("StreamJoinMemory") { assertMemoryDoesNotScaleWithStreamSize() { size => sumIR(joinRightDistinctIR( @@ -1152,7 +1151,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamGroupedMemory(): Unit = { + test("StreamGroupedMemory") { assertMemoryDoesNotScaleWithStreamSize() { size => sumIR(mapIR(StreamGrouped(rangeIR(size), 100))(stream => I32(1))) } @@ -1162,14 +1161,15 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamFilterMemory(): Unit = + test("StreamFilterMemory") { assertMemoryDoesNotScaleWithStreamSize(highSize = 100000) { size => StreamLen(filterIR(mapIR(StreamRange(0, size, 1, true))(i => invoke("str", TString, i))) { str => invoke("length", TInt32, str) > (size * 9 / 10).toString.size }) } + } - @Test def testStreamFlatMapMemory(): Unit = { + test("StreamFlatMapMemory") { assertMemoryDoesNotScaleWithStreamSize() { size => sumIR(flatMapIR(filteredRangeStructs(size)) { struct => StreamRange(0, invoke("length", TInt32, GetField(struct, "foo2")), 1, true) @@ -1183,14 +1183,15 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testGroupedFlatMapMemManagementMismatch(): Unit = + test("GroupedFlatMapMemManagementMismatch") { assertMemoryDoesNotScaleWithStreamSize() { size => foldLength(flatMapIR(mapIR(StreamGrouped(rangeStructs(size), 16))(x => ToArray(x))) { a => ToStream(a, false) }) } + } - @Test def testStreamTakeWhile(): Unit = { + test("StreamTakeWhile") { val makestream = MakeStream(FastSeq(I32(1), I32(2), I32(0), I32(1), I32(-1)), TStream(TInt32)) assert(evalStream(takeWhile(makestream)(r => r > 0)) == IndexedSeq(1, 2)) assert(evalStream(StreamTake(makestream, I32(3))) == IndexedSeq(1, 2, 0)) @@ -1201,7 +1202,7 @@ class EmitStreamSuite extends HailSuite { )) } - @Test def testStreamDropWhile(): Unit = { + test("StreamDropWhile") { val makestream = MakeStream(FastSeq(I32(1), I32(2), I32(0), I32(1), I32(-1)), TStream(TInt32)) assert(evalStream(dropWhile(makestream)(r => r > 0)) == IndexedSeq(0, 1, -1)) assert(evalStream(StreamDrop(makestream, I32(3))) == IndexedSeq(1, -1)) @@ -1214,7 +1215,7 @@ class EmitStreamSuite extends HailSuite { } - @Test def testStreamTakeDropMemory(): Unit = { + test("StreamTakeDropMemory") { assertMemoryDoesNotScaleWithStreamSize() { size => foldLength(StreamTake(rangeStructs(size), (size / 2).toI)) } @@ -1232,12 +1233,12 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamIota(): Unit = { + test("StreamIota") { assert(evalStream(takeWhile(iota(0, 2))(elt => elt < 10)) == IndexedSeq(0, 2, 4, 6, 8)) assert(evalStream(StreamTake(iota(5, -5), 3)) == IndexedSeq(5, 0, -5)) } - @Test def testStreamIntervalJoin(): Unit = { + test("StreamIntervalJoin") { val keyStream = mapIR(StreamRange(0, 9, 1, requiresMemoryManagementPerElement = true)) { i => MakeStruct(FastSeq("i" -> i)) } diff --git a/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala index 499ac6d71ae..8e6b6676d62 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala @@ -9,10 +9,8 @@ import is.hail.utils.{Interval, IntervalEndpoint} import is.hail.variant.{Locus, ReferenceGenome} import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.testng.annotations.Test -class ExtractIntervalFiltersSuite extends HailSuite { outer => +class ExtractIntervalFiltersSuite extends HailSuite { val ref1 = Ref(freshName(), TStruct("w" -> TInt32, "x" -> TInt32, "y" -> TBoolean)) val unknownBool = GetField(ref1, "y") @@ -118,7 +116,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => check(IsNA(filter), rowRef, key, probes, naResidual, naIntervals) } - @Test def testIsNA(): Unit = { + test("IsNA") { val testRows = FastSeq( Row(0, 0, true), Row(0, null, true), @@ -134,7 +132,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testKeyComparison(): Unit = { + test("KeyComparison") { def check( op: ComparisonOp[Boolean], point: IR, @@ -243,7 +241,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ).isEmpty) } - @Test def testLiteralContains(): Unit = { + test("LiteralContains") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -259,53 +257,51 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals) } - forAll { - Array( - Literal(TSet(TInt32), Set(null, 10, 1)), - Literal(TArray(TInt32), FastSeq(10, 1, null)), - Literal(TDict(TInt32, TString), Map(1 -> "foo", (null, "bar"), 10 -> "baz")), - ) - } { lit => - check( - invoke("contains", TBoolean, lit, k1), - FastSeq( - Interval(Row(1), Row(1), true, true), - Interval(Row(10), Row(10), true, true), - Interval(Row(null), Row(), true, true), - ), - FastSeq( - Interval(Row(), Row(1), true, false), - Interval(Row(1), Row(10), false, false), - Interval(Row(10), Row(null), false, false), - ), - FastSeq(), - ) - } + Array( + Literal(TSet(TInt32), Set(null, 10, 1)), + Literal(TArray(TInt32), FastSeq(10, 1, null)), + Literal(TDict(TInt32, TString), Map(1 -> "foo", (null, "bar"), 10 -> "baz")), + ) + .foreach { lit => + check( + invoke("contains", TBoolean, lit, k1), + FastSeq( + Interval(Row(1), Row(1), true, true), + Interval(Row(10), Row(10), true, true), + Interval(Row(null), Row(), true, true), + ), + FastSeq( + Interval(Row(), Row(1), true, false), + Interval(Row(1), Row(10), false, false), + Interval(Row(10), Row(null), false, false), + ), + FastSeq(), + ) + } - forAll { - Array( - Literal(TSet(TInt32), Set(10, 1)), - Literal(TArray(TInt32), FastSeq(10, 1)), - Literal(TDict(TInt32, TString), Map(1 -> "foo", 10 -> "baz")), - ) - } { lit => - check( - invoke("contains", TBoolean, lit, k1), - FastSeq( - Interval(Row(1), Row(1), true, true), - Interval(Row(10), Row(10), true, true), - ), - FastSeq( - Interval(Row(), Row(1), true, false), - Interval(Row(1), Row(10), false, false), - Interval(Row(10), Row(), false, true), - ), - FastSeq(), - ) - } + Array( + Literal(TSet(TInt32), Set(10, 1)), + Literal(TArray(TInt32), FastSeq(10, 1)), + Literal(TDict(TInt32, TString), Map(1 -> "foo", 10 -> "baz")), + ) + .foreach { lit => + check( + invoke("contains", TBoolean, lit, k1), + FastSeq( + Interval(Row(1), Row(1), true, true), + Interval(Row(10), Row(10), true, true), + ), + FastSeq( + Interval(Row(), Row(1), true, false), + Interval(Row(1), Row(10), false, false), + Interval(Row(10), Row(), false, true), + ), + FastSeq(), + ) + } } - @Test def testLiteralContainsStruct(): Unit = { + test("LiteralContainsStruct") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -332,62 +328,60 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - forAll { - Array( - Literal(TSet(structT1), Set(Row(1, 2), Row(3, 4), Row(3, null))), - Literal(TArray(structT1), FastSeq(Row(3, 4), Row(1, 2), Row(3, null))), - Literal( - TDict(structT1, TString), - Map(Row(1, 2) -> "foo", Row(3, 4) -> "bar", Row(3, null) -> "baz"), - ), - ) - } { lit => - forAll(fullKeyRefs) { k => - check( - invoke("contains", TBoolean, lit, k), - IndexedSeq( - Interval(Row(1, 2), Row(1, 2), true, true), - Interval(Row(3, 4), Row(3, 4), true, true), - Interval(Row(3, null), Row(3, null), true, true), - ), - IndexedSeq( - Interval(Row(), Row(1, 2), true, false), - Interval(Row(1, 2), Row(3, 4), false, false), - Interval(Row(3, 4), Row(3, null), false, false), - Interval(Row(3, null), Row(), false, true), - ), - IndexedSeq(), - ) + Array( + Literal(TSet(structT1), Set(Row(1, 2), Row(3, 4), Row(3, null))), + Literal(TArray(structT1), FastSeq(Row(3, 4), Row(1, 2), Row(3, null))), + Literal( + TDict(structT1, TString), + Map(Row(1, 2) -> "foo", Row(3, 4) -> "bar", Row(3, null) -> "baz"), + ), + ) + .foreach { lit => + fullKeyRefs.foreach { k => + check( + invoke("contains", TBoolean, lit, k), + IndexedSeq( + Interval(Row(1, 2), Row(1, 2), true, true), + Interval(Row(3, 4), Row(3, 4), true, true), + Interval(Row(3, null), Row(3, null), true, true), + ), + IndexedSeq( + Interval(Row(), Row(1, 2), true, false), + Interval(Row(1, 2), Row(3, 4), false, false), + Interval(Row(3, 4), Row(3, null), false, false), + Interval(Row(3, null), Row(), false, true), + ), + IndexedSeq(), + ) + } } - } - forAll { - Array( - Literal(TSet(structT2), Set(Row(1), Row(3), Row(null))), - Literal(TArray(structT2), FastSeq(Row(3), Row(null), Row(1))), - Literal(TDict(structT2, TString), Map(Row(1) -> "foo", Row(null) -> "baz", Row(3) -> "bar")), - ) - } { lit => - forAll(prefixKeyRefs) { k => - check( - invoke("contains", TBoolean, lit, k), - IndexedSeq( - Interval(Row(1), Row(1), true, true), - Interval(Row(3), Row(3), true, true), - Interval(Row(null), Row(), true, true), - ), - IndexedSeq( - Interval(Row(), Row(1), true, false), - Interval(Row(1), Row(3), false, false), - Interval(Row(3), Row(null), false, false), - ), - IndexedSeq(), - ) + Array( + Literal(TSet(structT2), Set(Row(1), Row(3), Row(null))), + Literal(TArray(structT2), FastSeq(Row(3), Row(null), Row(1))), + Literal(TDict(structT2, TString), Map(Row(1) -> "foo", Row(null) -> "baz", Row(3) -> "bar")), + ) + .foreach { lit => + prefixKeyRefs.foreach { k => + check( + invoke("contains", TBoolean, lit, k), + IndexedSeq( + Interval(Row(1), Row(1), true, true), + Interval(Row(3), Row(3), true, true), + Interval(Row(null), Row(), true, true), + ), + IndexedSeq( + Interval(Row(), Row(1), true, false), + Interval(Row(1), Row(3), false, false), + Interval(Row(3), Row(null), false, false), + ), + IndexedSeq(), + ) + } } - } } - @Test def testIntervalContains(): Unit = { + test("IntervalContains") { val interval = Interval(1, 5, false, true) val testRows = FastSeq( Row(0, 0, true), @@ -412,7 +406,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testIntervalContainsStruct(): Unit = { + test("IntervalContainsStruct") { val fullInterval = Interval(Row(1, 1), Row(2, 2), false, true) val prefixInterval = Interval(Row(1), Row(2), false, true) @@ -446,7 +440,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - forAll(fullKeyRefs) { k => + fullKeyRefs.foreach { k => check( invoke("contains", TBoolean, Literal(TInterval(structT1), fullInterval), k), FastSeq(fullInterval), @@ -458,7 +452,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - forAll(prefixKeyRefs) { k => + prefixKeyRefs.foreach { k => check( invoke("contains", TBoolean, Literal(TInterval(structT2), prefixInterval), k), FastSeq(prefixInterval), @@ -471,7 +465,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => } } - @Test def testLocusContigComparison(): Unit = { + test("LocusContigComparison") { val ref = Ref(freshName(), TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") @@ -503,7 +497,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => checkAll(not(ir1), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) } - @Test def testLocusPositionComparison(): Unit = { + test("LocusPositionComparison") { val ref = Ref(freshName(), TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") val pos = invoke("position", TInt32, k) @@ -645,7 +639,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ).isEmpty) } - @Test def testLocusContigContains(): Unit = { + test("LocusContigContains") { val ref = Ref(freshName(), TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") val contig = invoke("contig", TString, k) @@ -665,106 +659,104 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => checkAll(node, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) } - forAll { - Array( - Literal(TSet(TString), Set("chr10", "chr1", null, "foo")), - Literal(TArray(TString), FastSeq("foo", "chr10", null, "chr1")), - Literal( - TDict(TString, TString), - Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz", (null, "quux")), - ), - ) - } { lit => - check( - invoke("contains", TBoolean, lit, contig), - FastSeq( - Interval( - Row(Locus("chr1", 1)), - Row(Locus("chr1", grch38.contigLength("chr1"))), - true, - false, - ), - Interval( - Row(Locus("chr10", 1)), - Row(Locus("chr10", grch38.contigLength("chr10"))), - true, - false, - ), - Interval(Row(null), Row(), true, true), - ), - FastSeq( - Interval( - Row(), - Row(Locus("chr1", 1)), - true, - false, - ), - Interval( - Row(Locus("chr1", grch38.contigLength("chr1"))), - Row(Locus("chr10", 1)), - true, - false, + Array( + Literal(TSet(TString), Set("chr10", "chr1", null, "foo")), + Literal(TArray(TString), FastSeq("foo", "chr10", null, "chr1")), + Literal( + TDict(TString, TString), + Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz", (null, "quux")), + ), + ) + .foreach { lit => + check( + invoke("contains", TBoolean, lit, contig), + FastSeq( + Interval( + Row(Locus("chr1", 1)), + Row(Locus("chr1", grch38.contigLength("chr1"))), + true, + false, + ), + Interval( + Row(Locus("chr10", 1)), + Row(Locus("chr10", grch38.contigLength("chr10"))), + true, + false, + ), + Interval(Row(null), Row(), true, true), ), - Interval( - Row(Locus("chr10", grch38.contigLength("chr10"))), - Row(null), - true, - false, + FastSeq( + Interval( + Row(), + Row(Locus("chr1", 1)), + true, + false, + ), + Interval( + Row(Locus("chr1", grch38.contigLength("chr1"))), + Row(Locus("chr10", 1)), + true, + false, + ), + Interval( + Row(Locus("chr10", grch38.contigLength("chr10"))), + Row(null), + true, + false, + ), ), - ), - FastSeq(), - ) - } + FastSeq(), + ) + } - forAll { - Array( - Literal(TSet(TString), Set("chr10", "chr1", "foo")), - Literal(TArray(TString), FastSeq("foo", "chr10", "chr1")), - Literal(TDict(TString, TString), Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz")), - ) - } { lit => - check( - invoke("contains", TBoolean, lit, contig), - FastSeq( - Interval( - Row(Locus("chr1", 1)), - Row(Locus("chr1", grch38.contigLength("chr1"))), - true, - false, - ), - Interval( - Row(Locus("chr10", 1)), - Row(Locus("chr10", grch38.contigLength("chr10"))), - true, - false, - ), - ), - FastSeq( - Interval( - Row(), - Row(Locus("chr1", 1)), - true, - false, - ), - Interval( - Row(Locus("chr1", grch38.contigLength("chr1"))), - Row(Locus("chr10", 1)), - true, - false, + Array( + Literal(TSet(TString), Set("chr10", "chr1", "foo")), + Literal(TArray(TString), FastSeq("foo", "chr10", "chr1")), + Literal(TDict(TString, TString), Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz")), + ) + .foreach { lit => + check( + invoke("contains", TBoolean, lit, contig), + FastSeq( + Interval( + Row(Locus("chr1", 1)), + Row(Locus("chr1", grch38.contigLength("chr1"))), + true, + false, + ), + Interval( + Row(Locus("chr10", 1)), + Row(Locus("chr10", grch38.contigLength("chr10"))), + true, + false, + ), ), - Interval( - Row(Locus("chr10", grch38.contigLength("chr10"))), - Row(), - true, - true, + FastSeq( + Interval( + Row(), + Row(Locus("chr1", 1)), + true, + false, + ), + Interval( + Row(Locus("chr1", grch38.contigLength("chr1"))), + Row(Locus("chr10", 1)), + true, + false, + ), + Interval( + Row(Locus("chr10", grch38.contigLength("chr10"))), + Row(), + true, + true, + ), ), - ), - FastSeq(), - ) - } + FastSeq(), + ) + } } - @Test def testIntervalListFold(): Unit = { + test("IntervalListFold") { val inIntervals = FastSeq( Interval(0, 10, true, false), Interval(20, 25, true, false), @@ -836,7 +828,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testDisjunction(): Unit = { + test("Disjunction") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -908,7 +900,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testConjunction(): Unit = { + test("Conjunction") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -965,7 +957,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testCoalesce(): Unit = { + test("Coalesce") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -1001,7 +993,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testIf(): Unit = { + test("If") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -1034,7 +1026,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testSwitch(): Unit = { + test("Switch") { def check( node: IR, trueIntervals: IndexedSeq[Interval], @@ -1085,7 +1077,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => ) } - @Test def testRelationalChildren(): Unit = { + test("RelationalChildren") { val testRows = FastSeq( Row(0, 0, true), Row(0, 10, true), @@ -1102,7 +1094,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => check(filter, ref1, k1Full, testRows, filter, FastSeq(Interval(Row(), Row(), true, true))) } - @Test def testIntegration(): Unit = { + test("Integration") { val tab1 = TableRange(10, 5) def k = GetField(Ref(TableIR.rowName, tab1.typ.rowType), "idx") diff --git a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala index 2e11c8c32ea..14dabaa5167 100644 --- a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala @@ -7,36 +7,33 @@ import is.hail.expr.ir.defs.{ } import is.hail.types.virtual.{TFloat64, TInt32} -import org.testng.annotations.{DataProvider, Test} - class FoldConstantsSuite extends HailSuite { - @Test def testRandomBlocksFolding(): Unit = { + test("RandomBlocksFolding") { val x = Apply( "rand_norm", FastSeq.empty, FastSeq(RNGSplitStatic(RNGStateLiteral(), 0L), F64(0d), F64(0d)), TFloat64, ) - assert(FoldConstants(ctx, x) == x) + assertEquals(FoldConstants(ctx, x), x) } - @Test def testErrorCatching(): Unit = { + test("ErrorCatching") { val ir = invoke("toInt32", TInt32, Str("")) - assert(FoldConstants(ctx, ir) == ir) + assertEquals(FoldConstants(ctx, ir), ir) } - @DataProvider(name = "aggNodes") - def aggNodes(): Array[Array[Any]] = { - Array[IR]( - AggLet(freshName(), I32(1), I32(1), false), - AggLet(freshName(), I32(1), I32(1), true), - ApplyAggOp(Sum())(I64(1)), - ApplyScanOp(Sum())(I64(1)), - ).map(x => Array[Any](x)) + object aggNodes extends TestCases { + def apply(node: IR)(implicit loc: munit.Location): Unit = + test("AggNodesDoNotFold") { + assertEquals(FoldConstants(ctx, node), node) + } } - @Test def testAggNodesConstruction(): Unit = aggNodes(): Unit - - @Test(dataProvider = "aggNodes") def testAggNodesDoNotFold(node: IR): Unit = - assert(FoldConstants(ctx, node) == node) + Array[IR]( + AggLet(freshName(), I32(1), I32(1), false), + AggLet(freshName(), I32(1), I32(1), true), + ApplyAggOp(Sum())(I64(1)), + ApplyScanOp(Sum())(I64(1)), + ).foreach(aggNodes(_)) } diff --git a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala index f6f8854b5eb..7dc927c91ce 100644 --- a/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/ForwardLetsSuite.scala @@ -8,16 +8,26 @@ import is.hail.expr.Nat import is.hail.expr.ir.defs._ import is.hail.types.virtual._ -import org.testng.annotations.{BeforeMethod, DataProvider, Test} - class ForwardLetsSuite extends HailSuite { - @BeforeMethod - def resetUidCounter(): Unit = + override def beforeEach(context: BeforeEach): Unit = { + super.beforeEach(context) is.hail.expr.ir.uidCounter = 0 + } + + def aggMin(value: IR): ApplyAggOp = ApplyAggOp(Min())(value) - @DataProvider(name = "nonForwardingOps") - def nonForwardingOps(): Array[Array[IR]] = { + object checkNonForwardingOps extends TestCases { + def apply(ir: IR)(implicit loc: munit.Location): Unit = + test("NonForwardingOps") { + val after = ForwardLets(ctx, ir) + val normalizedBefore = NormalizeNames()(ctx, ir) + val normalizedAfter = NormalizeNames()(ctx, after) + assertEquals(normalizedBefore, normalizedAfter) + } + } + + { val a = ToArray(StreamRange(I32(0), I32(10), I32(1))) val x = Ref(freshName(), TInt32) Array( @@ -41,11 +51,19 @@ class ForwardLetsSuite extends HailSuite { )), ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, x), I32(1)), streamAggIR(ToStream(a))(y => ApplyAggOp(Sum())(x + y)), - ).map(ir => Array[IR](Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir))) + ).map(ir => Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir)) + .foreach(checkNonForwardingOps(_)) + } + + object checkNonForwardingNonEvalOps extends TestCases { + def apply(ir: IR)(implicit loc: munit.Location): Unit = + test("NonForwardingNonEvalOps") { + val after = ForwardLets(ctx, ir) + assert(after.isInstanceOf[Block]) + } } - @DataProvider(name = "nonForwardingNonEvalOps") - def nonForwardingNonEvalOps(): Array[Array[IR]] = { + { val x = Ref(freshName(), TInt32) val y = Ref(freshName(), TInt32) val z = Ref(freshName(), TInt32) @@ -66,23 +84,38 @@ class ForwardLetsSuite extends HailSuite { TInt32, If(y < x, Recur(f, FastSeq[IR](y - I32(1)), TInt32), x), ), - ).map(ir => Array[IR](Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir))) + ).map(ir => Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir)) + .foreach(checkNonForwardingNonEvalOps(_)) } - def aggMin(value: IR): ApplyAggOp = ApplyAggOp(Min())(value) + object checkNonForwardingAggOps extends TestCases { + def apply(ir: IR)(implicit loc: munit.Location): Unit = + test("NonForwardingAggOps") { + val after = ForwardLets(ctx, ir) + assert(after.isInstanceOf[Block]) + } + } - @DataProvider(name = "nonForwardingAggOps") - def nonForwardingAggOps(): Array[Array[IR]] = { + { val a = StreamRange(I32(0), I32(10), I32(1)) val x = Ref(freshName(), TInt32) Array( aggArrayPerElement(ToArray(a))((y, _) => aggMin(x + y)), aggExplodeIR(a)(y => aggMin(y + x)), - ).map(ir => Array[IR](AggLet(x.name, In(0, TInt32) + In(0, TInt32), ir, false))) + ).map(ir => AggLet(x.name, In(0, TInt32) + In(0, TInt32), ir, false)) + .foreach(checkNonForwardingAggOps(_)) } - @DataProvider(name = "forwardingOps") - def forwardingOps(): Array[Array[IR]] = { + object checkForwardingOps extends TestCases { + def apply(ir: IR)(implicit loc: munit.Location): Unit = + test("ForwardingOps") { + val after = ForwardLets(ctx, ir) + assert(!after.isInstanceOf[Block]) + assertEvalSame(ir, args = ArraySeq(5 -> TInt32)) + } + } + + { val x = Ref(freshName(), TInt32) Array( MakeStruct(FastSeq("a" -> I32(1), "b" -> ApplyBinaryPrimOp(Add(), x, I32(2)))), @@ -92,27 +125,102 @@ class ForwardLetsSuite extends HailSuite { ApplyUnaryPrimOp(Negate, x), ToArray(mapIR(rangeIR(x))(foo => foo)), ToArray(filterIR(rangeIR(x))(foo => foo <= I32(0))), - ).map(ir => Array[IR](Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir))) + ).map(ir => Let(FastSeq(x.name -> (In(0, TInt32) + In(0, TInt32))), ir)) + .foreach(checkForwardingOps(_)) + } + + object checkForwardingAggOps extends TestCases { + def apply(ir: IR)(implicit loc: munit.Location): Unit = + test("ForwardingAggOps") { + val after = ForwardLets(ctx, ir) + assert(!after.isInstanceOf[Block]) + } } - @DataProvider(name = "forwardingAggOps") - def forwardingAggOps(): Array[Array[IR]] = { + { val x = Ref(freshName(), TInt32) val other = Ref(freshName(), TInt32) Array( AggFilter(x.ceq(I32(0)), aggMin(other), false), aggMin(x + other), - ).map(ir => Array[IR](AggLet(x.name, In(0, TInt32) + In(0, TInt32), ir, false))) + ).map(ir => AggLet(x.name, In(0, TInt32) + In(0, TInt32), ir, false)) + .foreach(checkForwardingAggOps(_)) } - @Test def assertDataProvidersWork(): Unit = { - nonForwardingOps(): Unit - forwardingOps(): Unit - nonForwardingAggOps(): Unit - forwardingAggOps(): Unit + object checkTrivialCases extends TestCases { + def apply(input: IR, _expected: IR, reason: String)(implicit loc: munit.Location): Unit = + test("TrivialCases") { + val normalize: (ExecuteContext, BaseIR) => BaseIR = + NormalizeNames(allowFreeVariables = true) + val result = normalize(ctx, ForwardLets(ctx, input)) + val expected = normalize(ctx, _expected) + assertEquals( + result, + normalize(ctx, expected), + s"\ninput:\n${Pretty.sexprStyle(input)}\nexpected:\n${Pretty.sexprStyle(expected)}\ngot:\n${Pretty.sexprStyle(result)}\n$reason", + ) + } } - @Test def testBlock(): Unit = { + { + val pi = Math.atan(1) * 4 + val r = Ref(freshName(), TFloat64) + + checkTrivialCases( + bindIR(I32(0))(_ => I32(2)), + I32(2), + """"x" is unused.""", + ) + checkTrivialCases( + bindIR(I32(0))(x => x), + I32(0), + """"x" is constant and is used once.""", + ) + checkTrivialCases( + bindIR(I32(2))(x => x * x), + I32(2) * I32(2), + """"x" is a primitive constant (ForwardLets does not evaluate).""", + ) + checkTrivialCases( + bindIRs(I32(2), F64(pi), r) { case Seq(two, pi, r) => + ApplyBinaryPrimOp(Multiply(), ApplyBinaryPrimOp(Multiply(), Cast(two, TFloat64), pi), r) + }, + ApplyBinaryPrimOp( + Multiply(), + ApplyBinaryPrimOp(Multiply(), Cast(I32(2), TFloat64), F64(pi)), + r, + ), + """Forward constant primitive values and simple use ref.""", + ) + checkTrivialCases( + IRBuilder.scoped { b => + val x0 = b.strictMemoize(I32(2)) + val x1 = b.strictMemoize(Cast(x0, TFloat64)) + val x2 = b.strictMemoize(ApplyBinaryPrimOp(FloatingPointDivide(), x1, F64(2))) + val x3 = b.strictMemoize(F64(pi)) + val x4 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x3, x1)) + val x5 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x2, x2)) + val x6 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x3, x5)) + MakeStruct(FastSeq("radius" -> x2, "circumference" -> x4, "area" -> x6)) + }, + IRBuilder.scoped { b => + val x1 = b.strictMemoize(Cast(I32(2), TFloat64)) + val x2 = b.strictMemoize(ApplyBinaryPrimOp(FloatingPointDivide(), x1, F64(2))) + MakeStruct(FastSeq( + "radius" -> x2, + "circumference" -> ApplyBinaryPrimOp(Multiply(), F64(pi), x1), + "area" -> ApplyBinaryPrimOp( + Multiply(), + F64(pi), + ApplyBinaryPrimOp(Multiply(), x2, x2), + ), + )) + }, + "Cascading Let-bindings are forwarded", + ) + } + + test("Block") { val x = Ref(freshName(), TInt32) val y = Ref(freshName(), TInt32) val ir = Block( @@ -121,115 +229,10 @@ class ForwardLetsSuite extends HailSuite { ) val after: IR = ForwardLets(ctx, ir) val expected = ApplyAggOp(Sum())(I32(1)) - assert(NormalizeNames()(ctx, after) == NormalizeNames()(ctx, expected)) - } - - @Test(dataProvider = "nonForwardingOps") - def testNonForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ctx, ir) - val normalizedBefore = NormalizeNames()(ctx, ir) - val normalizedAfter = NormalizeNames()(ctx, after) - assert(normalizedBefore == normalizedAfter) - } - - @Test(dataProvider = "nonForwardingNonEvalOps") - def testNonForwardingNonEvalOps(ir: IR): Unit = { - val after = ForwardLets(ctx, ir) - assert(after.isInstanceOf[Block]) - } - - @Test(dataProvider = "nonForwardingAggOps") - def testNonForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ctx, ir) - assert(after.isInstanceOf[Block]) - } - - @Test(dataProvider = "forwardingOps") - def testForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ctx, ir) - assert(!after.isInstanceOf[Block]) - assertEvalSame(ir, args = ArraySeq(5 -> TInt32)) - } - - @Test(dataProvider = "forwardingAggOps") - def testForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ctx, ir) - assert(!after.isInstanceOf[Block]) - } - - @DataProvider(name = "TrivialIRCases") - def trivalIRCases: Array[Array[Any]] = { - val pi = Math.atan(1) * 4 - - val r = Ref(freshName(), TFloat64) - Array( - Array( - bindIR(I32(0))(_ => I32(2)), - I32(2), - """"x" is unused.""", - ), - Array( - bindIR(I32(0))(x => x), - I32(0), - """"x" is constant and is used once.""", - ), - Array( - bindIR(I32(2))(x => x * x), - I32(2) * I32(2), - """"x" is a primitive constant (ForwardLets does not evaluate).""", - ), - Array( - bindIRs(I32(2), F64(pi), r) { case Seq(two, pi, r) => - ApplyBinaryPrimOp(Multiply(), ApplyBinaryPrimOp(Multiply(), Cast(two, TFloat64), pi), r) - }, - ApplyBinaryPrimOp( - Multiply(), - ApplyBinaryPrimOp(Multiply(), Cast(I32(2), TFloat64), F64(pi)), - r, - ), - """Forward constant primitive values and simple use ref.""", - ), - Array( - IRBuilder.scoped { b => - val x0 = b.strictMemoize(I32(2)) - val x1 = b.strictMemoize(Cast(x0, TFloat64)) - val x2 = b.strictMemoize(ApplyBinaryPrimOp(FloatingPointDivide(), x1, F64(2))) - val x3 = b.strictMemoize(F64(pi)) - val x4 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x3, x1)) - val x5 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x2, x2)) - val x6 = b.strictMemoize(ApplyBinaryPrimOp(Multiply(), x3, x5)) - MakeStruct(FastSeq("radius" -> x2, "circumference" -> x4, "area" -> x6)) - }, - IRBuilder.scoped { b => - val x1 = b.strictMemoize(Cast(I32(2), TFloat64)) - val x2 = b.strictMemoize(ApplyBinaryPrimOp(FloatingPointDivide(), x1, F64(2))) - MakeStruct(FastSeq( - "radius" -> x2, - "circumference" -> ApplyBinaryPrimOp(Multiply(), F64(pi), x1), - "area" -> ApplyBinaryPrimOp( - Multiply(), - F64(pi), - ApplyBinaryPrimOp(Multiply(), x2, x2), - ), - )) - }, - "Cascading Let-bindings are forwarded", - ), - ) - } - - @Test(dataProvider = "TrivialIRCases") - def testTrivialCases(input: IR, _expected: IR, reason: String): Unit = { - val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true) - val result = normalize(ctx, ForwardLets(ctx, input)) - val expected = normalize(ctx, _expected) - assert( - result == normalize(ctx, expected), - s"\ninput:\n${Pretty.sexprStyle(input)}\nexpected:\n${Pretty.sexprStyle(expected)}\ngot:\n${Pretty.sexprStyle(result)}\n$reason", - ) + assertEquals(NormalizeNames()(ctx, after), NormalizeNames()(ctx, expected)) } - @Test def testAggregators(): Unit = { + test("Aggregators") { val row = Ref(freshName(), TStruct("idx" -> TInt32)) val aggEnv = Env[Type](row.name -> row.typ) @@ -238,7 +241,7 @@ class ForwardLetsSuite extends HailSuite { TypeCheck(ctx, ForwardLets(ctx, ir0), BindingEnv(Env.empty, agg = Some(aggEnv))) } - @Test def testNestedBindingOverwrites(): Unit = { + test("NestedBindingOverwrites") { val x = Ref(freshName(), TInt32) val env = Env[Type](x.name -> TInt32) def xCast = Cast(x, TFloat64) @@ -248,7 +251,7 @@ class ForwardLetsSuite extends HailSuite { TypeCheck(ctx, ForwardLets(ctx, ir), BindingEnv(env)) } - @Test def testLetsDoNotForwardInsideArrayAggWithNoOps(): Unit = { + test("LetsDoNotForwardInsideArrayAggWithNoOps") { val y = Ref(freshName(), TInt32) val x = bindIR( streamAggIR(ToStream(In(0, TArray(TInt32))))(_ => y) diff --git a/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala b/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala index 5bd9878aeea..1aac0479ae8 100644 --- a/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/FunctionSuite.scala @@ -11,8 +11,6 @@ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ import is.hail.variant.Call2 -import org.testng.annotations.Test - object ScalaTestObject { def testFunction(): Int = 1 } @@ -53,36 +51,35 @@ class FunctionSuite extends HailSuite { TestRegisterFunctions.registerAll() - @Test - def testCodeFunction(): Unit = + test("CodeFunction") { assertEvalsTo( invoke("triangle", TInt32, In(0, TInt32)), FastSeq(5 -> TInt32), (5 * (5 + 1)) / 2, ) + } - @Test - def testStaticFunction(): Unit = + test("StaticFunction") { assertEvalsTo( invoke("compare", TInt32, In(0, TInt32), I32(0)) > 0, FastSeq(5 -> TInt32), true, ) + } - @Test - def testScalaFunction(): Unit = + test("ScalaFunction") { assertEvalsTo(invoke("foobar1", TInt32), 1) + } - @Test - def testIRConversion(): Unit = + test("IRConversion") { assertEvalsTo(invoke("addone", TInt32, In(0, TInt32)), FastSeq(5 -> TInt32), 6) + } - @Test - def testScalaFunctionCompanion(): Unit = + test("ScalaFunctionCompanion") { assertEvalsTo(invoke("foobar2", TInt32), 2) + } - @Test - def testVariableUnification(): Unit = { + test("VariableUnification") { assert(IRFunctionRegistry.lookup( "testCodeUnification", TInt32, @@ -105,16 +102,15 @@ class FunctionSuite extends HailSuite { ).isDefined) } - @Test - def testUnphasedDiploidGtIndexCall(): Unit = + test("UnphasedDiploidGtIndexCall") { assertEvalsTo( invoke("UnphasedDiploidGtIndexCall", TCall, In(0, TInt32)), FastSeq(0 -> TInt32), Call2.fromUnphasedDiploidGtIndex(0), ) + } - @Test - def testGetOrGenMethod(): Unit = { + test("GetOrGenMethod") { val fb = EmitFunctionBuilder[Int](ctx, "foo") val i = fb.genFieldThisRef[Int]() val mb1 = fb.getOrGenEmitMethod("foo", "foo", FastSeq[ParamType](), UnitInfo) { mb => @@ -130,7 +126,7 @@ class FunctionSuite extends HailSuite { i } pool.scopedRegion { r => - assert(fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, ctx.taskContext, r)() == 2) + assertEquals(fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, ctx.taskContext, r)(), 2) } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala index a4b0c7ffd20..ede7afab68e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/GenotypeFunctionsSuite.scala @@ -6,31 +6,33 @@ import is.hail.collection.FastSeq import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual.TFloat64 -import org.testng.annotations.{DataProvider, Test} - class GenotypeFunctionsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @DataProvider(name = "gps") - def gpData(): Array[Array[Any]] = Array( - Array(FastSeq(1.0, 0.0, 0.0), 0.0), - Array(FastSeq(0.0, 1.0, 0.0), 1.0), - Array(FastSeq(0.0, 0.0, 1.0), 2.0), - Array(FastSeq(0.5, 0.5, 0.0), 0.5), - Array(FastSeq(0.0, 0.5, 0.5), 1.5), - Array(null, null), - Array(FastSeq(null, null, null), null), - Array(FastSeq(null, 0.5, 0.5), 1.5), - Array(FastSeq(0.0, null, 1.0), null), - Array(FastSeq(0.0, 0.5, null), null), - ) + object checkDosage extends TestCases { + def apply( + gp: IndexedSeq[java.lang.Double], + expected: java.lang.Double, + )(implicit + loc: munit.Location + ): Unit = test("dosage") { + assertEvalsTo(invoke("dosage", TFloat64, toIRDoubleArray(gp)), expected) + } + } - @Test(dataProvider = "gps") - def testDosage(gp: IndexedSeq[java.lang.Double], expected: java.lang.Double): Unit = - assertEvalsTo(invoke("dosage", TFloat64, toIRDoubleArray(gp)), expected) + checkDosage(FastSeq(1.0, 0.0, 0.0), 0.0) + checkDosage(FastSeq(0.0, 1.0, 0.0), 1.0) + checkDosage(FastSeq(0.0, 0.0, 1.0), 2.0) + checkDosage(FastSeq(0.5, 0.5, 0.0), 0.5) + checkDosage(FastSeq(0.0, 0.5, 0.5), 1.5) + checkDosage(null, null) + checkDosage(FastSeq(null, null, null), null) + checkDosage(FastSeq(null, 0.5, 0.5), 1.5) + checkDosage(FastSeq(0.0, null, 1.0), null) + checkDosage(FastSeq(0.0, 0.5, null), null) - @Test def testDosageLength(): Unit = { + test("dosageLength") { assertFatal(invoke("dosage", TFloat64, IRDoubleArray(1.0, 1.5)), "length") assertFatal(invoke("dosage", TFloat64, IRDoubleArray(1.0, 1.5, 0.0, 0.0)), "length") } diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index 9a1fc652998..39a5f00e574 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -3,7 +3,6 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} -import is.hail.backend.ExecuteContext import is.hail.collection.{FastSeq, IntArrayBuilder} import is.hail.collection.compat.immutable.ArraySeq import is.hail.collection.implicits.toRichIterable @@ -32,14 +31,28 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.json4s.jackson.{JsonMethods, Serialization} -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.{DataProvider, Test} class IRSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.nonLowering - @Test def testRandDifferentLengthUIDStrings(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() + CompileAndEvaluate[Unit]( + ctx, + invoke( + "index_bgen", + TInt64, + ArraySeq[Type](TLocus("GRCh37")), + Str(getTestResource("example.8bits.bgen")), + Str(getTestResource("example.8bits.bgen.idx2")), + Literal(TDict(TString, TString), Map("01" -> "1")), + False(), + I32(1000000), + ), + ) + } + + test("RandDifferentLengthUIDStrings") { implicit val execStrats = ExecStrategy.lowering val staticUID: Long = 112233 var rng = RNGSplit(RNGSplitStatic(RNGStateLiteral(), staticUID), I64(12345)) @@ -59,29 +72,37 @@ class IRSuite extends HailSuite { assert(expected1 != expected3) } - @Test def testI32(): Unit = + test("I32") { assertEvalsTo(I32(5), 5) + } - @Test def testI64(): Unit = + test("I64") { assertEvalsTo(I64(5), 5L) + } - @Test def testF32(): Unit = + test("F32") { assertEvalsTo(F32(3.14f), 3.14f) + } - @Test def testF64(): Unit = + test("F64") { assertEvalsTo(F64(3.14), 3.14) + } - @Test def testStr(): Unit = + test("Str") { assertEvalsTo(Str("Hail"), "Hail") + } - @Test def testTrue(): Unit = + test("True") { assertEvalsTo(True(), true) + } - @Test def testFalse(): Unit = + test("False") { assertEvalsTo(False(), false) + } + // FIXME Void() doesn't work because we can't handle a void type in a tuple - @Test def testCast(): Unit = { + test("Cast") { assertAllEvalTo( (Cast(I32(5), TInt32), 5), (Cast(I32(5), TInt64), 5L), @@ -105,7 +126,7 @@ class IRSuite extends HailSuite { ) } - @Test def testCastRename(): Unit = { + test("CastRename") { assertEvalsTo(CastRename(MakeStruct(FastSeq(("x", I32(1)))), TStruct("foo" -> TInt32)), Row(1)) assertEvalsTo( CastRename( @@ -116,10 +137,11 @@ class IRSuite extends HailSuite { ) } - @Test def testNA(): Unit = + test("NA") { assertEvalsTo(NA(TInt32), null) + } - @Test def testCoalesce(): Unit = { + test("Coalesce") { assertEvalsTo(Coalesce(FastSeq(In(0, TInt32))), FastSeq((null, TInt32)), null) assertEvalsTo(Coalesce(FastSeq(In(0, TInt32))), FastSeq((1, TInt32)), 1) assertEvalsTo(Coalesce(FastSeq(NA(TInt32), In(0, TInt32))), FastSeq((null, TInt32)), null) @@ -132,7 +154,7 @@ class IRSuite extends HailSuite { assertEvalsTo(Coalesce(FastSeq(NA(TInt32), I32(1), Die("foo", TInt32))), 1) } - @Test def testCoalesceWithDifferentRequiredeness(): Unit = { + test("CoalesceWithDifferentRequiredeness") { val t1 = In(0, TArray(TInt32)) val t2 = NA(TArray(TInt32)) val value = FastSeq(1, 2, 3, 4) @@ -146,7 +168,7 @@ class IRSuite extends HailSuite { val f64na = NA(TFloat64) val bna = NA(TBoolean) - @Test def testApplyUnaryPrimOpNegate(): Unit = { + test("ApplyUnaryPrimOpNegate") { assertAllEvalTo( (ApplyUnaryPrimOp(Negate, I32(5)), -5), (ApplyUnaryPrimOp(Negate, i32na), null), @@ -159,13 +181,13 @@ class IRSuite extends HailSuite { ) } - @Test def testApplyUnaryPrimOpBang(): Unit = { + test("ApplyUnaryPrimOpBang") { assertEvalsTo(ApplyUnaryPrimOp(Bang, False()), true) assertEvalsTo(ApplyUnaryPrimOp(Bang, True()), false) assertEvalsTo(ApplyUnaryPrimOp(Bang, bna), null) } - @Test def testApplyUnaryPrimOpBitFlip(): Unit = { + test("ApplyUnaryPrimOpBitFlip") { assertAllEvalTo( (ApplyUnaryPrimOp(BitNot, I32(0xdeadbeef)), ~0xdeadbeef), (ApplyUnaryPrimOp(BitNot, I32(-0xdeadbeef)), ~(-0xdeadbeef)), @@ -176,7 +198,7 @@ class IRSuite extends HailSuite { ) } - @Test def testApplyUnaryPrimOpBitCount(): Unit = { + test("ApplyUnaryPrimOpBitCount") { assertAllEvalTo( (ApplyUnaryPrimOp(BitCount, I32(0xdeadbeef)), Integer.bitCount(0xdeadbeef)), (ApplyUnaryPrimOp(BitCount, I32(-0xdeadbeef)), Integer.bitCount(-0xdeadbeef)), @@ -193,7 +215,7 @@ class IRSuite extends HailSuite { ) } - @Test def testApplyBinaryPrimOpAdd(): Unit = { + test("ApplyBinaryPrimOpAdd") { def assertSumsTo(t: Type, x: Any, y: Any, sum: Any): Unit = assertEvalsTo(ApplyBinaryPrimOp(Add(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), sum) assertSumsTo(TInt32, 5, 3, 8) @@ -217,7 +239,7 @@ class IRSuite extends HailSuite { assertSumsTo(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpSubtract(): Unit = { + test("ApplyBinaryPrimOpSubtract") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(Subtract(), In(0, t), In(1, t)), @@ -246,7 +268,7 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpMultiply(): Unit = { + test("ApplyBinaryPrimOpMultiply") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(Multiply(), In(0, t), In(1, t)), @@ -275,7 +297,7 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpFloatingPointDivide(): Unit = { + test("ApplyBinaryPrimOpFloatingPointDivide") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(FloatingPointDivide(), In(0, t), In(1, t)), @@ -304,7 +326,7 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpRoundToNegInfDivide(): Unit = { + test("ApplyBinaryPrimOpRoundToNegInfDivide") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(RoundToNegInfDivide(), In(0, t), In(1, t)), @@ -333,7 +355,7 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpBitAnd(): Unit = { + test("ApplyBinaryPrimOpBitAnd") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(BitAnd(), In(0, t), In(1, t)), @@ -358,7 +380,7 @@ class IRSuite extends HailSuite { assertExpected(TInt64, null, null, null) } - @Test def testApplyBinaryPrimOpBitOr(): Unit = { + test("ApplyBinaryPrimOpBitOr") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(BitOr(), In(0, t), In(1, t)), @@ -383,7 +405,7 @@ class IRSuite extends HailSuite { assertExpected(TInt64, null, null, null) } - @Test def testApplyBinaryPrimOpBitXOr(): Unit = { + test("ApplyBinaryPrimOpBitXOr") { def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(BitXOr(), In(0, t), In(1, t)), @@ -408,7 +430,7 @@ class IRSuite extends HailSuite { assertExpected(TInt64, null, null, null) } - @Test def testApplyBinaryPrimOpLeftShift(): Unit = { + test("ApplyBinaryPrimOpLeftShift") { def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(LeftShift(), In(0, t), In(1, TInt32)), @@ -429,7 +451,7 @@ class IRSuite extends HailSuite { assertShiftsTo(TInt64, null, null, null) } - @Test def testApplyBinaryPrimOpRightShift(): Unit = { + test("ApplyBinaryPrimOpRightShift") { def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(RightShift(), In(0, t), In(1, TInt32)), @@ -450,7 +472,7 @@ class IRSuite extends HailSuite { assertShiftsTo(TInt64, null, null, null) } - @Test def testApplyBinaryPrimOpLogicalRightShift(): Unit = { + test("ApplyBinaryPrimOpLogicalRightShift") { def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = assertEvalsTo( ApplyBinaryPrimOp(LogicalRightShift(), In(0, t), In(1, TInt32)), @@ -471,7 +493,7 @@ class IRSuite extends HailSuite { assertShiftsTo(TInt64, null, null, null) } - @Test def testApplyComparisonOpGT(): Unit = { + test("ApplyComparisonOpGT") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(GT, In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) @@ -493,7 +515,7 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpGTEQ(): Unit = { + test("ApplyComparisonOpGTEQ") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo( ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), @@ -518,7 +540,7 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, true) } - @Test def testApplyComparisonOpLT(): Unit = { + test("ApplyComparisonOpLT") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(LT, In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) @@ -540,7 +562,7 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpLTEQ(): Unit = { + test("ApplyComparisonOpLTEQ") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo( ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), @@ -566,7 +588,7 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpEQ(): Unit = { + test("ApplyComparisonOpEQ") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(EQ, In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) @@ -587,7 +609,7 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, expected = false) } - @Test def testApplyComparisonOpNE(): Unit = { + test("ApplyComparisonOpNE") { def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo( ApplyComparisonOp(NEQ, In(0, t), In(1, t)), @@ -612,33 +634,39 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, expected = true) } - @Test def testDieCodeBUilder(): Unit = + test("DieCodeBUilder") { assertFatal(Die("msg1", TInt32) + Die("msg2", TInt32), "msg1") - @Test def testIf(): Unit = { + } + + test("If") { assertEvalsTo(If(True(), I32(5), I32(7)), 5) assertEvalsTo(If(False(), I32(5), I32(7)), 7) assertEvalsTo(If(NA(TBoolean), I32(5), I32(7)), null) assertEvalsTo(If(True(), NA(TInt32), I32(7)), null) } - @DataProvider(name = "SwitchEval") - def switchEvalRules: Array[Array[Any]] = - Array( - Array(I32(-1), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), Int.MinValue), - Array(I32(0), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), 0), - Array(I32(1), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), Int.MaxValue), - Array(I32(2), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), Int.MinValue), - Array(NA(TInt32), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), null), - Array(I32(-1), NA(TInt32), FastSeq(0, Int.MaxValue).map(I32), null), - Array(I32(0), NA(TInt32), FastSeq(NA(TInt32), I32(0)), null), - ) + object checkSwitch extends TestCases { + def apply( + x: IR, + default: IR, + cases: IndexedSeq[IR], + result: Any, + )(implicit loc: munit.Location + ): Unit = test(s"Switch") { + assertEvalsTo(Switch(x, default, cases), result) + } + } - @Test(dataProvider = "SwitchEval") - def testSwitch(x: IR, default: IR, cases: IndexedSeq[IR], result: Any): Unit = - assertEvalsTo(Switch(x, default, cases), result) + checkSwitch(-1, Int.MinValue, FastSeq[IR](0, Int.MaxValue), Int.MinValue) + checkSwitch(0, Int.MinValue, FastSeq[IR](0, Int.MaxValue), 0) + checkSwitch(1, Int.MinValue, FastSeq[IR](0, Int.MaxValue), Int.MaxValue) + checkSwitch(2, Int.MinValue, FastSeq[IR](0, Int.MaxValue), Int.MinValue) + checkSwitch(NA(TInt32), Int.MinValue, FastSeq[IR](0, Int.MaxValue), null) + checkSwitch(-1, NA(TInt32), FastSeq[IR](0, Int.MaxValue), null) + checkSwitch(0, NA(TInt32), FastSeq[IR](NA(TInt32), 0), null) - @Test def testLet(): Unit = { + test("Let") { assertEvalsTo(bindIR(I32(5))(x => x), 5) assertEvalsTo(bindIR(NA(TInt32))(x => x), null) assertEvalsTo(bindIR(I32(5))(_ => NA(TInt32)), null) @@ -670,7 +698,7 @@ class IRSuite extends HailSuite { ) } - @Test def testMakeArray(): Unit = { + test("MakeArray") { assertEvalsTo( MakeArray(FastSeq(I32(5), NA(TInt32), I32(-3)), TArray(TInt32)), FastSeq(5, null, -3), @@ -678,218 +706,218 @@ class IRSuite extends HailSuite { assertEvalsTo(MakeArray(FastSeq(), TArray(TInt32)), FastSeq()) } - @Test def testGetNestedElementPTypesI32(): Unit = { + test("GetNestedElementPTypesI32") { var types = IndexedSeq(PInt32(true)) var res = InferPType.getCompatiblePType(types) - assert(res == PInt32(true)) + assertEquals(res, PInt32(true)) types = IndexedSeq(PInt32(false)) res = InferPType.getCompatiblePType(types) - assert(res == PInt32(false)) + assertEquals(res, PInt32(false)) types = IndexedSeq(PInt32(false), PInt32(true)) res = InferPType.getCompatiblePType(types) - assert(res == PInt32(false)) + assertEquals(res, PInt32(false)) types = IndexedSeq(PInt32(true), PInt32(true)) res = InferPType.getCompatiblePType(types) - assert(res == PInt32(true)) + assertEquals(res, PInt32(true)) } - @Test def testGetNestedElementPTypesI64(): Unit = { + test("GetNestedElementPTypesI64") { var types = IndexedSeq(PInt64(true)) var res = InferPType.getCompatiblePType(types) - assert(res == PInt64(true)) + assertEquals(res, PInt64(true)) types = IndexedSeq(PInt64(false)) res = InferPType.getCompatiblePType(types) - assert(res == PInt64(false)) + assertEquals(res, PInt64(false)) types = IndexedSeq(PInt64(false), PInt64(true)) res = InferPType.getCompatiblePType(types) - assert(res == PInt64(false)) + assertEquals(res, PInt64(false)) types = IndexedSeq(PInt64(true), PInt64(true)) res = InferPType.getCompatiblePType(types) - assert(res == PInt64(true)) + assertEquals(res, PInt64(true)) } - @Test def testGetNestedElementPFloat32(): Unit = { + test("GetNestedElementPFloat32") { var types = IndexedSeq(PFloat32(true)) var res = InferPType.getCompatiblePType(types) - assert(res == PFloat32(true)) + assertEquals(res, PFloat32(true)) types = IndexedSeq(PFloat32(false)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat32(false)) + assertEquals(res, PFloat32(false)) types = IndexedSeq(PFloat32(false), PFloat32(true)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat32(false)) + assertEquals(res, PFloat32(false)) types = IndexedSeq(PFloat32(true), PFloat32(true)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat32(true)) + assertEquals(res, PFloat32(true)) } - @Test def testGetNestedElementPFloat64(): Unit = { + test("GetNestedElementPFloat64") { var types = IndexedSeq(PFloat64(true)) var res = InferPType.getCompatiblePType(types) - assert(res == PFloat64(true)) + assertEquals(res, PFloat64(true)) types = IndexedSeq(PFloat64(false)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat64(false)) + assertEquals(res, PFloat64(false)) types = IndexedSeq(PFloat64(false), PFloat64(true)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat64(false)) + assertEquals(res, PFloat64(false)) types = IndexedSeq(PFloat64(true), PFloat64(true)) res = InferPType.getCompatiblePType(types) - assert(res == PFloat64(true)) + assertEquals(res, PFloat64(true)) } - @Test def testGetNestedElementPCanonicalString(): Unit = { + test("GetNestedElementPCanonicalString") { var types = IndexedSeq(PCanonicalString(true)) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalString(true)) + assertEquals(res, PCanonicalString(true)) types = IndexedSeq(PCanonicalString(false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalString(false)) + assertEquals(res, PCanonicalString(false)) types = IndexedSeq(PCanonicalString(false), PCanonicalString(true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalString(false)) + assertEquals(res, PCanonicalString(false)) types = IndexedSeq(PCanonicalString(true), PCanonicalString(true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalString(true)) + assertEquals(res, PCanonicalString(true)) } - @Test def testGetNestedPCanonicalArray(): Unit = { + test("GetNestedPCanonicalArray") { var types = IndexedSeq(PCanonicalArray(PInt32(true), true)) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(true), true)) + assertEquals(res, PCanonicalArray(PInt32(true), true)) types = IndexedSeq(PCanonicalArray(PInt32(true), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(true), false)) + assertEquals(res, PCanonicalArray(PInt32(true), false)) types = IndexedSeq(PCanonicalArray(PInt32(false), true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(false), true)) + assertEquals(res, PCanonicalArray(PInt32(false), true)) types = IndexedSeq(PCanonicalArray(PInt32(false), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(false), false)) + assertEquals(res, PCanonicalArray(PInt32(false), false)) types = IndexedSeq( PCanonicalArray(PInt32(true), true), PCanonicalArray(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(true), true)) + assertEquals(res, PCanonicalArray(PInt32(true), true)) types = IndexedSeq( PCanonicalArray(PInt32(false), true), PCanonicalArray(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(false), true)) + assertEquals(res, PCanonicalArray(PInt32(false), true)) types = IndexedSeq( PCanonicalArray(PInt32(false), true), PCanonicalArray(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PInt32(false), false)) + assertEquals(res, PCanonicalArray(PInt32(false), false)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), true), true), PCanonicalArray(PCanonicalArray(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PCanonicalArray(PInt32(true), true), true)) + assertEquals(res, PCanonicalArray(PCanonicalArray(PInt32(true), true), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), true), true), PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), true), true)) + assertEquals(res, PCanonicalArray(PCanonicalArray(PInt32(false), true), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), false), true), PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), false), true)) + assertEquals(res, PCanonicalArray(PCanonicalArray(PInt32(false), false), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), false), false), PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), false), false)) + assertEquals(res, PCanonicalArray(PCanonicalArray(PInt32(false), false), false)) } - @Test def testGetNestedElementPCanonicalDict(): Unit = { + test("GetNestedElementPCanonicalDict") { var types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(true), true)) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), true)) + assertEquals(res, PCanonicalDict(PInt32(true), PCanonicalString(true), true)) types = IndexedSeq(PCanonicalDict(PInt32(false), PCanonicalString(true), true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(false), PCanonicalString(true), true)) + assertEquals(res, PCanonicalDict(PInt32(false), PCanonicalString(true), true)) types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(false), true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalString(false), true)) + assertEquals(res, PCanonicalDict(PInt32(true), PCanonicalString(false), true)) types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(true), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), false)) + assertEquals(res, PCanonicalDict(PInt32(true), PCanonicalString(true), false)) types = IndexedSeq(PCanonicalDict(PInt32(false), PCanonicalString(false), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), false)) + assertEquals(res, PCanonicalDict(PInt32(false), PCanonicalString(false), false)) types = IndexedSeq( PCanonicalDict(PInt32(true), PCanonicalString(true), true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), true)) + assertEquals(res, PCanonicalDict(PInt32(true), PCanonicalString(true), true)) types = IndexedSeq( PCanonicalDict(PInt32(true), PCanonicalString(true), false), PCanonicalDict(PInt32(true), PCanonicalString(true), false), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), false)) + assertEquals(res, PCanonicalDict(PInt32(true), PCanonicalString(true), false)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(false), PCanonicalString(true), true)) + assertEquals(res, PCanonicalDict(PInt32(false), PCanonicalString(true), true)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), true), PCanonicalDict(PInt32(true), PCanonicalString(false), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), true)) + assertEquals(res, PCanonicalDict(PInt32(false), PCanonicalString(false), true)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), false), PCanonicalDict(PInt32(true), PCanonicalString(false), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), false)) + assertEquals(res, PCanonicalDict(PInt32(false), PCanonicalString(false), false)) types = IndexedSeq( PCanonicalDict( @@ -900,11 +928,14 @@ class IRSuite extends HailSuite { PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict( - PInt32(true), - PCanonicalDict(PInt32(true), PCanonicalString(true), true), - true, - )) + assertEquals( + res, + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(true), true), + true, + ), + ) types = IndexedSeq( PCanonicalDict( @@ -915,11 +946,14 @@ class IRSuite extends HailSuite { PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict( - PInt32(true), - PCanonicalDict(PInt32(false), PCanonicalString(true), true), - true, - )) + assertEquals( + res, + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), true), + true, + ), + ) types = IndexedSeq( PCanonicalDict( @@ -934,11 +968,14 @@ class IRSuite extends HailSuite { ), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict( - PInt32(true), - PCanonicalDict(PInt32(false), PCanonicalString(false), true), - true, - )) + assertEquals( + res, + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), true), + true, + ), + ) types = IndexedSeq( PCanonicalDict( @@ -953,11 +990,14 @@ class IRSuite extends HailSuite { ), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict( - PInt32(true), - PCanonicalDict(PInt32(false), PCanonicalString(false), true), - true, - )) + assertEquals( + res, + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), true), + true, + ), + ) types = IndexedSeq( PCanonicalDict( @@ -972,54 +1012,57 @@ class IRSuite extends HailSuite { ), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict( - PInt32(true), - PCanonicalDict(PInt32(false), PCanonicalString(false), false), - true, - )) + assertEquals( + res, + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), false), + true, + ), + ) } - @Test def testGetNestedElementPCanonicalStruct(): Unit = { + test("GetNestedElementPCanonicalStruct") { var types = IndexedSeq(PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) + assertEquals(res, PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) types = IndexedSeq(PCanonicalStruct(false, "a" -> PInt32(true), "b" -> PInt32(true))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(false, "a" -> PInt32(true), "b" -> PInt32(true))) + assertEquals(res, PCanonicalStruct(false, "a" -> PInt32(true), "b" -> PInt32(true))) types = IndexedSeq(PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(true))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(true))) + assertEquals(res, PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(true))) types = IndexedSeq(PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(false))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(false))) + assertEquals(res, PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(false))) types = IndexedSeq(PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) + assertEquals(res, PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) types = IndexedSeq( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) + assertEquals(res, PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) types = IndexedSeq( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false))) + assertEquals(res, PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false))) types = IndexedSeq( PCanonicalStruct(false, "a" -> PInt32(true), "b" -> PInt32(true)), PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) + assertEquals(res, PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) types = IndexedSeq( PCanonicalStruct( @@ -1029,11 +1072,14 @@ class IRSuite extends HailSuite { ) ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct( - true, - "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), - "b" -> PInt32(true), - )) + assertEquals( + res, + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), + "b" -> PInt32(true), + ), + ) types = IndexedSeq( PCanonicalStruct( @@ -1043,11 +1089,14 @@ class IRSuite extends HailSuite { ) ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct( - true, - "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)), - "b" -> PInt32(true), - )) + assertEquals( + res, + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)), + "b" -> PInt32(true), + ), + ) types = IndexedSeq( PCanonicalStruct( @@ -1062,11 +1111,14 @@ class IRSuite extends HailSuite { ), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct( - true, - "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), - "b" -> PInt32(true), - )) + assertEquals( + res, + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + ), + ) types = IndexedSeq( PCanonicalStruct( @@ -1081,208 +1133,211 @@ class IRSuite extends HailSuite { ), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct( - true, - "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), - "b" -> PInt32(true), - )) + assertEquals( + res, + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + ), + ) } - @Test def testGetNestedElementPCanonicalTuple(): Unit = { + test("GetNestedElementPCanonicalTuple") { var types = IndexedSeq(PCanonicalTuple(true, PInt32(true))) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(true, PInt32(true))) + assertEquals(res, PCanonicalTuple(true, PInt32(true))) types = IndexedSeq(PCanonicalTuple(false, PInt32(true))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(false, PInt32(true))) + assertEquals(res, PCanonicalTuple(false, PInt32(true))) types = IndexedSeq(PCanonicalTuple(true, PInt32(false))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(true, PInt32(false))) + assertEquals(res, PCanonicalTuple(true, PInt32(false))) types = IndexedSeq(PCanonicalTuple(false, PInt32(false))) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(false, PInt32(false))) + assertEquals(res, PCanonicalTuple(false, PInt32(false))) types = IndexedSeq( PCanonicalTuple(true, PInt32(true)), PCanonicalTuple(true, PInt32(true)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(true, PInt32(true))) + assertEquals(res, PCanonicalTuple(true, PInt32(true))) types = IndexedSeq( PCanonicalTuple(true, PInt32(true)), PCanonicalTuple(false, PInt32(true)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(false, PInt32(true))) + assertEquals(res, PCanonicalTuple(false, PInt32(true))) types = IndexedSeq( PCanonicalTuple(true, PInt32(false)), PCanonicalTuple(false, PInt32(true)), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(false, PInt32(false))) + assertEquals(res, PCanonicalTuple(false, PInt32(false))) types = IndexedSeq( PCanonicalTuple(true, PCanonicalTuple(true, PInt32(true))), PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false)))) + assertEquals(res, PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false)))) types = IndexedSeq( PCanonicalTuple(true, PCanonicalTuple(false, PInt32(true))), PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalTuple(true, PCanonicalTuple(false, PInt32(false)))) + assertEquals(res, PCanonicalTuple(true, PCanonicalTuple(false, PInt32(false)))) } - @Test def testGetNestedElementPCanonicalSet(): Unit = { + test("GetNestedElementPCanonicalSet") { var types = IndexedSeq(PCanonicalSet(PInt32(true), true)) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(true), true)) + assertEquals(res, PCanonicalSet(PInt32(true), true)) types = IndexedSeq(PCanonicalSet(PInt32(true), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(true), false)) + assertEquals(res, PCanonicalSet(PInt32(true), false)) types = IndexedSeq(PCanonicalSet(PInt32(false), true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(false), true)) + assertEquals(res, PCanonicalSet(PInt32(false), true)) types = IndexedSeq(PCanonicalSet(PInt32(false), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(false), false)) + assertEquals(res, PCanonicalSet(PInt32(false), false)) types = IndexedSeq( PCanonicalSet(PInt32(true), true), PCanonicalSet(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(true), true)) + assertEquals(res, PCanonicalSet(PInt32(true), true)) types = IndexedSeq( PCanonicalSet(PInt32(false), true), PCanonicalSet(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(false), true)) + assertEquals(res, PCanonicalSet(PInt32(false), true)) types = IndexedSeq( PCanonicalSet(PInt32(false), true), PCanonicalSet(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PInt32(false), false)) + assertEquals(res, PCanonicalSet(PInt32(false), false)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), true), true), PCanonicalSet(PCanonicalSet(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PCanonicalSet(PInt32(true), true), true)) + assertEquals(res, PCanonicalSet(PCanonicalSet(PInt32(true), true), true)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), true), true), PCanonicalSet(PCanonicalSet(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PCanonicalSet(PInt32(false), true), true)) + assertEquals(res, PCanonicalSet(PCanonicalSet(PInt32(false), true), true)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), false), true), PCanonicalSet(PCanonicalSet(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalSet(PCanonicalSet(PInt32(false), false), true)) + assertEquals(res, PCanonicalSet(PCanonicalSet(PInt32(false), false), true)) } - @Test def testGetNestedElementPCanonicalInterval(): Unit = { + test("GetNestedElementPCanonicalInterval") { var types = IndexedSeq(PCanonicalInterval(PInt32(true), true)) var res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(true), true)) + assertEquals(res, PCanonicalInterval(PInt32(true), true)) types = IndexedSeq(PCanonicalInterval(PInt32(true), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(true), false)) + assertEquals(res, PCanonicalInterval(PInt32(true), false)) types = IndexedSeq(PCanonicalInterval(PInt32(false), true)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(false), true)) + assertEquals(res, PCanonicalInterval(PInt32(false), true)) types = IndexedSeq(PCanonicalInterval(PInt32(false), false)) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(false), false)) + assertEquals(res, PCanonicalInterval(PInt32(false), false)) types = IndexedSeq( PCanonicalInterval(PInt32(true), true), PCanonicalInterval(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(true), true)) + assertEquals(res, PCanonicalInterval(PInt32(true), true)) types = IndexedSeq( PCanonicalInterval(PInt32(false), true), PCanonicalInterval(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(false), true)) + assertEquals(res, PCanonicalInterval(PInt32(false), true)) types = IndexedSeq( PCanonicalInterval(PInt32(true), true), PCanonicalInterval(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(true), false)) + assertEquals(res, PCanonicalInterval(PInt32(true), false)) types = IndexedSeq( PCanonicalInterval(PInt32(false), true), PCanonicalInterval(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PInt32(false), false)) + assertEquals(res, PCanonicalInterval(PInt32(false), false)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true)) + assertEquals(res, PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true), PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true)) + assertEquals(res, PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true), PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true)) + assertEquals(res, PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true), PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(false), false), true)) + assertEquals(res, PCanonicalInterval(PCanonicalInterval(PInt32(false), false), true)) } - @Test def testMakeStruct(): Unit = { + test("MakeStruct") { assertEvalsTo(MakeStruct(FastSeq()), Row()) assertEvalsTo(MakeStruct(FastSeq("a" -> NA(TInt32), "b" -> 4, "c" -> 0.5)), Row(null, 4, 0.5)) // making sure wide structs get emitted without failure assertEvalsTo(GetField(MakeStruct((0 until 20000).map(i => s"foo$i" -> I32(1))), "foo1"), 1) } - @Test def testMakeArrayWithDifferentRequiredness(): Unit = { + test("MakeArrayWithDifferentRequiredness") { val pt1 = PCanonicalArray(PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalArray(PInt32()))) val pt2 = PCanonicalArray(PCanonicalStruct( true, @@ -1301,14 +1356,14 @@ class IRSuite extends HailSuite { ) } - @Test def testMakeTuple(): Unit = { + test("MakeTuple") { assertEvalsTo(MakeTuple.ordered(FastSeq()), Row()) assertEvalsTo(MakeTuple.ordered(FastSeq(NA(TInt32), 4, 0.5)), Row(null, 4, 0.5)) // making sure wide structs get emitted without failure assertEvalsTo(GetTupleElement(MakeTuple.ordered((0 until 20000).map(I32)), 1), 1) } - @Test def testGetTupleElement(): Unit = { + test("GetTupleElement") { implicit val execStrats = ExecStrategy.javaOnly val t = MakeTuple.ordered(FastSeq(I32(5), Str("abc"), NA(TInt32))) @@ -1320,7 +1375,7 @@ class IRSuite extends HailSuite { assertEvalsTo(GetTupleElement(na, 0), null) } - @Test def testLetBoundPrunedTuple(): Unit = { + test("LetBoundPrunedTuple") { implicit val execStrats = ExecStrategy.unoptimizedCompileOnly val t2 = MakeTuple(FastSeq((2, I32(5)))) @@ -1329,7 +1384,7 @@ class IRSuite extends HailSuite { assertEvalsTo(letBoundTuple, 5) } - @Test def testArrayRef(): Unit = { + test("ArrayRef") { assertEvalsTo( ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), I32(0), ErrorIDs.NO_ERROR), 5, @@ -1353,13 +1408,13 @@ class IRSuite extends HailSuite { ) } - @Test def testArrayLen(): Unit = { + test("ArrayLen") { assertEvalsTo(ArrayLen(NA(TArray(TInt32))), null) assertEvalsTo(ArrayLen(MakeArray(FastSeq(), TArray(TInt32))), 0) assertEvalsTo(ArrayLen(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32))), 2) } - @Test def testArraySort(): Unit = { + test("ArraySort") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ArraySort(ToStream(NA(TArray(TInt32)))), null) @@ -1369,7 +1424,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ArraySort(ToStream(a), False()), FastSeq(2, 2, -7, null)) } - @Test def testStreamZip(): Unit = { + test("StreamZip") { val range12 = StreamRange(0, 12, 1) val range6 = StreamRange(0, 12, 2) val range8 = StreamRange(0, 24, 3) @@ -1417,7 +1472,7 @@ class IRSuite extends HailSuite { ) } - @Test def testToSet(): Unit = { + test("ToSet") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ToSet(ToStream(NA(TArray(TInt32)))), null) @@ -1427,13 +1482,13 @@ class IRSuite extends HailSuite { assertEvalsTo(ToSet(ToStream(a)), Set(-7, 2, null)) } - @Test def testToArrayFromSet(): Unit = { + test("ToArrayFromSet") { val t = TSet(TInt32) assertEvalsTo(CastToArray(NA(t)), null) assertEvalsTo(CastToArray(In(0, t)), FastSeq((Set(-7, 2, null), t)), FastSeq(-7, 2, null)) } - @Test def testToDict(): Unit = { + test("ToDict") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ToDict(ToStream(NA(TArray(TTuple(FastSeq(TInt32, TString): _*))))), null) @@ -1452,7 +1507,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ToDict(ToStream(a)), Map(5 -> "a", (null, "b"), 3 -> null)) } - @Test def testToArrayFromDict(): Unit = { + test("ToArrayFromDict") { val t = TDict(TInt32, TString) assertEvalsTo(CastToArray(NA(t)), null) @@ -1465,13 +1520,13 @@ class IRSuite extends HailSuite { ) } - @Test def testToArrayFromArray(): Unit = { + test("ToArrayFromArray") { val t = TArray(TInt32) assertEvalsTo(NA(t), null) assertEvalsTo(In(0, t), FastSeq((FastSeq(-7, 2, null, 2), t)), FastSeq(-7, 2, null, 2)) } - @Test def testSetContains(): Unit = { + test("SetContains") { implicit val execStrats = ExecStrategy.javaOnly val t = TSet(TInt32) @@ -1495,7 +1550,7 @@ class IRSuite extends HailSuite { assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(7)), FastSeq((Set(-7, 2), t)), false) } - @Test def testDictContains(): Unit = { + test("DictContains") { implicit val execStrats = ExecStrategy.javaOnly val t = TDict(TInt32, TString) @@ -1512,7 +1567,7 @@ class IRSuite extends HailSuite { ) } - @Test def testLowerBoundOnOrderedCollectionArray(): Unit = { + test("LowerBoundOnOrderedCollectionArray") { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TArray(TInt32)) @@ -1538,7 +1593,7 @@ class IRSuite extends HailSuite { ) } - @Test def testLowerBoundOnOrderedCollectionSet(): Unit = { + test("LowerBoundOnOrderedCollectionSet") { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TSet(TInt32)) @@ -1560,7 +1615,7 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(swna, I32(5), onKey = false), 3) } - @Test def testLowerBoundOnOrderedCollectionDict(): Unit = { + test("LowerBoundOnOrderedCollectionDict") { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TDict(TInt32, TString)) @@ -1580,7 +1635,7 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(dwoutna, NA(TInt32), onKey = true), 2) } - @Test def testStreamLen(): Unit = { + test("StreamLen") { val a = StreamLen(MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32))) assertEvalsTo(a, 3) @@ -1613,7 +1668,7 @@ class IRSuite extends HailSuite { assertEvalsTo(lenOfLet, 7) } - @Test def testStreamLenUnconsumedInnerStream(): Unit = + test("StreamLenUnconsumedInnerStream") { assertEvalsTo( StreamLen( mapIR(StreamGrouped(filterIR(rangeIR(10))(x => x.cne(I32(0))), 3))(group => ToArray(group)) @@ -1621,7 +1676,9 @@ class IRSuite extends HailSuite { 3, ) - @Test def testStreamTake(): Unit = { + } + + test("StreamTake") { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1634,7 +1691,7 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(StreamTake(a, 2)), 2) } - @Test def testStreamDrop(): Unit = { + test("StreamDrop") { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1650,7 +1707,7 @@ class IRSuite extends HailSuite { def toNestedArray(stream: IR): IR = ToArray(mapIR(stream)(ToArray)) - @Test def testStreamGrouped(): Unit = { + test("StreamGrouped") { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1689,7 +1746,7 @@ class IRSuite extends HailSuite { assertEvalsTo(toNestedArray(takeFromEach(r, I32(0), I32(5))), FastSeq(FastSeq(), FastSeq())) } - @Test def testStreamGroupByKey(): Unit = { + test("StreamGroupByKey") { val structType = TStruct("a" -> TInt32, "b" -> TInt32) val naa = NA(TStream(structType)) val a = MakeStream( @@ -1748,7 +1805,7 @@ class IRSuite extends HailSuite { ) } - @Test def testStreamMap(): Unit = { + test("StreamMap") { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1756,7 +1813,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(mapIR(a)(_ + I32(1))), FastSeq(4, null, 8)) } - @Test def testStreamFilter(): Unit = { + test("StreamFilter") { val nsa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1772,7 +1829,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(filterIR(a)(_ < I32(6))), FastSeq(3)) } - @Test def testArrayFlatMap(): Unit = { + test("ArrayFlatMap") { val ta = TArray(TInt32) val ts = TStream(TInt32) val tsa = TStream(ta) @@ -1803,7 +1860,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(flatMapIR(st)(foo => rangeIR(-1, foo))), expected) } - @Test def testStreamFold(): Unit = { + test("StreamFold") { assertEvalsTo(foldIR(StreamRange(1, 2, 1), NA(TBoolean))((accum, elt) => IsNA(accum)), true) assertEvalsTo(foldIR(IRStream(1, 2, 3), 0)((accum, elt) => accum + elt), 6) assertEvalsTo( @@ -1821,7 +1878,7 @@ class IRSuite extends HailSuite { ) } - @Test def testArrayFold2(): Unit = { + test("ArrayFold2") { implicit val execStrats = ExecStrategy.compileOnly val af = fold2IR( @@ -1842,7 +1899,7 @@ class IRSuite extends HailSuite { assertEvalsTo(af, FastSeq((FastSeq(1, 2, 3), TArray(TInt32))), Row(6, 1)) } - @Test def testArrayScan(): Unit = { + test("ArrayScan") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo( @@ -1900,7 +1957,7 @@ class IRSuite extends HailSuite { val cubeRowMajor = makeNDArray((0 until 27).map(_.toDouble), FastSeq(3, 3, 3), True()) val cubeColMajor = makeNDArray((0 until 27).map(_.toDouble), FastSeq(3, 3, 3), False()) - @Test def testNDArrayShape(): Unit = { + test("NDArrayShape") { implicit val execStrats = ExecStrategy.compileOnly assertEvalsTo(NDArrayShape(scalarRowMajor), Row()) @@ -1908,7 +1965,7 @@ class IRSuite extends HailSuite { assertEvalsTo(NDArrayShape(cubeRowMajor), Row(3L, 3L, 3L)) } - @Test def testNDArrayRef(): Unit = { + test("NDArrayRef") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly assertEvalsTo(makeNDArrayRef(scalarRowMajor, FastSeq()), 3.0) @@ -1936,7 +1993,7 @@ class IRSuite extends HailSuite { assertEvalsTo(centerColMajor, 13.0) } - @Test def testNDArrayReshape(): Unit = { + test("NDArrayReshape") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val v = NDArrayReshape(matrixRowMajor, MakeTuple.ordered(IndexedSeq(I64(4))), ErrorIDs.NO_ERROR) @@ -1948,7 +2005,7 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(mat2, FastSeq(0, 0)), 1.0) } - @Test def testNDArrayConcat(): Unit = { + test("NDArrayConcat") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly def nds(ndData: (IndexedSeq[Int], Long, Long)*): IR = @@ -2027,7 +2084,7 @@ class IRSuite extends HailSuite { assertNDEvals(NDArrayConcat(NA(TArray(TNDArray(TInt32, Nat(2)))), 1), null) } - @Test def testNDArrayMap(): Unit = { + test("NDArrayMap") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val data = 0 until 10 @@ -2061,7 +2118,7 @@ class IRSuite extends HailSuite { assertEvalsTo(zero, 0L) } - @Test def testNDArrayMap2(): Unit = { + test("NDArrayMap2") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val shape = MakeTuple.ordered(FastSeq(2L, 2L).map(I64)) @@ -2085,7 +2142,7 @@ class IRSuite extends HailSuite { assertEvalsTo(twentyTwo, 22.0) } - @Test def testNDArrayReindex(): Unit = { + test("NDArrayReindex") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val transpose = NDArrayReindex(matrixRowMajor, FastSeq(1, 0)) @@ -2108,7 +2165,7 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(partialTranspose, partialTranposeIdx), 3.0) } - @Test def testNDArrayBroadcasting(): Unit = { + test("NDArrayBroadcasting") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly def sum(nd1: IR, nd2: IR): IR = ndMap2(nd1, nd2)(_ + _) @@ -2141,7 +2198,7 @@ class IRSuite extends HailSuite { assertEvalsTo(NDArrayShape(colVectorWithEmpty), Row(2L, 0L)) } - @Test def testNDArrayAgg(): Unit = { + test("NDArrayAgg") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val empty = makeNDArrayRef( @@ -2165,7 +2222,7 @@ class IRSuite extends HailSuite { assertEvalsTo(twentySeven, 3.0) } - @Test def testNDArrayMatMul(): Unit = { + test("NDArrayMatMul") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val dotProduct = NDArrayMatMul(vectorRowMajor, vectorRowMajor, ErrorIDs.NO_ERROR) @@ -2193,7 +2250,7 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(matMulCube, IndexedSeq(0, 0, 0)), 30.0) } - @Test def testNDArrayInv(): Unit = { + test("NDArrayInv") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val matrixRowMajor = makeNDArray(FastSeq(1.5, 2.0, 4.0, 5.0), FastSeq(2, 2), True()) val inv = NDArrayInv(matrixRowMajor, ErrorIDs.NO_ERROR) @@ -2201,7 +2258,7 @@ class IRSuite extends HailSuite { assertNDEvals(inv, expectedInv) } - @Test def testNDArraySlice(): Unit = { + test("NDArraySlice") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val rightCol = NDArraySlice( @@ -2230,7 +2287,7 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(scalarSlice, FastSeq()), 3.0) } - @Test def testNDArrayFilter(): Unit = { + test("NDArrayFilter") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly assertNDEvals( @@ -2326,7 +2383,7 @@ class IRSuite extends HailSuite { }) } - @Test def testStreamZipJoin(): Unit = { + test("StreamZipJoin") { def eltType = TStruct("k1" -> TInt32, "k2" -> TString, "idx" -> TInt32) def makeStream(a: IndexedSeq[Integer]): IR = { if (a == null) @@ -2402,7 +2459,7 @@ class IRSuite extends HailSuite { ) } - @Test def testStreamMultiMerge(): Unit = { + test("StreamMultiMerge") { def eltType = TStruct("k1" -> TInt32, "k2" -> TString, "idx" -> TInt32) def makeStream(a: IndexedSeq[Integer]): IR = { if (a == null) @@ -2480,7 +2537,7 @@ class IRSuite extends HailSuite { ) } - @Test def testJoinRightDistinct(): Unit = { + test("JoinRightDistinct") { implicit val execStrats = ExecStrategy.javaOnly def joinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], joinType: String): IR = { @@ -2599,7 +2656,7 @@ class IRSuite extends HailSuite { ) } - @Test def testStreamJoin(): Unit = { + test("StreamJoin") { implicit val execStrats = ExecStrategy.javaOnly def joinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], joinType: String): IR = { @@ -2704,7 +2761,7 @@ class IRSuite extends HailSuite { ) } - @Test def testStreamMerge(): Unit = { + test("StreamMerge") { implicit val execStrats = ExecStrategy.compileOnly def mergeRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], key: Int): IR = { @@ -2831,12 +2888,12 @@ class IRSuite extends HailSuite { assertEvalsTo(mergeRows(ArraySeq(1, 2, null), null, 1), null) } - @Test def testDie(): Unit = { + test("Die") { assertFatal(Die("mumblefoo", TFloat64), "mble") assertFatal(Die(NA(TString), TFloat64, -1), "message missing") } - @Test def testStreamRange(): Unit = { + test("StreamRange") { def assertEquals(start: Integer, stop: Integer, step: Integer, expected: IndexedSeq[Int]) : Unit = assertEvalsTo( @@ -2850,21 +2907,19 @@ class IRSuite extends HailSuite { assertFatal(ToArray(StreamRange(I32(0), I32(5), I32(0))), "step size") - forAll( - for { - start <- -2 to 2 - stop <- -2 to 8 - step <- 1 to 3 - } yield (start, stop, step) - ) { case (start, stop, step) => + for { + start <- -2 to 2 + stop <- -2 to 8 + step <- 1 to 3 + } { case (start, stop, step) => assertEquals(start, stop, step, expected = ArraySeq.range(start, stop, step)) assertEquals(start, stop, -step, expected = ArraySeq.range(start, stop, -step)) - } + }: (Int, Int, Int) => Unit val expected = ArraySeq.range(Int.MinValue, Int.MaxValue, Int.MaxValue / 5) assertEquals(Int.MinValue, Int.MaxValue, Int.MaxValue / 5, expected) } - @Test def testArrayAgg(): Unit = { + test("ArrayAgg") { implicit val execStrats = ExecStrategy.compileOnly assertEvalsTo( @@ -2873,7 +2928,7 @@ class IRSuite extends HailSuite { ) } - @Test def testArrayAggContexts(): Unit = { + test("ArrayAggContexts") { implicit val execStrats = ExecStrategy.compileOnly val ir = bindIR(In(0, TInt32) * In(0, TInt32)) { x => // multiply to prevent forwarding @@ -2906,7 +2961,7 @@ class IRSuite extends HailSuite { ) } - @Test def testStreamAggScan(): Unit = { + test("StreamAggScan") { implicit val execStrats = ExecStrategy.compileOnly val eltType = TStruct("x" -> TCall, "y" -> TInt32) @@ -2937,7 +2992,7 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(ir), args = FastSeq(input), 6) } - @Test def testInsertFields(): Unit = { + test("InsertFields") { implicit val execStrats = ExecStrategy.javaOnly val s = TStruct("a" -> TInt64, "b" -> TString) @@ -3036,7 +3091,7 @@ class IRSuite extends HailSuite { } - @Test def testSelectFields(): Unit = { + test("SelectFields") { assertEvalsTo( SelectFields( NA(TStruct("foo" -> TInt32, "bar" -> TFloat64)), @@ -3062,7 +3117,7 @@ class IRSuite extends HailSuite { ) } - @Test def testGetField(): Unit = { + test("GetField") { implicit val execStrats = ExecStrategy.javaOnly val s = MakeStruct(IndexedSeq("a" -> NA(TInt64), "b" -> Str("abc"))) @@ -3073,7 +3128,7 @@ class IRSuite extends HailSuite { assertEvalsTo(GetField(na, "a"), null) } - @Test def testLiteral(): Unit = { + test("Literal") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.JvmCompile) val poopEmoji = new String(Array[Char](0xd83d, 0xdca9)) @@ -3096,7 +3151,7 @@ class IRSuite extends HailSuite { assertEvalsTo(Str("hello" + poopEmoji), "hello" + poopEmoji) } - @Test def testSameLiteralsWithDifferentTypes(): Unit = { + test("SameLiteralsWithDifferentTypes") { assertEvalsTo( ApplyComparisonOp( EQ, @@ -3109,13 +3164,13 @@ class IRSuite extends HailSuite { ) } - @Test def testTableCount(): Unit = { + test("TableCount") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) assertEvalsTo(TableCount(TableRange(0, 4)), 0L) assertEvalsTo(TableCount(TableRange(7, 4)), 7L) } - @Test def testTableGetGlobals(): Unit = { + test("TableGetGlobals") { implicit val execStrats = ExecStrategy.interpretOnly assertEvalsTo( TableGetGlobals(TableMapGlobals(TableRange(0, 1), Literal(TStruct("a" -> TInt32), Row(1)))), @@ -3123,7 +3178,7 @@ class IRSuite extends HailSuite { ) } - @Test def testTableAggregate(): Unit = { + test("TableAggregate") { implicit val execStrats = ExecStrategy.allRelational val table = TableRange(3, 2) @@ -3131,7 +3186,7 @@ class IRSuite extends HailSuite { assertEvalsTo(TableAggregate(table, MakeStruct(IndexedSeq("foo" -> count))), Row(3L)) } - @Test def testMatrixAggregate(): Unit = { + test("MatrixAggregate") { implicit val execStrats = ExecStrategy.interpretOnly val matrix = MatrixIR.range(ctx, 5, 5, None) @@ -3139,7 +3194,7 @@ class IRSuite extends HailSuite { assertEvalsTo(MatrixAggregate(matrix, MakeStruct(IndexedSeq("foo" -> count))), Row(25L)) } - @Test def testGroupByKey(): Unit = { + test("GroupByKey") { implicit val execStrats = Set( ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, @@ -3168,100 +3223,82 @@ class IRSuite extends HailSuite { assertEvalsTo(groupby(), Map()) } - @DataProvider(name = "compareDifferentTypes") - def compareDifferentTypesData(): Array[Array[Any]] = Array( - Array(FastSeq(0.0, 0.0), TArray(TFloat64), TArray(TFloat64)), - Array(Set(0, 1), TSet(TInt32), TSet(TInt32)), - Array(Map(0L -> 5, 3L -> 20), TDict(TInt64, TInt32), TDict(TInt64, TInt32)), - Array( - Interval(1, 2, includesStart = false, includesEnd = true), - TInterval(TInt32), - TInterval(TInt32), - ), - Array( - Row("foo", 0.0), - TStruct("a" -> TString, "b" -> TFloat64), - TStruct("a" -> TString, "b" -> TFloat64), - ), - Array(Row("foo", 0.0), TTuple(TString, TFloat64), TTuple(TString, TFloat64)), - Array( - Row(FastSeq("foo"), 0.0), - TTuple(TArray(TString), TFloat64), - TTuple(TArray(TString), TFloat64), - ), + object checkComparisonOpDifferentTypes extends TestCases { + def apply(a: Any, t1: Type, t2: Type)(implicit loc: munit.Location): Unit = + test("ComparisonOpDifferentTypes") { + implicit val execStrats = ExecStrategy.javaOnly + + assertEvalsTo(ApplyComparisonOp(EQ, In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), true) + assertEvalsTo(ApplyComparisonOp(LT, In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) + assertEvalsTo(ApplyComparisonOp(GT, In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) + assertEvalsTo( + ApplyComparisonOp(LTEQ, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(GTEQ, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(NEQ, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(EQWithNA, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(NEQWithNA, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(Compare, In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + 0, + ) + } + } + + checkComparisonOpDifferentTypes(FastSeq(0.0, 0.0), TArray(TFloat64), TArray(TFloat64)) + checkComparisonOpDifferentTypes(Set(0, 1), TSet(TInt32), TSet(TInt32)) + + checkComparisonOpDifferentTypes( + Map(0L -> 5, 3L -> 20), + TDict(TInt64, TInt32), + TDict(TInt64, TInt32), ) - @Test(dataProvider = "compareDifferentTypes") - def testComparisonOpDifferentTypes(a: Any, t1: Type, t2: Type): Unit = { - implicit val execStrats = ExecStrategy.javaOnly + checkComparisonOpDifferentTypes( + Interval(1, 2, includesStart = false, includesEnd = true), + TInterval(TInt32), + TInterval(TInt32), + ) - assertEvalsTo( - ApplyComparisonOp(EQ, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - true, - ) - assertEvalsTo( - ApplyComparisonOp(LT, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - false, - ) - assertEvalsTo( - ApplyComparisonOp(GT, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - false, - ) - assertEvalsTo( - ApplyComparisonOp(LTEQ, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - true, - ) - assertEvalsTo( - ApplyComparisonOp(GTEQ, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - true, - ) - assertEvalsTo( - ApplyComparisonOp(NEQ, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - false, - ) - assertEvalsTo( - ApplyComparisonOp(EQWithNA, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - true, - ) - assertEvalsTo( - ApplyComparisonOp(NEQWithNA, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - false, - ) - assertEvalsTo( - ApplyComparisonOp(Compare, In(0, t1), In(1, t2)), - FastSeq(a -> t1, a -> t2), - 0, - ) - } + checkComparisonOpDifferentTypes( + Row("foo", 0.0), + TStruct("a" -> TString, "b" -> TFloat64), + TStruct("a" -> TString, "b" -> TFloat64), + ) - @DataProvider(name = "valueIRs") - def valueIRs(): Array[Array[Object]] = - valueIRs(ctx) + checkComparisonOpDifferentTypes( + Row("foo", 0.0), + TTuple(TString, TFloat64), + TTuple(TString, TFloat64), + ) - def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { - val fs = ctx.fs + checkComparisonOpDifferentTypes( + Row(FastSeq("foo"), 0.0), + TTuple(TArray(TString), TFloat64), + TTuple(TArray(TString), TFloat64), + ) - CompileAndEvaluate[Unit]( - ctx, - invoke( - "index_bgen", - TInt64, - ArraySeq(TLocus("GRCh37")), - Str(getTestResource("example.8bits.bgen")), - Str(getTestResource("example.8bits.bgen.idx2")), - Literal(TDict(TString, TString), Map("01" -> "1")), - False(), - I32(1000000), - ), - ) + { + lazy val fs = ctx.fs val b = True() val bin = Ref(freshName(), TBinary) @@ -3299,9 +3336,9 @@ class IRSuite extends HailSuite { val table = TableRange(100, 10) - val mt = MatrixIR.range(ctx, 20, 2, Some(3)) - val vcf = importVCF(ctx, getTestResource("sample.vcf")) - val bgenReader = MatrixBGENReader( + lazy val mt = MatrixIR.range(ctx, 20, 2, Some(3)) + lazy val vcf = importVCF(ctx, getTestResource("sample.vcf")) + lazy val bgenReader = MatrixBGENReader( ctx, FastSeq(getTestResource("example.8bits.bgen")), None, @@ -3310,9 +3347,9 @@ class IRSuite extends HailSuite { None, None, ) - val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) + lazy val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) - val blockMatrix = + lazy val blockMatrix = BlockMatrixRead(BlockMatrixNativeReader(fs, getTestResource("blockmatrix_example/0"))) val blockMatrixWriter = BlockMatrixNativeWriter("/path/to/file.bm", false, false, false) val blockMatrixMultiWriter = BlockMatrixBinaryMultiWriter("/path/to/prefix", false) @@ -3326,48 +3363,60 @@ class IRSuite extends HailSuite { def collect(ir: IR): IR = ApplyAggOp(Collect())(ir) - implicit def addEnv(ir: IR): (IR, BindingEnv[Type] => BindingEnv[Type]) = - (ir, env => env) - implicit def liftRefs(refs: IndexedSeq[Ref]): BindingEnv[Type] => BindingEnv[Type] = - env => env.bindEval(refs.map(r => r.name -> r.typ): _*) - - val irs = Array[(IR, BindingEnv[Type] => BindingEnv[Type])]( - i, - I64(5), - F32(3.14f), - F64(3.14), - str, - True(), - False(), - Void(), - UUID4(), - Cast(i, TFloat64), - CastRename(NA(TStruct("a" -> TInt32)), TStruct("b" -> TInt32)), - NA(TInt32), - IsNA(i), - If(b, i, j), - Switch(i, j, 0 until 7 map I32), - Coalesce(FastSeq(i, I32(1))), - bindIR(i)(v => v), - aggBindIR(i)(collect(_)) -> (_.createAgg), { - val x = Ref(freshName(), TInt32) - x -> (_.bindEval(x.name, x.typ)) - }, - ApplyBinaryPrimOp(Add(), i, j), - ApplyUnaryPrimOp(Negate, i), - ApplyComparisonOp(EQ, i, j), - MakeArray(FastSeq(i, NA(TInt32), I32(-3)), TArray(TInt32)), - MakeStream(FastSeq(i, NA(TInt32), I32(-3)), TStream(TInt32)), - nd, + def testParseIR(name: String, ir: => IR, refMap: BindingEnv[Type] = BindingEnv.empty): Unit = { + test(s"$name construction")(ir: Unit) + test(s"$name parsing") { + val ir0 = ir + val s = Pretty.sexprStyle(ir0, elideLiterals = false) + val x2 = IRParser.parse_value_ir(ctx, s, refMap) + assertEquals(x2, ir0) + } + } + + testParseIR("I32", i) + testParseIR("I32", I64(5)) + testParseIR("F32", F32(3.14f)) + testParseIR("F64", F64(3.14)) + testParseIR("Str", str) + testParseIR("True", True()) + testParseIR("False", False()) + testParseIR("Void", Void()) + testParseIR("UUID4", UUID4()) + testParseIR("Cast", Cast(i, TFloat64)) + testParseIR("NA", NA(TInt32)) + testParseIR("IsNA", IsNA(i)) + testParseIR("If", If(b, i, j)) + testParseIR("Switch", Switch(i, j, 0 until 7 map I32)) + testParseIR("Coalesce", Coalesce(FastSeq(i, I32(1)))) + testParseIR("Let", bindIR(i)(v => v)) + testParseIR("AggLet", aggBindIR(i)(collect(_)), BindingEnv.empty.createAgg) + + { + val x = Ref(freshName(), TInt32) + testParseIR("Ref", x, BindingEnv.eval(x.name -> x.typ)) + } + + testParseIR("CastRename", CastRename(NA(TStruct("a" -> TInt32)), TStruct("b" -> TInt32))) + testParseIR("ApplyBinaryPrimOp", ApplyBinaryPrimOp(Add(), i, j)) + testParseIR("ApplyUnaryPrimOp", ApplyUnaryPrimOp(Negate, i)) + testParseIR("ApplyComparisonOp", ApplyComparisonOp(EQ, i, j)) + testParseIR("MakeArray", MakeArray(FastSeq(i, NA(TInt32), I32(-3)), TArray(TInt32))) + testParseIR("MakeStream", MakeStream(FastSeq(i, NA(TInt32), I32(-3)), TStream(TInt32))) + testParseIR("MakeNDArray", nd) + testParseIR( + "NDArrayReshape", NDArrayReshape(nd, MakeTuple.ordered(IndexedSeq(I64(4))), ErrorIDs.NO_ERROR), - NDArrayConcat(MakeArray(FastSeq(nd, nd), TArray(nd.typ)), 0), - NDArrayRef(nd, FastSeq(I64(1), I64(2)), -1), - ndMap(nd)(v => -v), - ndMap2(nd, nd)(_ + _), - NDArrayReindex(nd, FastSeq(0, 1)), - NDArrayAgg(nd, FastSeq(0)), - NDArrayWrite(nd, Str("/path/to/ndarray")), - NDArrayMatMul(nd, nd, ErrorIDs.NO_ERROR), + ) + testParseIR("NDArrayConcat", NDArrayConcat(MakeArray(FastSeq(nd, nd), TArray(nd.typ)), 0)) + testParseIR("NDArrayRef", NDArrayRef(nd, FastSeq(I64(1), I64(2)), -1)) + testParseIR("NDArrayMap", ndMap(nd)(v => -v)) + testParseIR("NDArrayMap2", ndMap2(nd, nd)(_ + _)) + testParseIR("NDArrayReindex", NDArrayReindex(nd, FastSeq(0, 1))) + testParseIR("NDArrayAgg", NDArrayAgg(nd, FastSeq(0))) + testParseIR("NDArrayWrite", NDArrayWrite(nd, Str("/path/to/ndarray"))) + testParseIR("NDArrayMatMul", NDArrayMatMul(nd, nd, ErrorIDs.NO_ERROR)) + testParseIR( + "NDArraySlice", NDArraySlice( nd, MakeTuple.ordered(FastSeq( @@ -3375,42 +3424,66 @@ class IRSuite extends HailSuite { MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1))), )), ), - NDArrayFilter(nd, FastSeq(NA(TArray(TInt64)), NA(TArray(TInt64)))), - ArrayRef(a, i) -> ArraySeq(a), - ArrayLen(a) -> ArraySeq(a), - RNGSplit(rngState, MakeTuple.ordered(FastSeq(I64(1), I64(2), I64(3)))), - StreamLen(st) -> ArraySeq(st), - StreamRange(I32(0), I32(5), I32(1)), - StreamRange(I32(0), I32(5), I32(1)), - ArraySort(st, b) -> ArraySeq(st), - ToSet(st) -> ArraySeq(st), - ToDict(std) -> ArraySeq(std), - ToArray(st) -> ArraySeq(st), - CastToArray(NA(TSet(TInt32))), - ToStream(a) -> ArraySeq(a), - LowerBoundOnOrderedCollection(a, i, onKey = false) -> ArraySeq(a), - GroupByKey(std) -> ArraySeq(std), - StreamTake(st, I32(10)) -> ArraySeq(st), - StreamDrop(st, I32(10)) -> ArraySeq(st), - takeWhile(st)(_ < 5) -> ArraySeq(st), - dropWhile(st)(_ < 5) -> ArraySeq(st), - mapIR(st)(v => v) -> ArraySeq(st), - zipIR(FastSeq(st, st), ArrayZipBehavior.TakeMinLength)(_ => True()) -> ArraySeq(st), - filterIR(st)(_ => b) -> ArraySeq(st), - flatMapIR(sta)(ToStream(_)) -> ArraySeq(sta), - foldIR(st, I32(0))((_, v) => v) -> ArraySeq(st), - StreamFold2(foldIR(st, I32(0))((_, v) => v)) -> ArraySeq(st), - streamScanIR(st, I32(0))((_, v) => v) -> ArraySeq(st), - StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, false) -> ArraySeq( - whitenStream - ), + ) + testParseIR("NDArrayFilter", NDArrayFilter(nd, FastSeq(NA(TArray(TInt64)), NA(TArray(TInt64))))) + testParseIR("ArrayRef", ArrayRef(a, i), BindingEnv.eval(a.name -> a.typ)) + testParseIR("ArrayLen", ArrayLen(a), BindingEnv.eval(a.name -> a.typ)) + testParseIR("RNGSplit", RNGSplit(rngState, MakeTuple.ordered(FastSeq(I64(1), I64(2), I64(3))))) + testParseIR("StreamLen", StreamLen(st), BindingEnv.eval(st.name -> st.typ)) + testParseIR("StreamRange", StreamRange(I32(0), I32(5), I32(1))) + testParseIR("StreamRange", StreamRange(I32(0), I32(5), I32(1))) + testParseIR("ArraySort", ArraySort(st, b), BindingEnv.eval(st.name -> st.typ)) + testParseIR("ToSet", ToSet(st), BindingEnv.eval(st.name -> st.typ)) + testParseIR("ToDict", ToDict(std), BindingEnv.eval(std.name -> std.typ)) + testParseIR("ToArray", ToArray(st), BindingEnv.eval(st.name -> st.typ)) + testParseIR("CastToArray", CastToArray(NA(TSet(TInt32)))) + testParseIR("ToStream", ToStream(a), BindingEnv.eval(a.name -> a.typ)) + testParseIR( + "LowerBoundOnOrderedCollection", + LowerBoundOnOrderedCollection(a, i, onKey = false), + BindingEnv.eval(a.name -> a.typ), + ) + testParseIR("GroupByKey", GroupByKey(std), BindingEnv.eval(std.name -> std.typ)) + testParseIR("StreamTake", StreamTake(st, I32(10)), BindingEnv.eval(st.name -> st.typ)) + testParseIR("StreamDrop", StreamDrop(st, I32(10)), BindingEnv.eval(st.name -> st.typ)) + testParseIR("takeWhile", takeWhile(st)(_ < 5), BindingEnv.eval(st.name -> st.typ)) + testParseIR("dropWhile", dropWhile(st)(_ < 5), BindingEnv.eval(st.name -> st.typ)) + testParseIR("mapIR", mapIR(st)(v => v), BindingEnv.eval(st.name -> st.typ)) + testParseIR( + "zipIR", + zipIR(FastSeq(st, st), ArrayZipBehavior.TakeMinLength)(_ => True()), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR("filterIR", filterIR(st)(_ => b), BindingEnv.eval(st.name -> st.typ)) + testParseIR("flatMapIR", flatMapIR(sta)(ToStream(_)), BindingEnv.eval(sta.name -> sta.typ)) + testParseIR("foldIR", foldIR(st, I32(0))((_, v) => v), BindingEnv.eval(st.name -> st.typ)) + testParseIR( + "StreamFold2", + StreamFold2(foldIR(st, I32(0))((_, v) => v)), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR( + "streamScanIR", + streamScanIR(st, I32(0))((_, v) => v), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR( + "StreamWhiten", + StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, false), + BindingEnv.eval(whitenStream.name -> whitenStream.typ), + ) + testParseIR( + "joinRightDistinctIR", joinRightDistinctIR( mapIR(rangeIR(2))(x => MakeStruct(FastSeq("x" -> x))), mapIR(rangeIR(3))(x => MakeStruct(FastSeq("x" -> x))), FastSeq("x"), FastSeq("x"), "left", - )((_, _) => I32(1)), { + )((_, _) => I32(1)), + ) + testParseIR( + "StreamLeftIntervalJoin", { val left = mapIR(rangeIR(2))(x => MakeStruct(FastSeq("x" -> x))) val right = ToStream(Literal( TArray(TStruct("a" -> TInterval(TInt32))), @@ -3428,8 +3501,15 @@ class IRSuite extends HailSuite { InsertFields(lref, FastSeq("join" -> rref)), ) }, - forIR(st)(_ => Void()) -> ArraySeq(st), - streamAggIR(st)(x => ApplyAggOp(Sum())(Cast(x, TInt64))) -> ArraySeq(st), + ) + testParseIR("forIR", forIR(st)(_ => Void()), BindingEnv.eval(st.name -> st.typ)) + testParseIR( + "streamAggIR", + streamAggIR(st)(x => ApplyAggOp(Sum())(Cast(x, TInt64))), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR( + "StreamBufferedAggregate", StreamBufferedAggregate( st, Void(), @@ -3438,8 +3518,16 @@ class IRSuite extends HailSuite { l.name, ArraySeq(pCollectSig), 27, - ) -> ArraySeq(st), - streamAggScanIR(st)(x => ApplyScanOp(Sum())(Cast(x, TInt64))) -> ArraySeq(st), + ), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR( + "streamAggScanIR", + streamAggScanIR(st)(x => ApplyScanOp(Sum())(Cast(x, TInt64))), + BindingEnv.eval(st.name -> st.typ), + ) + testParseIR( + "RunAgg", RunAgg( Begin(FastSeq( InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), @@ -3447,7 +3535,10 @@ class IRSuite extends HailSuite { )), AggStateValue(0, groupSignature.state), FastSeq(groupSignature.state), - ), { + ), + ) + testParseIR( + "RunAggScan", { val foo = Ref(freshName(), TInt32) RunAggScan( StreamRange(I32(0), I32(1), I32(1)), @@ -3458,20 +3549,36 @@ class IRSuite extends HailSuite { FastSeq(groupSignature.state), ) }, - AggFilter(True(), I32(0), false) -> (_.createAgg), - aggExplodeIR(NA(TStream(TInt32)))(_ => I32(0)) -> (_.createAgg), - AggGroupBy(True(), I32(0), false) -> (_.createAgg), - ApplyAggOp(Collect())(I32(0)) -> (_.createAgg), - ApplyAggOp(CallStats(), I32(2))(call) -> (_.createAgg.bindAgg( - call.name, - call.typ, - )), - ApplyAggOp(TakeBy(), I32(10))(F64(-2.11), I32(4)) -> (_.createAgg), - AggFold(I32(0), l + I32(1), l + r, l.name, r.name, false) -> (_.createAgg), - InitOp(0, FastSeq(I32(2)), pCallStatsSig), - SeqOp(0, FastSeq(i), pCollectSig), - CombOp(0, 1, pCollectSig), - ResultOp(0, pCollectSig), { + ) + testParseIR("AggFilter", AggFilter(True(), I32(0), false), BindingEnv.empty.createAgg) + testParseIR( + "aggExplodeIR", + aggExplodeIR(NA(TStream(TInt32)))(_ => I32(0)), + BindingEnv.empty.createAgg, + ) + testParseIR("AggGroupBy", AggGroupBy(True(), I32(0), false), BindingEnv.empty.createAgg) + testParseIR("ApplyAggOp Collect", ApplyAggOp(Collect())(I32(0)), BindingEnv.empty.createAgg) + testParseIR( + "ApplyAggOp CallStats", + ApplyAggOp(CallStats(), I32(2))(call), + BindingEnv.empty.createAgg.bindAgg(call.name, call.typ), + ) + testParseIR( + "ApplyAggOp TakeBy", + ApplyAggOp(TakeBy(), I32(10))(F64(-2.11), I32(4)), + BindingEnv.empty.createAgg, + ) + testParseIR( + "AggFold", + AggFold(I32(0), l + I32(1), l + r, l.name, r.name, false), + BindingEnv.empty.createAgg, + ) + testParseIR("InitOp", InitOp(0, FastSeq(I32(2)), pCallStatsSig)) + testParseIR("SeqOp", SeqOp(0, FastSeq(i), pCollectSig)) + testParseIR("CombOp", CombOp(0, 1, pCollectSig)) + testParseIR("ResultOp", ResultOp(0, pCollectSig)) + testParseIR( + "ResultOp FoldStateSig", { val accum = Ref(freshName(), TInt32) ResultOp( 0, @@ -3481,40 +3588,79 @@ class IRSuite extends HailSuite { ), ) }, + ) + testParseIR( + "SerializeAggs", SerializeAggs(0, 0, BufferSpec.default, FastSeq(pCollectSig.state)), + ) + testParseIR( + "DeserializeAggs", DeserializeAggs(0, 0, BufferSpec.default, FastSeq(pCollectSig.state)), - CombOpValue(0, bin, pCollectSig) -> ArraySeq(bin), - AggStateValue(0, pCollectSig.state), - InitFromSerializedValue(0, bin, pCollectSig.state) -> ArraySeq(bin), - Begin(FastSeq(Void())), - MakeStruct(FastSeq("x" -> i)), - SelectFields(s, FastSeq("x", "z")) -> ArraySeq(s), - InsertFields(s, FastSeq("x" -> i)) -> ArraySeq(s), - InsertFields(s, FastSeq("* x *" -> i)) -> ArraySeq(s), // Won't parse as a simple identifier - GetField(s, "x") -> ArraySeq(s), - MakeTuple(FastSeq(2 -> i, 4 -> b)), - GetTupleElement(t, 1) -> ArraySeq(t), - Die("mumblefoo", TFloat64), - invoke("land", TBoolean, b, c) -> ArraySeq(c), // ApplySpecial - invoke("toFloat64", TFloat64, i), // Apply - Literal(TStruct("x" -> TInt32), Row(1)), - TableCount(table), - MatrixCount(mt), - TableGetGlobals(table), - TableCollect(TableKeyBy(table, FastSeq())), - TableAggregate(table, MakeStruct(IndexedSeq("foo" -> count))), - TableToValueApply(table, ForceCountTable()), - MatrixToValueApply(mt, ForceCountMatrixTable()), - TableWrite(table, TableNativeWriter("/path/to/data.ht")), - MatrixWrite(mt, MatrixNativeWriter("/path/to/data.mt")), - MatrixWrite(vcf, MatrixVCFWriter("/path/to/sample.vcf")), - MatrixWrite(vcf, MatrixPLINKWriter("/path/to/base")), - MatrixWrite(bgen, MatrixGENWriter("/path/to/base")), + ) + testParseIR( + "CombOpValue", + CombOpValue(0, bin, pCollectSig), + BindingEnv.eval(bin.name -> bin.typ), + ) + testParseIR("AggStateValue", AggStateValue(0, pCollectSig.state)) + testParseIR( + "InitFromSerializedValue", + InitFromSerializedValue(0, bin, pCollectSig.state), + BindingEnv.eval(bin.name -> bin.typ), + ) + testParseIR("Begin", Begin(FastSeq(Void()))) + testParseIR("MakeStruct", MakeStruct(FastSeq("x" -> i))) + testParseIR( + "SelectFields", + SelectFields(s, FastSeq("x", "z")), + BindingEnv.eval(s.name -> s.typ), + ) + testParseIR( + "InsertFields", + InsertFields(s, FastSeq("x" -> i)), + BindingEnv.eval(s.name -> s.typ), + ) + testParseIR( + "InsertFields special name", // Won't parse as a simple identifier + InsertFields(s, FastSeq("* x *" -> i)), + BindingEnv.eval(s.name -> s.typ), + ) + testParseIR("GetField", GetField(s, "x"), BindingEnv.eval(s.name -> s.typ)) + testParseIR("MakeTuple", MakeTuple(FastSeq(2 -> i, 4 -> b))) + testParseIR("GetTupleElement", GetTupleElement(t, 1), BindingEnv.eval(t.name -> t.typ)) + testParseIR("Die", Die("mumblefoo", TFloat64)) + testParseIR( + "ApplySpecial land", + invoke("land", TBoolean, b, c), + BindingEnv.eval(c.name -> c.typ), + ) + testParseIR("Apply toFloat64", invoke("toFloat64", TFloat64, i)) + testParseIR("Literal", Literal(TStruct("x" -> TInt32), Row(1))) + testParseIR("TableCount", TableCount(table)) + testParseIR("MatrixCount", MatrixCount(mt)) + testParseIR("TableGetGlobals", TableGetGlobals(table)) + testParseIR("TableCollect", TableCollect(TableKeyBy(table, FastSeq()))) + testParseIR("TableAggregate", TableAggregate(table, MakeStruct(IndexedSeq("foo" -> count)))) + testParseIR("TableToValueApply", TableToValueApply(table, ForceCountTable())) + testParseIR("MatrixToValueApply", MatrixToValueApply(mt, ForceCountMatrixTable())) + testParseIR("TableWrite", TableWrite(table, TableNativeWriter("/path/to/data.ht"))) + testParseIR("MatrixWrite NativeWriter", MatrixWrite(mt, MatrixNativeWriter("/path/to/data.mt"))) + testParseIR("MatrixWrite VCFWriter", MatrixWrite(vcf, MatrixVCFWriter("/path/to/sample.vcf"))) + testParseIR("MatrixWrite PLINKWriter", MatrixWrite(vcf, MatrixPLINKWriter("/path/to/base"))) + testParseIR("MatrixWrite GENWriter", MatrixWrite(bgen, MatrixGENWriter("/path/to/base"))) + testParseIR( + "MatrixWrite BlockMatrixWriter", MatrixWrite(mt, MatrixBlockMatrixWriter("path/to/data/bm", true, "a", 4096)), + ) + testParseIR( + "MatrixMultiWrite", MatrixMultiWrite( ArraySeq(mt, mt), MatrixNativeMultiWriter(IndexedSeq("/path/to/mt1", "/path/to/mt2")), ), + ) + testParseIR( + "TableMultiWrite", TableMultiWrite( ArraySeq(table, table), WrappedMatrixNativeMultiWriter( @@ -3522,12 +3668,21 @@ class IRSuite extends HailSuite { FastSeq("foo"), ), ), - MatrixAggregate(mt, MakeStruct(IndexedSeq("foo" -> count))), - BlockMatrixCollect(blockMatrix), - BlockMatrixWrite(blockMatrix, blockMatrixWriter), + ) + testParseIR("MatrixAggregate", MatrixAggregate(mt, MakeStruct(IndexedSeq("foo" -> count)))) + testParseIR("BlockMatrixCollect", BlockMatrixCollect(blockMatrix)) + testParseIR("BlockMatrixWrite", BlockMatrixWrite(blockMatrix, blockMatrixWriter)) + testParseIR( + "BlockMatrixMultiWrite", BlockMatrixMultiWrite(IndexedSeq(blockMatrix, blockMatrix), blockMatrixMultiWriter), + ) + testParseIR( + "BlockMatrixPersistWriter", BlockMatrixWrite(blockMatrix, BlockMatrixPersistWriter("x", "MEMORY_ONLY")), - cdaIR(rangeIR(3), 1, "test", NA(TString))((context, global) => context), + ) + testParseIR("cdaIR", cdaIR(rangeIR(3), 1, "test", NA(TString))((context, global) => context)) + testParseIR( + "ReadPartition", ReadPartition( MakeStruct(ArraySeq("partitionIndex" -> I64(0), "partitionPath" -> Str("foo"))), TStruct("foo" -> TInt32), @@ -3540,6 +3695,9 @@ class IRSuite extends HailSuite { "rowUID", ), ), + ) + testParseIR( + "WritePartition", WritePartition( MakeStream(FastSeq(), TStream(TStruct())), NA(TString), @@ -3551,10 +3709,16 @@ class IRSuite extends HailSuite { None, ), ), + ) + testParseIR( + "WriteMetadata", WriteMetadata( Begin(FastSeq()), RelationalWriter("path", overwrite = false, None), ), + ) + testParseIR( + "ReadValue", ReadValue( Str("foo"), ETypeValueReader(TypedCodecSpec( @@ -3564,335 +3728,306 @@ class IRSuite extends HailSuite { )), TStruct("foo" -> TInt32), ), + ) + testParseIR( + "WriteValue", WriteValue( I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(ctx, PInt32(), BufferSpec.default)), ), + ) + testParseIR( + "WriteValue with uid", WriteValue( I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(ctx, PInt32(), BufferSpec.default)), Some(Str("/tmp/uid/part")), ), - LiftMeOut(I32(1)), - relationalBindIR(I32(0))(_ => I32(0)), { + ) + testParseIR("LiftMeOut", LiftMeOut(I32(1))) + testParseIR("relationalBindIR", relationalBindIR(I32(0))(_ => I32(0))) + testParseIR( + "TailLoop", { val y = freshName() TailLoop(y, IndexedSeq(freshName() -> I32(0)), TInt32, Recur(y, FastSeq(I32(4)), TInt32)) }, ) - val emptyEnv = BindingEnv.empty[Type] - irs.map { case (ir, bind) => Array(ir, bind(emptyEnv)) } } - @DataProvider(name = "tableIRs") - def tableIRs(): Array[Array[TableIR]] = - tableIRs(ctx) + // --- Table IR parser tests --- - def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { - try { - val fs = ctx.fs - - val read = TableIR.read(fs, getTestResource("backward_compatability/1.1.0/table/0.ht")) - val mtRead = - MatrixIR.read(fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) - val b = True() - - val xs: Array[TableIR] = Array( - TableDistinct(read), - TableKeyBy(read, ArraySeq("m", "d")), - TableFilter(read, b), - read, - MatrixColsTable(mtRead), - TableAggregateByKey( - read, - MakeStruct(FastSeq( - "a" -> I32(5) - )), - ), - TableKeyByAndAggregate(read, NA(TStruct.empty), NA(TStruct.empty), Some(1), 2), - TableJoin(read, TableRange(100, 10), "inner", 1), - TableLeftJoinRightDistinct(read, TableRange(100, 10), "root"), - TableMultiWayZipJoin(FastSeq(read, read), " * data * ", "globals"), - MatrixEntriesTable(mtRead), - MatrixRowsTable(mtRead), - TableRepartition(read, 10, RepartitionStrategy.COALESCE), - TableHead(read, 10), - TableTail(read, 10), - TableParallelize( - MakeStruct(FastSeq( - "rows" -> MakeArray( - FastSeq( - MakeStruct(FastSeq("a" -> NA(TInt32))), - MakeStruct(FastSeq("a" -> I32(1))), - ), - TArray(TStruct("a" -> TInt32)), - ), - "global" -> MakeStruct(FastSeq()), - )), - None, - ), - TableMapRows( - TableKeyBy(read, FastSeq()), - MakeStruct(FastSeq( - "a" -> GetField(Ref(TableIR.rowName, read.typ.rowType), "f32"), - "b" -> F64(-2.11), - )), - ), { - val rs = Ref(freshName(), TStream(read.typ.rowType)) - TableMapPartitions( - TableKeyBy(read, FastSeq()), - freshName(), - rs.name, - StreamTake(rs, 1), - 0, - 0, - ) - }, - TableMapGlobals( - read, - MakeStruct(FastSeq( - "foo" -> NA(TArray(TInt32)) - )), - ), - TableRange(100, 10), - TableUnion( - FastSeq(TableRange(100, 10), TableRange(50, 10)) - ), - TableExplode(read, ArraySeq("mset")), - TableOrderBy( - TableKeyBy(read, FastSeq()), - FastSeq(SortField("m", Ascending), SortField("m", Descending)), - ), - CastMatrixToTable(mtRead, " # entries", " # cols"), - TableRename(read, Map("idx" -> "idx_foo"), Map("global_f32" -> "global_foo")), - TableFilterIntervals( - read, - FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), - keep = false, - ), - RelationalLetTable(freshName(), I32(0), read), { - val structs = MakeStream(FastSeq(), TStream(TStruct())) - val partitioner = RVDPartitioner.empty(ctx.stateManager, TStruct()) - TableGen( - structs, - MakeStruct(FastSeq()), - freshName(), - freshName(), - structs, - partitioner, - errorId = 180, - ) - }, - ) - xs.map(x => Array(x)) - } catch { - case t: Throwable => - println(t) - println(t.printStackTrace()) - throw t + object checkTableIRParser extends TestCases { + def apply( + ir: => TableIR + )(implicit loc: munit.Location + ): Unit = test("table IR parser") { + val ir0 = ir + val s = Pretty.sexprStyle(ir0, elideLiterals = false) + val x2 = IRParser.parse_table_ir(ctx, s) + assertEquals(x2, ir0) } } - @DataProvider(name = "matrixIRs") - def matrixIRs(): Array[Array[MatrixIR]] = - matrixIRs(ctx) - - def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { - try { - val fs = ctx.fs + { + lazy val read = TableIR.read(ctx.fs, getTestResource("backward_compatability/1.1.0/table/0.ht")) + lazy val mtRead = + MatrixIR.read(ctx.fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) + val b = True() - CompileAndEvaluate[Unit]( - ctx, - invoke( - "index_bgen", - TInt64, - ArraySeq(TLocus("GRCh37")), - Str(getTestResource("example.8bits.bgen")), - Str(getTestResource("example.8bits.bgen.idx2")), - Literal(TDict(TString, TString), Map("01" -> "1")), - False(), - I32(1000000), + checkTableIRParser(TableDistinct(read)) + checkTableIRParser(TableKeyBy(read, ArraySeq("m", "d"))) + checkTableIRParser(TableFilter(read, b)) + checkTableIRParser(read) + checkTableIRParser(MatrixColsTable(mtRead)) + checkTableIRParser(TableAggregateByKey( + read, + MakeStruct(FastSeq( + "a" -> I32(5) + )), + )) + checkTableIRParser(TableKeyByAndAggregate( + read, + NA(TStruct.empty), + NA(TStruct.empty), + Some(1), + 2, + )) + checkTableIRParser(TableJoin(read, TableRange(100, 10), "inner", 1)) + checkTableIRParser(TableLeftJoinRightDistinct(read, TableRange(100, 10), "root")) + checkTableIRParser(TableMultiWayZipJoin(FastSeq(read, read), " * data * ", "globals")) + checkTableIRParser(MatrixEntriesTable(mtRead)) + checkTableIRParser(MatrixRowsTable(mtRead)) + checkTableIRParser(TableRepartition(read, 10, RepartitionStrategy.COALESCE)) + checkTableIRParser(TableHead(read, 10)) + checkTableIRParser(TableTail(read, 10)) + checkTableIRParser(TableParallelize( + MakeStruct(FastSeq( + "rows" -> MakeArray( + FastSeq( + MakeStruct(FastSeq("a" -> NA(TInt32))), + MakeStruct(FastSeq("a" -> I32(1))), + ), + TArray(TStruct("a" -> TInt32)), ), + "global" -> MakeStruct(FastSeq()), + )), + None, + )) + checkTableIRParser(TableMapRows( + TableKeyBy(read, FastSeq()), + MakeStruct(FastSeq( + "a" -> GetField(Ref(TableIR.rowName, read.typ.rowType), "f32"), + "b" -> F64(-2.11), + )), + )) + checkTableIRParser({ + val rs = Ref(freshName(), TStream(read.typ.rowType)) + TableMapPartitions( + TableKeyBy(read, FastSeq()), + freshName(), + rs.name, + StreamTake(rs, 1), + 0, + 0, ) - - val tableRead = TableIR.read(fs, getTestResource("backward_compatability/1.1.0/table/0.ht")) - val read = - MatrixIR.read(fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) - val range = MatrixIR.range(ctx, 3, 7, None) - val vcf = importVCF(ctx, getTestResource("sample.vcf")) - - val bgenReader = MatrixBGENReader( - ctx, - FastSeq(getTestResource("example.8bits.bgen")), - None, - Map.empty[String, String], - None, - None, - None, + }) + checkTableIRParser(TableMapGlobals( + read, + MakeStruct(FastSeq( + "foo" -> NA(TArray(TInt32)) + )), + )) + checkTableIRParser(TableRange(100, 10)) + checkTableIRParser(TableUnion( + FastSeq(TableRange(100, 10), TableRange(50, 10)) + )) + checkTableIRParser(TableExplode(read, ArraySeq("mset"))) + checkTableIRParser(TableOrderBy( + TableKeyBy(read, FastSeq()), + FastSeq(SortField("m", Ascending), SortField("m", Descending)), + )) + checkTableIRParser(CastMatrixToTable(mtRead, " # entries", " # cols")) + checkTableIRParser(TableRename( + read, + Map("idx" -> "idx_foo"), + Map("global_f32" -> "global_foo"), + )) + checkTableIRParser(TableFilterIntervals( + read, + FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), + keep = false, + )) + checkTableIRParser(RelationalLetTable(freshName(), I32(0), read)) + checkTableIRParser({ + val structs = MakeStream(FastSeq(), TStream(TStruct())) + val partitioner = RVDPartitioner.empty(ctx.stateManager, TStruct()) + TableGen( + structs, + MakeStruct(FastSeq()), + freshName(), + freshName(), + structs, + partitioner, + errorId = 180, ) - val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) - - val range1 = MatrixIR.range(ctx, 20, 2, Some(3)) - val range2 = MatrixIR.range(ctx, 20, 2, Some(4)) - - val b = True() - - val newCol = MakeStruct(FastSeq( - "col_idx" -> GetField(Ref(MatrixIR.colName, read.typ.colType), "col_idx"), - "new_f32" -> ApplyBinaryPrimOp( - Add(), - GetField(Ref(MatrixIR.colName, read.typ.colType), "col_f32"), - F32(-5.2f), - ), - )) - val newRow = MakeStruct(FastSeq( - "row_idx" -> GetField(Ref(MatrixIR.rowName, read.typ.rowType), "row_idx"), - "new_f32" -> ApplyBinaryPrimOp( - Add(), - GetField(Ref(MatrixIR.rowName, read.typ.rowType), "row_f32"), - F32(-5.2f), - ), - )) + }) + } - val collect = ApplyAggOp(Collect())(I32(0)) - - val newRowAnn = MakeStruct(FastSeq("count_row" -> collect)) - val newColAnn = MakeStruct(FastSeq("count_col" -> collect)) - val newEntryAnn = MakeStruct(FastSeq("count_entry" -> collect)) - - val xs = Array[MatrixIR]( - read, - MatrixFilterRows(read, b), - MatrixFilterCols(read, b), - MatrixFilterEntries(read, b), - MatrixChooseCols(read, ArraySeq(0, 0, 0)), - MatrixMapCols(read, newCol, None), - MatrixKeyRowsBy(read, FastSeq("row_m", "row_d"), false), - MatrixMapRows(read, newRow), - MatrixRepartition(read, 10, 0), - MatrixMapEntries( - read, - MakeStruct(FastSeq( - "global_f32" -> ApplyBinaryPrimOp( - Add(), - GetField(Ref(MatrixIR.globalName, read.typ.globalType), "global_f32"), - F32(-5.2f), - ) - )), - ), - MatrixCollectColsByKey(read), - MatrixAggregateColsByKey(read, newEntryAnn, newColAnn), - MatrixAggregateRowsByKey(read, newEntryAnn, newRowAnn), - range, - vcf, - bgen, - MatrixExplodeRows(read, FastSeq("row_mset")), - MatrixUnionRows(FastSeq(range1, range2)), - MatrixDistinctByRow(range1), - MatrixRowsHead(range1, 3), - MatrixColsHead(range1, 3), - MatrixRowsTail(range1, 3), - MatrixColsTail(range1, 3), - MatrixExplodeCols(read, FastSeq("col_mset")), - CastTableToMatrix( - CastMatrixToTable(read, " # entries", " # cols"), - " # entries", - " # cols", - read.typ.colKey, - ), - MatrixAnnotateColsTable(read, tableRead, "uid_123"), - MatrixAnnotateRowsTable(read, tableRead, "uid_123", product = false), - MatrixRename( - read, - Map("global_i64" -> "foo"), - Map("col_i64" -> "bar"), - Map("row_i64" -> "baz"), - Map("entry_i64" -> "quam"), - ), - MatrixFilterIntervals( - read, - FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), - keep = false, - ), - RelationalLetMatrixTable(freshName(), I32(0), read), - ) + // --- Matrix IR parser tests --- - xs.map(x => Array(x)) - } catch { - case t: Throwable => - println(t) - println(t.printStackTrace()) - throw t + object checkMatrixIRParser extends TestCases { + def apply( + ir: => MatrixIR + )(implicit loc: munit.Location + ): Unit = test("matrix IR parser") { + val ir0 = ir + val s = Pretty.sexprStyle(ir0, elideLiterals = false) + val x2 = IRParser.parse_matrix_ir(ctx, s) + assertEquals(x2, ir0) } } - @DataProvider(name = "blockMatrixIRs") - def blockMatrixIRs(): Array[Array[BlockMatrixIR]] = { - val read = - BlockMatrixRead(BlockMatrixNativeReader(fs, getTestResource("blockmatrix_example/0"))) - val transpose = BlockMatrixBroadcast(read, FastSeq(1, 0), FastSeq(2, 2), 2) - val dot = BlockMatrixDot(read, transpose) - val slice = BlockMatrixSlice(read, FastSeq(FastSeq(0, 2, 1), FastSeq(0, 1, 1))) + { + lazy val tableRead = + TableIR.read(ctx.fs, getTestResource("backward_compatability/1.1.0/table/0.ht")) + lazy val read = + MatrixIR.read(ctx.fs, getTestResource("backward_compatability/1.0.0/matrix_table/0.hmt")) + lazy val range = MatrixIR.range(ctx, 3, 7, None) + lazy val vcf = importVCF(ctx, getTestResource("sample.vcf")) - val sparsify1 = BlockMatrixSparsify(read, RectangleSparsifier(FastSeq(FastSeq(0L, 1L, 5L, 6L)))) - val sparsify2 = BlockMatrixSparsify(read, BandSparsifier(true, -1L, 1L)) - val sparsify3 = BlockMatrixSparsify( - read, - RowIntervalSparsifier(true, FastSeq(0L, 1L, 5L, 6L), FastSeq(5L, 6L, 8L, 9L)), + lazy val bgenReader = MatrixBGENReader( + ctx, + FastSeq(getTestResource("example.8bits.bgen")), + None, + Map.empty[String, String], + None, + None, + None, ) - val densify = BlockMatrixDensify(read) + lazy val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) - val blockMatrixIRs = Array[BlockMatrixIR]( - read, - transpose, - dot, - sparsify1, - sparsify2, - sparsify3, - densify, - slice, - ) + lazy val range1 = MatrixIR.range(ctx, 20, 2, Some(3)) + lazy val range2 = MatrixIR.range(ctx, 20, 2, Some(4)) - blockMatrixIRs.map(ir => Array(ir)) - } + val b = True() - @Test def testIRConstruction(): Unit = { - matrixIRs(): Unit - tableIRs(): Unit - valueIRs(): Unit - blockMatrixIRs(): Unit - } + lazy val newCol = MakeStruct(FastSeq( + "col_idx" -> GetField(Ref(MatrixIR.colName, read.typ.colType), "col_idx"), + "new_f32" -> ApplyBinaryPrimOp( + Add(), + GetField(Ref(MatrixIR.colName, read.typ.colType), "col_f32"), + F32(-5.2f), + ), + )) + lazy val newRow = MakeStruct(FastSeq( + "row_idx" -> GetField(Ref(MatrixIR.rowName, read.typ.rowType), "row_idx"), + "new_f32" -> ApplyBinaryPrimOp( + Add(), + GetField(Ref(MatrixIR.rowName, read.typ.rowType), "row_f32"), + F32(-5.2f), + ), + )) - @Test(dataProvider = "valueIRs") - def testValueIRParser(x: IR, refMap: BindingEnv[Type]): Unit = { - val s = Pretty.sexprStyle(x, elideLiterals = false) - val x2 = IRParser.parse_value_ir(ctx, s, refMap) - assert(x2 == x) + val collect = ApplyAggOp(Collect())(I32(0)) + + val newRowAnn = MakeStruct(FastSeq("count_row" -> collect)) + val newColAnn = MakeStruct(FastSeq("count_col" -> collect)) + val newEntryAnn = MakeStruct(FastSeq("count_entry" -> collect)) + + checkMatrixIRParser(read) + checkMatrixIRParser(MatrixFilterRows(read, b)) + checkMatrixIRParser(MatrixFilterCols(read, b)) + checkMatrixIRParser(MatrixFilterEntries(read, b)) + checkMatrixIRParser(MatrixChooseCols(read, ArraySeq(0, 0, 0))) + checkMatrixIRParser(MatrixMapCols(read, newCol, None)) + checkMatrixIRParser(MatrixKeyRowsBy(read, FastSeq("row_m", "row_d"), false)) + checkMatrixIRParser(MatrixMapRows(read, newRow)) + checkMatrixIRParser(MatrixRepartition(read, 10, 0)) + checkMatrixIRParser(MatrixMapEntries( + read, + MakeStruct(FastSeq( + "global_f32" -> ApplyBinaryPrimOp( + Add(), + GetField(Ref(MatrixIR.globalName, read.typ.globalType), "global_f32"), + F32(-5.2f), + ) + )), + )) + checkMatrixIRParser(MatrixCollectColsByKey(read)) + checkMatrixIRParser(MatrixAggregateColsByKey(read, newEntryAnn, newColAnn)) + checkMatrixIRParser(MatrixAggregateRowsByKey(read, newEntryAnn, newRowAnn)) + checkMatrixIRParser(range) + checkMatrixIRParser(vcf) + checkMatrixIRParser(bgen) + checkMatrixIRParser(MatrixExplodeRows(read, FastSeq("row_mset"))) + checkMatrixIRParser(MatrixUnionRows(FastSeq(range1, range2))) + checkMatrixIRParser(MatrixDistinctByRow(range1)) + checkMatrixIRParser(MatrixRowsHead(range1, 3)) + checkMatrixIRParser(MatrixColsHead(range1, 3)) + checkMatrixIRParser(MatrixRowsTail(range1, 3)) + checkMatrixIRParser(MatrixColsTail(range1, 3)) + checkMatrixIRParser(MatrixExplodeCols(read, FastSeq("col_mset"))) + checkMatrixIRParser(CastTableToMatrix( + CastMatrixToTable(read, " # entries", " # cols"), + " # entries", + " # cols", + read.typ.colKey, + )) + checkMatrixIRParser(MatrixAnnotateColsTable(read, tableRead, "uid_123")) + checkMatrixIRParser(MatrixAnnotateRowsTable(read, tableRead, "uid_123", product = false)) + checkMatrixIRParser(MatrixRename( + read, + Map("global_i64" -> "foo"), + Map("col_i64" -> "bar"), + Map("row_i64" -> "baz"), + Map("entry_i64" -> "quam"), + )) + checkMatrixIRParser(MatrixFilterIntervals( + read, + FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), + keep = false, + )) + checkMatrixIRParser(RelationalLetMatrixTable(freshName(), I32(0), read)) } - @Test(dataProvider = "tableIRs") - def testTableIRParser(x: TableIR): Unit = { - val s = Pretty.sexprStyle(x, elideLiterals = false) - val x2 = IRParser.parse_table_ir(ctx, s) - assert(x2 == x) - } + // --- Block matrix IR parser tests --- - @Test(dataProvider = "matrixIRs") - def testMatrixIRParser(x: MatrixIR): Unit = { - val s = Pretty.sexprStyle(x, elideLiterals = false) - val x2 = IRParser.parse_matrix_ir(ctx, s) - assert(x2 == x) + object checkBlockMatrixIRParser extends TestCases { + def apply( + ir: => BlockMatrixIR + )(implicit loc: munit.Location + ): Unit = test("block matrix IR parser") { + val ir0 = ir + val s = Pretty.sexprStyle(ir0, elideLiterals = false) + val x2 = IRParser.parse_blockmatrix_ir(ctx, s) + assertEquals(x2, ir0) + } } - @Test(dataProvider = "blockMatrixIRs") - def testBlockMatrixIRParser(x: BlockMatrixIR): Unit = { - val s = Pretty.sexprStyle(x, elideLiterals = false) - val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == x) + { + lazy val read = + BlockMatrixRead(BlockMatrixNativeReader(fs, getTestResource("blockmatrix_example/0"))) + lazy val transpose = BlockMatrixBroadcast(read, FastSeq(1, 0), FastSeq(2, 2), 2) + + checkBlockMatrixIRParser(read) + checkBlockMatrixIRParser(transpose) + checkBlockMatrixIRParser(BlockMatrixDot(read, transpose)) + checkBlockMatrixIRParser(BlockMatrixSparsify( + read, + RectangleSparsifier(FastSeq(FastSeq(0L, 1L, 5L, 6L))), + )) + checkBlockMatrixIRParser(BlockMatrixSparsify(read, BandSparsifier(true, -1L, 1L))) + checkBlockMatrixIRParser(BlockMatrixSparsify( + read, + RowIntervalSparsifier(true, FastSeq(0L, 1L, 5L, 6L), FastSeq(5L, 6L, 8L, 9L)), + )) + checkBlockMatrixIRParser(BlockMatrixDensify(read)) + checkBlockMatrixIRParser(BlockMatrixSlice(read, FastSeq(FastSeq(0, 2, 1), FastSeq(0, 1, 1)))) } - @Test def testBlockMatrixIRParserPersist(): Unit = { + test("BlockMatrixIRParserPersist") { val cache = mutable.Map.empty[String, BlockMatrix] val bm = BlockMatrixRandom(0, gaussian = true, shape = ArraySeq(5L, 6L), blockSize = 3) try { @@ -3907,13 +4042,13 @@ class IRSuite extends HailSuite { val persist = BlockMatrixRead(BlockMatrixPersistReader("x", bm.typ)) val s = Pretty.sexprStyle(persist, elideLiterals = false) val x2 = IRParser.parse_blockmatrix_ir(ctx, s) - assert(x2 == persist) + assertEquals(x2, persist) } } finally cache.values.foreach(_.unpersist()) } - @Test def testCachedIR(): Unit = { + test("CachedIR") { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" val x2 = @@ -3921,7 +4056,7 @@ class IRSuite extends HailSuite { assert(x2 eq cached) } - @Test def testCachedTableIR(): Unit = { + test("CachedTableIR") { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" val x2 = @@ -3929,7 +4064,7 @@ class IRSuite extends HailSuite { assert(x2 eq cached) } - @Test def testArrayContinuationDealsWithIfCorrectly(): Unit = { + test("ArrayContinuationDealsWithIfCorrectly") { val ir = ToArray(mapIR( If(IsNA(In(0, TBoolean)), NA(TStream(TInt32)), ToStream(In(1, TArray(TInt32)))) )(Cast(_, TInt64))) @@ -3937,7 +4072,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ir, FastSeq(true -> TBoolean, FastSeq(0) -> TArray(TInt32)), FastSeq(0L)) } - @Test def testTableGetGlobalsSimplifyRules(): Unit = { + test("TableGetGlobalsSimplifyRules") { implicit val execStrats = ExecStrategy.interpretOnly val t1 = @@ -3974,7 +4109,7 @@ class IRSuite extends HailSuite { assertEvalsTo(TableGetGlobals(TableRename(tab1, Map.empty, Map("g2" -> "g3"))), Row(1, 1.1)) } - @Test def testAggLet(): Unit = { + test("AggLet") { implicit val execStrats = ExecStrategy.interpretOnly val t = TableRange(2, 2) val ir = TableAggregate( @@ -3989,14 +4124,14 @@ class IRSuite extends HailSuite { assertEvalsTo(ir, 61L) } - @Test def testRelationalLet(): Unit = { + test("RelationalLet") { implicit val execStrats = ExecStrategy.interpretOnly val ir = relationalBindIR(NA(TInt32))(x => x) assertEvalsTo(ir, null) } - @Test def testRelationalLetTable(): Unit = { + test("RelationalLetTable") { implicit val execStrats = ExecStrategy.interpretOnly val t = TArray(TStruct("x" -> TInt32)) @@ -4015,7 +4150,7 @@ class IRSuite extends HailSuite { assertEvalsTo(ir, 1L) } - @Test def testRelationalLetMatrixTable(): Unit = { + test("RelationalLetMatrixTable") { implicit val execStrats = ExecStrategy.interpretOnly val t = TArray(TStruct("x" -> TInt32)) @@ -4045,7 +4180,6 @@ class IRSuite extends HailSuite { assertEvalsTo(ir, 1L) } - @DataProvider(name = "relationalFunctions") def relationalFunctionsData(): Array[Array[Any]] = Array( Array(TableFilterPartitions(ArraySeq(1, 2, 3), keep = true)), Array(VEP(fs, getTestResource("dummy_vep_config.json"), false, 1, true)), @@ -4093,31 +4227,50 @@ class IRSuite extends HailSuite { Array(GetElement(FastSeq(1, 2))), ) - @Test def relationalFunctionsRun(): Unit = relationalFunctionsData(): Unit - - @Test(dataProvider = "relationalFunctions") - def testRelationalFunctionsSerialize(x: Any): Unit = { - implicit val formats = RelationalFunctions.formats - - x match { - case x: MatrixToMatrixFunction => - assert(RelationalFunctions.lookupMatrixToMatrix(ctx, Serialization.write(x)) == x) - case x: MatrixToTableFunction => - assert(RelationalFunctions.lookupMatrixToTable(ctx, Serialization.write(x)) == x) - case x: MatrixToValueFunction => - assert(RelationalFunctions.lookupMatrixToValue(ctx, Serialization.write(x)) == x) - case x: TableToTableFunction => - assert(RelationalFunctions.lookupTableToTable(ctx, JsonMethods.compact(x.toJValue)) == x) - case x: TableToValueFunction => - assert(RelationalFunctions.lookupTableToValue(ctx, Serialization.write(x)) == x) - case x: BlockMatrixToTableFunction => - assert(RelationalFunctions.lookupBlockMatrixToTable(ctx, Serialization.write(x)) == x) - case x: BlockMatrixToValueFunction => - assert(RelationalFunctions.lookupBlockMatrixToValue(ctx, Serialization.write(x)) == x) + test("relationalFunctionsRun") { + relationalFunctionsData(): Unit + } + + // --- Relational functions serialize tests (from dataProvider "relationalFunctions") --- + private lazy val relationalFunctionsTestData: Array[Array[Any]] = relationalFunctionsData() + + private val RELATIONAL_FUNCTIONS_COUNT = 19 + + test("relational functions count check") { + assertEquals( + relationalFunctionsTestData.length, + RELATIONAL_FUNCTIONS_COUNT, + "RELATIONAL_FUNCTIONS_COUNT is out of sync with relationalFunctionsData()", + ) + } + + (0 until RELATIONAL_FUNCTIONS_COUNT).foreach { i => + test(s"relational functions serialize $i") { + implicit val formats = RelationalFunctions.formats + val x = relationalFunctionsTestData(i)(0) + x match { + case x: MatrixToMatrixFunction => + assertEquals(RelationalFunctions.lookupMatrixToMatrix(ctx, Serialization.write(x)), x) + case x: MatrixToTableFunction => + assertEquals(RelationalFunctions.lookupMatrixToTable(ctx, Serialization.write(x)), x) + case x: MatrixToValueFunction => + assertEquals(RelationalFunctions.lookupMatrixToValue(ctx, Serialization.write(x)), x) + case x: TableToTableFunction => + assertEquals( + RelationalFunctions.lookupTableToTable(ctx, JsonMethods.compact(x.toJValue)), + x, + ) + case x: TableToValueFunction => + assertEquals(RelationalFunctions.lookupTableToValue(ctx, Serialization.write(x)), x) + case x: BlockMatrixToTableFunction => + assertEquals(RelationalFunctions.lookupBlockMatrixToTable(ctx, Serialization.write(x)), x) + case x: BlockMatrixToValueFunction => + assertEquals(RelationalFunctions.lookupBlockMatrixToValue(ctx, Serialization.write(x)), x) + } } } - @Test def testFoldWithSetup(): Unit = { + test("FoldWithSetup") { val v = In(0, TInt32) val cond1 = If( v.ceq(I32(3)), @@ -4131,12 +4284,12 @@ class IRSuite extends HailSuite { ) } - @Test def testNonCanonicalTypeParsing(): Unit = { + test("NonCanonicalTypeParsing") { val t = TTuple(FastSeq(TupleField(1, TInt64))) val lit = Literal(t, Row(1L)) - assert(IRParser.parseType(t.parsableString()) == t) - assert(IRParser.parse_value_ir(ctx, Pretty.sexprStyle(lit, elideLiterals = false)) == lit) + assertEquals(IRParser.parseType(t.parsableString()), t) + assertEquals(IRParser.parse_value_ir(ctx, Pretty.sexprStyle(lit, elideLiterals = false)), lit) } def regressionTestUnifyBug(): Unit = { @@ -4168,7 +4321,7 @@ class IRSuite extends HailSuite { ) } - @Test def testSimpleTailLoop(): Unit = { + test("SimpleTailLoop") { implicit val execStrats = ExecStrategy.compileOnly val triangleSum: IR = tailLoop(TInt32, In(0, TInt32), In(1, TInt32)) { case (recur, Seq(x, accum)) => @@ -4180,7 +4333,7 @@ class IRSuite extends HailSuite { assertEvalsTo(triangleSum, FastSeq((null, TInt32), 0 -> TInt32), null) } - @Test def testNestedTailLoop(): Unit = { + test("NestedTailLoop") { implicit val execStrats = ExecStrategy.compileOnly val triangleSum: IR = tailLoop(TInt32, In(0, TInt32), I32(0)) { case (recur, Seq(x, accum)) => If( @@ -4195,7 +4348,7 @@ class IRSuite extends HailSuite { assertEvalsTo(triangleSum, FastSeq(5 -> TInt32), 15 + 10 + 5) } - @Test def testTailLoopNDMemory(): Unit = { + test("TailLoopNDMemory") { val ndType = TNDArray(TInt32, Nat(2)) val ndSum: IR = tailLoop( @@ -4222,46 +4375,52 @@ class IRSuite extends HailSuite { eval(ndSum, Env.empty, FastSeq(100 -> TInt32, startingArg -> ndType), None, None, ctx) } - assert(memUsed == memUsed2) + assertEquals(memUsed, memUsed2) } - @Test def testHasIRSharing(): Unit = { + test("HasIRSharing") { val r = Ref(freshName(), TInt32) val ir1 = MakeTuple.ordered(FastSeq(I64(1), r, r, I32(1))) assert(HasIRSharing(ctx)(ir1)) assert(!HasIRSharing(ctx)(ir1.deepCopy())) } - @Test def freeVariables(): Unit = { + test("freeVariables") { val stream = rangeIR(5) val y = Ref(freshName(), TInt32) val z = Ref(freshName(), TInt32) val explodeIR = aggExplodeIR(stream)(x => z + ApplyAggOp(Sum())(x + y)) - assert(FreeVariables(explodeIR, true, true) == BindingEnv[Unit]( - Env((z.name, ())), - Some(Env((y.name, ()))), - Some(Env()), - )) - assert(FreeVariables(explodeIR, false, false) == BindingEnv[Unit](Env((z.name, ())))) + assertEquals( + FreeVariables(explodeIR, true, true), + BindingEnv[Unit]( + Env((z.name, ())), + Some(Env((y.name, ()))), + Some(Env()), + ), + ) + assertEquals(FreeVariables(explodeIR, false, false), BindingEnv[Unit](Env((z.name, ())))) val aggIR = streamAggIR(stream)(x => z + ApplyAggOp(Sum())(x + y)) - assert( - FreeVariables(aggIR, true, true) == BindingEnv[Unit]( + assertEquals( + FreeVariables(aggIR, true, true), + BindingEnv[Unit]( Env((z.name, ()), (y.name, ())), Some(Env()), Some(Env()), - ) + ), ) } - @Test def freeVariablesAggScanBindingEnv(): Unit = { + test("freeVariablesAggScanBindingEnv") { def testFreeVarsHelper(ir: IR): Unit = { val irFreeVarsTrue = FreeVariables.apply(ir, true, true) - assert(irFreeVarsTrue.agg.isDefined && irFreeVarsTrue.scan.isDefined) + assert(irFreeVarsTrue.agg.isDefined) + assert(irFreeVarsTrue.scan.isDefined) val irFreeVarsFalse = FreeVariables.apply(ir, false, false) - assert(irFreeVarsFalse.agg.isEmpty && irFreeVarsFalse.scan.isEmpty) + assert(irFreeVarsFalse.agg.isEmpty) + assert(irFreeVarsFalse.scan.isEmpty) } val aggIR = streamAggIR(mapIR(rangeIR(4))(Cast(_, TInt64)))(ApplyAggOp(Sum())(_)) @@ -4273,7 +4432,6 @@ class IRSuite extends HailSuite { testFreeVarsHelper(scanIR) } - @DataProvider(name = "nonNullTypesAndValues") def nonNullTypesAndValues(): Array[Array[Any]] = Array( Array(Int32SingleCodeType, 1), Array(Int64SingleCodeType, 5L), @@ -4292,43 +4450,50 @@ class IRSuite extends HailSuite { ), ) - @Test(dataProvider = "nonNullTypesAndValues") - def testReadWriteValues(pt: SingleCodeType, value: Any): Unit = { - implicit val execStrats = ExecStrategy.compileOnly - val node = In(0, SingleCodeEmitParamType(true, pt)) - val spec = TypedCodecSpec(ctx, PType.canonical(node.typ), BufferSpec.blockedUncompressed) - val writer = ETypeValueWriter(spec) - val reader = ETypeValueReader(spec) - val prefix = ctx.createTmpPath("test-read-write-values") - val filename = WriteValue(node, Str(prefix) + UUID4(), writer) - forAll(Array(value, null)) { v => - assertEvalsTo(ReadValue(filename, reader, pt.virtualType), FastSeq(v -> pt.virtualType), v) + // --- Read/write values tests (from dataProvider "nonNullTypesAndValues") --- + nonNullTypesAndValues().zipWithIndex.foreach { case (arr, i) => + test(s"ReadWriteValues $i") { + implicit val execStrats = ExecStrategy.compileOnly + val pt = arr(0).asInstanceOf[SingleCodeType] + val value = arr(1) + val node = In(0, SingleCodeEmitParamType(true, pt)) + val spec = TypedCodecSpec(ctx, PType.canonical(node.typ), BufferSpec.blockedUncompressed) + val writer = ETypeValueWriter(spec) + val reader = ETypeValueReader(spec) + val prefix = ctx.createTmpPath("test-read-write-values") + val filename = WriteValue(node, Str(prefix) + UUID4(), writer) + Array(value, null).foreach { v => + assertEvalsTo(ReadValue(filename, reader, pt.virtualType), FastSeq(v -> pt.virtualType), v) + } } } - @Test(dataProvider = "nonNullTypesAndValues") - def testReadWriteValueDistributed(pt: SingleCodeType, value: Any): Unit = { - implicit val execStrats = ExecStrategy.compileOnly - val node = In(0, SingleCodeEmitParamType(true, pt)) - val spec = TypedCodecSpec(ctx, PType.canonical(node.typ), BufferSpec.blockedUncompressed) - val writer = ETypeValueWriter(spec) - val reader = ETypeValueReader(spec) - val prefix = ctx.createTmpPath("test-read-write-value-dist") - val readArray = bindIR( - cdaIR( - mapIR(rangeIR(10))(_ => node), - MakeStruct(FastSeq()), - "test", - NA(TString), - )((ctx, _) => WriteValue(ctx, Str(prefix) + UUID4(), writer)) - )(files => mapIR(ToStream(files))(ReadValue(_, reader, pt.virtualType))) - - forAll(ArraySeq(value, null)) { v => - assertEvalsTo(ToArray(readArray), ArraySeq(v -> pt.virtualType), ArraySeq.fill(10)(v)) + nonNullTypesAndValues().zipWithIndex.foreach { case (arr, i) => + test(s"ReadWriteValueDistributed $i") { + implicit val execStrats = ExecStrategy.compileOnly + val pt = arr(0).asInstanceOf[SingleCodeType] + val value = arr(1) + val node = In(0, SingleCodeEmitParamType(true, pt)) + val spec = TypedCodecSpec(ctx, PType.canonical(node.typ), BufferSpec.blockedUncompressed) + val writer = ETypeValueWriter(spec) + val reader = ETypeValueReader(spec) + val prefix = ctx.createTmpPath("test-read-write-value-dist") + val readArray = bindIR( + cdaIR( + mapIR(rangeIR(10))(_ => node), + MakeStruct(FastSeq()), + "test", + NA(TString), + )((ctx, _) => WriteValue(ctx, Str(prefix) + UUID4(), writer)) + )(files => mapIR(ToStream(files))(ReadValue(_, reader, pt.virtualType))) + + ArraySeq(value, null).foreach { v => + assertEvalsTo(ToArray(readArray), ArraySeq(v -> pt.virtualType), ArraySeq.fill(10)(v)) + } } } - @Test def testUUID4(): Unit = { + test("UUID4") { val single = UUID4() val hex = "[0-9a-f]" val format = s"$hex{8}-$hex{4}-$hex{4}-$hex{4}-$hex{12}" @@ -4356,15 +4521,13 @@ class IRSuite extends HailSuite { assertNumDistinct(bindIR(ToArray(stream))(a => selfZip(ToStream(a), 2)), 5) } - @Test def testZipDoesntPruneLengthInfo(): Unit = - forAll { - ArraySeq( - ArrayZipBehavior.AssumeSameLength, - ArrayZipBehavior.AssertSameLength, - ArrayZipBehavior.TakeMinLength, - ArrayZipBehavior.ExtendNA, - ) - } { behavior => + test("ZipDoesntPruneLengthInfo") { + ArraySeq( + ArrayZipBehavior.AssumeSameLength, + ArrayZipBehavior.AssertSameLength, + ArrayZipBehavior.TakeMinLength, + ArrayZipBehavior.ExtendNA, + ).foreach { behavior => val zip = zipIR( ArraySeq(StreamRange(0, 10, 1), StreamRange(0, 10, 1)), behavior, @@ -4373,7 +4536,9 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(zip), ArraySeq.fill(10)(Row("foo", "bar"))) } - @Test def testStreamDistribute(): Unit = { + } + + test("StreamDistribute") { val data1 = IndexedSeq(0, 1, 1, 2, 4, 7, 7, 7, 9, 11, 15, 20, 22, 28, 50, 100) val pivots1 = IndexedSeq(-10, 1, 7, 7, 15, 22, 50, 200) val pivots2 = IndexedSeq(-10, 1, 1, 7, 9, 28, 50, 200) @@ -4428,16 +4593,16 @@ class IRSuite extends HailSuite { reader, )) val rowsFromDisk = eval(read).asInstanceOf[IndexedSeq[Row]] - assert(rowsFromDisk.size == elementCount) + assertEquals(rowsFromDisk.size, elementCount) assert(rowsFromDisk.forall(interval.contains(kord, _))) rowsFromDisk.foreach { row => - assert(row(0) == data(dataIdx)) + assertEquals(row(0), data(dataIdx)) dataIdx += 1 } } - assert(dataIdx == data.size) + assertEquals(dataIdx, data.size) result.map(_._1).sliding(2).foreach { case IndexedSeq(interval1, interval2) => assert(interval1.isDisjointFrom(kord, interval2)) @@ -4453,9 +4618,9 @@ class IRSuite extends HailSuite { } val expectedStartsAndEnds = intBuilder.result().sliding(2).toIndexedSeq - forAll(result.map(_._1).zip(expectedStartsAndEnds)) { case (interval, splitterPair) => - assert(interval.start.asInstanceOf[Row](0) == splitterPair(0)) - assert(interval.end.asInstanceOf[Row](0) == splitterPair(1)) + result.map(_._1).zip(expectedStartsAndEnds).foreach { case (interval, splitterPair) => + assertEquals(interval.start.asInstanceOf[Row](0), splitterPair(0)) + assertEquals(interval.end.asInstanceOf[Row](0), splitterPair(1)) } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala index c8c4f8b2101..e2ee9086613 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IntervalSuite.scala @@ -12,9 +12,6 @@ import is.hail.types.virtual._ import is.hail.utils._ import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.testng.ITestContext -import org.testng.annotations.{BeforeMethod, Test} class IntervalSuite extends HailSuite { @@ -43,7 +40,7 @@ class IntervalSuite extends HailSuite { val i4 = interval(NA(tpoint1), point(2), null, false) val i5 = interval(NA(tpoint1), point(2), true, null) - @Test def constructor(): Unit = { + test("constructor") { assertEvalsTo(i1, Interval(Row(1), Row(2), true, false)) assertEvalsTo(i2, Interval(Row(1), null, true, false)) assertEvalsTo(i3, Interval(null, Row(2), true, false)) @@ -51,33 +48,33 @@ class IntervalSuite extends HailSuite { assertEvalsTo(i5, null) } - @Test def start(): Unit = { + test("start") { assertEvalsTo(invoke("start", tpoint1, i1), Row(1)) assertEvalsTo(invoke("start", tpoint1, i2), Row(1)) assertEvalsTo(invoke("start", tpoint1, i3), null) assertEvalsTo(invoke("start", tpoint1, na), null) } - @Test def defaultValueCorrectlyStored(): Unit = { + test("defaultValueCorrectlyStored") { assertEvalsTo(If(GetTupleElement(invoke("start", tpoint1, i1), 0).ceq(1), true, false), true) assertEvalsTo(If(GetTupleElement(invoke("end", tpoint1, i1), 0).ceq(2), true, false), true) } - @Test def end(): Unit = { + test("end") { assertEvalsTo(invoke("end", tpoint1, i1), Row(2)) assertEvalsTo(invoke("end", tpoint1, i2), null) assertEvalsTo(invoke("end", tpoint1, i3), Row(2)) assertEvalsTo(invoke("end", tpoint1, na), null) } - @Test def includeStart(): Unit = { + test("includeStart") { assertEvalsTo(invoke("includesStart", TBoolean, i1), true) assertEvalsTo(invoke("includesStart", TBoolean, i2), true) assertEvalsTo(invoke("includesStart", TBoolean, i3), true) assertEvalsTo(invoke("includesStart", TBoolean, na), null) } - @Test def includeEnd(): Unit = { + test("includeEnd") { assertEvalsTo(invoke("includesEnd", TBoolean, i1), false) assertEvalsTo(invoke("includesEnd", TBoolean, i2), false) assertEvalsTo(invoke("includesEnd", TBoolean, i3), false) @@ -107,40 +104,45 @@ class IntervalSuite extends HailSuite { i.includesEnd, ) - @Test def contains(): Unit = - forAll(cartesian(testIntervals, points)) { case (setInterval, p) => + test("contains") { + cartesian(testIntervals, points).foreach { case (setInterval, p) => val interval = toIRInterval(setInterval) - assert(eval(invoke("contains", TBoolean, interval, p)) == setInterval.contains(p)) + assertEquals(eval(invoke("contains", TBoolean, interval, p)), setInterval.contains(p)) } + } - @Test def isEmpty(): Unit = - forAll(testIntervals) { setInterval => + test("isEmpty") { + testIntervals.foreach { setInterval => val interval = toIRInterval(setInterval) - assert(eval( - invoke("isEmpty", TBoolean, ErrorIDs.NO_ERROR, interval) - ) == setInterval.definitelyEmpty()) + assertEquals( + eval(invoke("isEmpty", TBoolean, ErrorIDs.NO_ERROR, interval)), + setInterval.definitelyEmpty(), + ) } + } - @Test def overlaps(): Unit = - forAll(cartesian(testIntervals, testIntervals)) { case (setInterval1, setInterval2) => + test("overlaps") { + cartesian(testIntervals, testIntervals).foreach { case (setInterval1, setInterval2) => val interval1 = toIRInterval(setInterval1) val interval2 = toIRInterval(setInterval2) - assert(eval( - invoke("overlaps", TBoolean, interval1, interval2) - ) == setInterval1.probablyOverlaps(setInterval2)) + assertEquals( + eval(invoke("overlaps", TBoolean, interval1, interval2)), + setInterval1.probablyOverlaps(setInterval2), + ) } + } def intInterval(start: Int, end: Int, includesStart: Boolean = true, includesEnd: Boolean = false) : Interval = Interval(start, end, includesStart, includesEnd) - @Test def testIntervalSortAndReduce(): Unit = { + test("IntervalSortAndReduce") { val ord = TInt32.ordering(ctx.stateManager).intervalEndpointOrdering - assert(Interval.union(ArraySeq(), ord) == ArraySeq()) - assert(Interval.union(ArraySeq(intInterval(0, 10)), ord) == ArraySeq(intInterval(0, 10))) + assertEquals(Interval.union(ArraySeq(), ord), ArraySeq()) + assertEquals(Interval.union(ArraySeq(intInterval(0, 10)), ord), ArraySeq(intInterval(0, 10))) - assert( + assertEquals( Interval.union( ArraySeq( intInterval(0, 10), @@ -149,11 +151,12 @@ class IntervalSuite extends HailSuite { intInterval(40, 50), ).reverse, ord, - ) == ArraySeq(intInterval(0, 30), intInterval(40, 50)) + ), + ArraySeq(intInterval(0, 30), intInterval(40, 50)), ) } - @Test def testIntervalIntersection(): Unit = { + test("IntervalIntersection") { val ord = TInt32.ordering(ctx.stateManager).intervalEndpointOrdering val x1 = ArraySeq( @@ -174,15 +177,18 @@ class IntervalSuite extends HailSuite { assert(Interval.intersection(x1, ArraySeq(), ord).isEmpty) assert(Interval.intersection(ArraySeq(), x2, ord).isEmpty) - assert(Interval.intersection(x1, x2, ord) == x1) - assert(Interval.intersection(x1, x2, ord) == x1) - assert(Interval.intersection(x1, x3, ord) == ArraySeq( - intInterval(7, 10), - intInterval(15, 19, includesEnd = true), - )) + assertEquals(Interval.intersection(x1, x2, ord), x1) + assertEquals(Interval.intersection(x1, x2, ord), x1) + assertEquals( + Interval.intersection(x1, x3, ord), + ArraySeq( + intInterval(7, 10), + intInterval(15, 19, includesEnd = true), + ), + ) } - @Test def testsortedNonOverlappingIntervalsContain(): Unit = { + test("sortedNonOverlappingIntervalsContain") { val intervals = Literal( TArray(TInterval(TInt32)), FastSeq( @@ -236,8 +242,8 @@ class IntervalSuite extends HailSuite { val partitionerKType = TStruct("k1" -> TInt32, "k2" -> TInt32, "k3" -> TInt32) var partitioner: Literal = _ - @BeforeMethod - def setupRVDPartitioner(context: ITestContext): Unit = { + override def beforeEach(context: BeforeEach): Unit = { + super.beforeEach(context) partitioner = new RVDPartitioner( ctx.stateManager, partitionerKType, @@ -249,7 +255,7 @@ class IntervalSuite extends HailSuite { ).partitionBoundsIRRepresentation } - @Test def testsortedNonOverlappingPartitionIntervalsEqualRange(): Unit = { + test("sortedNonOverlappingPartitionIntervalsEqualRange") { def assertRange(interval: Interval, startIdx: Int, endIdx: Int): Unit = { val resultType = TTuple(TInt32, TInt32) val irInterval = Literal( @@ -268,7 +274,7 @@ class IntervalSuite extends HailSuite { assertRange(Interval(Row(-1, 7), Row(0, 9), true, false), 0, 0) } - @Test def testPointPartitionIntervalEndpointComparison(): Unit = { + test("PointPartitionIntervalEndpointComparison") { def assertComp( point: IndexedSeq[Int], intervalEndpoint: IndexedSeq[Int], diff --git a/hail/hail/test/src/is/hail/expr/ir/LiftLiteralsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/LiftLiteralsSuite.scala index 5eba97863e0..5cd1a0c2d01 100644 --- a/hail/hail/test/src/is/hail/expr/ir/LiftLiteralsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/LiftLiteralsSuite.scala @@ -7,12 +7,11 @@ import is.hail.expr.ir.defs.{ApplyBinaryPrimOp, I64, MakeStruct, TableCount, Tab import is.hail.expr.ir.lowering.ExecuteRelational import org.apache.spark.sql.Row -import org.testng.annotations.Test class LiftLiteralsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.interpretOnly - @Test def testNestedGlobalsRewrite(): Unit = { + test("NestedGlobalsRewrite") { val tab = TableLiteral(ExecuteRelational(ctx, TableRange(10, 1)).asTableValue(ctx), theHailClassLoader) val ir = TableGetGlobals( diff --git a/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala index 33d9d89448e..c7ff8a3a9b9 100644 --- a/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/LocusFunctionsSuite.scala @@ -9,7 +9,6 @@ import is.hail.utils.Interval import is.hail.variant.{Locus, ReferenceGenome} import org.apache.spark.sql.Row -import org.testng.annotations.Test class LocusFunctionsSuite extends HailSuite { @@ -24,37 +23,42 @@ class LocusFunctionsSuite extends HailSuite { def locus = Locus("chr22", 1, grch38) - @Test def contig(): Unit = - assertEvalsTo(invoke("contig", TString, locusIR), locus.contig) + test("contig")(assertEvalsTo(invoke("contig", TString, locusIR), locus.contig)) - @Test def position(): Unit = - assertEvalsTo(invoke("position", TInt32, locusIR), locus.position) + test("position")(assertEvalsTo(invoke("position", TInt32, locusIR), locus.position)) - @Test def isAutosomalOrPseudoAutosomal(): Unit = + test("isAutosomalOrPseudoAutosomal") { assertEvalsTo( invoke("isAutosomalOrPseudoAutosomal", TBoolean, locusIR), locus.isAutosomalOrPseudoAutosomal(grch38), ) + } - @Test def isAutosomal(): Unit = + test("isAutosomal") { assertEvalsTo(invoke("isAutosomal", TBoolean, locusIR), locus.isAutosomal(grch38)) + } - @Test def inYNonPar(): Unit = + test("inYNonPar") { assertEvalsTo(invoke("inYNonPar", TBoolean, locusIR), locus.inYNonPar(grch38)) + } - @Test def inXPar(): Unit = + test("inXPar") { assertEvalsTo(invoke("inXPar", TBoolean, locusIR), locus.inXPar(grch38)) + } - @Test def isMitochondrial(): Unit = + test("isMitochondrial") { assertEvalsTo(invoke("isMitochondrial", TBoolean, locusIR), locus.isMitochondrial(grch38)) + } - @Test def inXNonPar(): Unit = + test("inXNonPar") { assertEvalsTo(invoke("inXNonPar", TBoolean, locusIR), locus.inXNonPar(grch38)) + } - @Test def inYPar(): Unit = + test("inYPar") { assertEvalsTo(invoke("inYPar", TBoolean, locusIR), locus.inYPar(grch38)) + } - @Test def minRep(): Unit = { + test("minRep") { val alleles = MakeArray(FastSeq(Str("AA"), Str("AT")), TArray(TString)) assertEvalsTo( invoke("min_rep", tvariant, locusIR, alleles), @@ -63,10 +67,11 @@ class LocusFunctionsSuite extends HailSuite { assertEvalsTo(invoke("min_rep", tvariant, locusIR, NA(TArray(TString))), null) } - @Test def globalPosition(): Unit = + test("globalPosition") { assertEvalsTo(invoke("locusToGlobalPos", TInt64, locusIR), grch38.locusToGlobalPos(locus)) + } - @Test def reverseGlobalPosition(): Unit = { + test("reverseGlobalPosition") { val globalPosition = 2824183054L assertEvalsTo( invoke("globalPosToLocus", tlocus, I64(globalPosition)), @@ -74,7 +79,7 @@ class LocusFunctionsSuite extends HailSuite { ) } - @Test def testMultipleReferenceGenomes(): Unit = { + test("MultipleReferenceGenomes") { implicit val execStrats = ExecStrategy.compileOnly val ir = MakeTuple.ordered(FastSeq( @@ -91,7 +96,7 @@ class LocusFunctionsSuite extends HailSuite { ) } - @Test def testMakeInterval(): Unit = { + test("MakeInterval") { // TString, TInt32, TInt32, TBoolean, TBoolean, TBoolean val ir = MakeTuple.ordered(FastSeq( invoke( diff --git a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala index 3e050764181..9963f1ca373 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MathFunctionsSuite.scala @@ -6,7 +6,6 @@ import is.hail.types.virtual._ import is.hail.utils._ import org.apache.spark.sql.Row -import org.testng.annotations.{DataProvider, Test} class MathFunctionsSuite extends HailSuite { @@ -14,7 +13,7 @@ class MathFunctionsSuite extends HailSuite { val tfloat = TFloat64 - @Test def log2(): Unit = { + test("log2") { assertEvalsTo(invoke("log2", TInt32, I32(2)), 1) assertEvalsTo(invoke("log2", TInt32, I32(32)), 5) assertEvalsTo(invoke("log2", TInt32, I32(33)), 5) @@ -22,7 +21,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("log2", TInt32, I32(64)), 6) } - @Test def roundToNextPowerOf2(): Unit = { + test("roundToNextPowerOf2") { assertEvalsTo(invoke("roundToNextPowerOf2", TInt32, I32(2)), 2) assertEvalsTo(invoke("roundToNextPowerOf2", TInt32, I32(32)), 32) assertEvalsTo(invoke("roundToNextPowerOf2", TInt32, I32(33)), 64) @@ -30,7 +29,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("roundToNextPowerOf2", TInt32, I32(64)), 64) } - @Test def isnan(): Unit = { + test("isnan") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("isnan", TBoolean, F32(0)), false) @@ -40,7 +39,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("isnan", TBoolean, F64(Double.NaN)), true) } - @Test def is_finite(): Unit = { + test("is_finite") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("is_finite", TBoolean, F32(0)), expected = true) @@ -56,7 +55,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("is_finite", TBoolean, F64(Double.NegativeInfinity)), expected = false) } - @Test def is_infinite(): Unit = { + test("is_infinite") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("is_infinite", TBoolean, F32(0)), expected = false) @@ -72,7 +71,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("is_infinite", TBoolean, F64(Double.NegativeInfinity)), expected = true) } - @Test def sign(): Unit = { + test("sign") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("sign", TInt32, I32(2)), 1) @@ -96,7 +95,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("sign", TFloat64, F64(Double.NegativeInfinity)), -1.0) } - @Test def approxEqual(): Unit = { + test("approxEqual") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo( @@ -157,7 +156,7 @@ class MathFunctionsSuite extends HailSuite { ) } - @Test def entropy(): Unit = { + test("entropy") { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("entropy", TFloat64, Str("")), 0.0) @@ -167,120 +166,128 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("entropy", TFloat64, Str("accctg")), 1.7924812503605778) } - @DataProvider(name = "chi_squared_test") - def chiSquaredData(): Array[Array[Any]] = Array( - Array(0, 0, 0, 0, Double.NaN, Double.NaN), - Array(0, 1, 1, 1, 0.38647623077123266, 0.0), - Array(1, 0, 1, 1, 0.38647623077123266, Double.PositiveInfinity), - Array(1, 1, 0, 1, 0.38647623077123266, Double.PositiveInfinity), - Array(1, 1, 1, 0, 0.38647623077123266, 0.0), - Array(10, 10, 10, 10, 1.0, 1.0), - Array(51, 43, 22, 92, 1.462626e-7, (51.0 * 92) / (22 * 43)), - ) - - @Test(dataProvider = "chi_squared_test") - def chiSquaredTest(a: Int, b: Int, c: Int, d: Int, pValue: Double, oddsRatio: Double): Unit = { - val r = eval(invoke( - "chi_squared_test", - stats.chisqStruct.virtualType, - ErrorIDs.NO_ERROR, - a, - b, - c, - d, - )).asInstanceOf[Row] - assert(D0_==(pValue, r.getDouble(0))) - assert(D0_==(oddsRatio, r.getDouble(1))) + object checkChiSquaredTest extends TestCases { + def apply( + a: Int, + b: Int, + c: Int, + d: Int, + pValue: Double, + oddsRatio: Double, + )(implicit loc: munit.Location + ): Unit = test("chiSquaredTest") { + val r = eval(invoke( + "chi_squared_test", + stats.chisqStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + )).asInstanceOf[Row] + assert(D0_==(pValue, r.getDouble(0))) + assert(D0_==(oddsRatio, r.getDouble(1))) + } } - @DataProvider(name = "fisher_exact_test") - def fisherExactData(): Array[Array[Any]] = Array( - Array(0, 0, 0, 0, Double.NaN, Double.NaN, Double.NaN, Double.NaN), - Array(10, 10, 10, 10, 1.0, 1.0, 0.243858, 4.100748), - Array(51, 43, 22, 92, 2.1565e-7, 4.918058, 2.565937, 9.677930), - ) - - @Test(dataProvider = "fisher_exact_test") - def fisherExactTest( - a: Int, - b: Int, - c: Int, - d: Int, - pValue: Double, - oddsRatio: Double, - confLower: Double, - confUpper: Double, - ): Unit = { - val r = eval(invoke( - "fisher_exact_test", - stats.fetStruct.virtualType, - ErrorIDs.NO_ERROR, - a, - b, - c, - d, - )).asInstanceOf[Row] - assert(D0_==(pValue, r.getDouble(0))) - assert(D0_==(oddsRatio, r.getDouble(1))) - assert(D0_==(confLower, r.getDouble(2))) - assert(D0_==(confUpper, r.getDouble(3))) + checkChiSquaredTest(0, 0, 0, 0, Double.NaN, Double.NaN) + checkChiSquaredTest(0, 1, 1, 1, 0.38647623077123266, 0.0) + checkChiSquaredTest(1, 0, 1, 1, 0.38647623077123266, Double.PositiveInfinity) + checkChiSquaredTest(1, 1, 0, 1, 0.38647623077123266, Double.PositiveInfinity) + checkChiSquaredTest(1, 1, 1, 0, 0.38647623077123266, 0.0) + checkChiSquaredTest(10, 10, 10, 10, 1.0, 1.0) + checkChiSquaredTest(51, 43, 22, 92, 1.462626e-7, (51.0 * 92) / (22 * 43)) + + object checkFisherExactTest extends TestCases { + def apply( + a: Int, + b: Int, + c: Int, + d: Int, + pValue: Double, + oddsRatio: Double, + confLower: Double, + confUpper: Double, + )(implicit loc: munit.Location + ): Unit = test("fisherExactTest") { + val r = eval(invoke( + "fisher_exact_test", + stats.fetStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + )).asInstanceOf[Row] + assert(D0_==(pValue, r.getDouble(0))) + assert(D0_==(oddsRatio, r.getDouble(1))) + assert(D0_==(confLower, r.getDouble(2))) + assert(D0_==(confUpper, r.getDouble(3))) + } } - @DataProvider(name = "contingency_table_test") - def contingencyTableData(): Array[Array[Any]] = Array( - Array(51, 43, 22, 92, 22, 1.462626e-7, 4.95983087), - Array(51, 43, 22, 92, 23, 2.1565e-7, 4.91805817), - ) - - @Test(dataProvider = "contingency_table_test") - def contingencyTableTest( - a: Int, - b: Int, - c: Int, - d: Int, - minCellCount: Int, - pValue: Double, - oddsRatio: Double, - ): Unit = { - val r = eval(invoke( - "contingency_table_test", - stats.chisqStruct.virtualType, - ErrorIDs.NO_ERROR, - a, - b, - c, - d, - minCellCount, - )).asInstanceOf[Row] - assert(D0_==(pValue, r.getDouble(0))) - assert(D0_==(oddsRatio, r.getDouble(1))) + checkFisherExactTest(0, 0, 0, 0, Double.NaN, Double.NaN, Double.NaN, Double.NaN) + checkFisherExactTest(10, 10, 10, 10, 1.0, 1.0, 0.243858, 4.100748) + checkFisherExactTest(51, 43, 22, 92, 2.1565e-7, 4.918058, 2.565937, 9.677930) + + object checkContingencyTableTest extends TestCases { + def apply( + a: Int, + b: Int, + c: Int, + d: Int, + minCellCount: Int, + pValue: Double, + oddsRatio: Double, + )(implicit loc: munit.Location + ): Unit = test("contingencyTableTest") { + val r = eval(invoke( + "contingency_table_test", + stats.chisqStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + minCellCount, + )).asInstanceOf[Row] + assert(D0_==(pValue, r.getDouble(0))) + assert(D0_==(oddsRatio, r.getDouble(1))) + } } - @DataProvider(name = "hardy_weinberg_test") - def hardyWeinbergData(): Array[Array[Any]] = Array( - Array(0, 0, 0, Double.NaN, 0.5), - Array(1, 2, 1, 0.57142857, 0.65714285), - Array(0, 1, 0, 1.0, 0.5), - Array(100, 200, 100, 0.50062578, 0.96016808), - ) - - @Test(dataProvider = "hardy_weinberg_test") - def hardyWeinbergTest(nHomRef: Int, nHet: Int, nHomVar: Int, pValue: Double, hetFreq: Double) - : Unit = { - val r = eval(invoke( - "hardy_weinberg_test", - stats.hweStruct.virtualType, - ErrorIDs.NO_ERROR, - nHomRef, - nHet, - nHomVar, - false, - )).asInstanceOf[Row] - assert(D0_==(pValue, r.getDouble(0))) - assert(D0_==(hetFreq, r.getDouble(1))) + checkContingencyTableTest(51, 43, 22, 92, 22, 1.462626e-7, 4.95983087) + checkContingencyTableTest(51, 43, 22, 92, 23, 2.1565e-7, 4.91805817) + + object checkHardyWeinbergTest extends TestCases { + def apply( + nHomRef: Int, + nHet: Int, + nHomVar: Int, + pValue: Double, + hetFreq: Double, + )(implicit loc: munit.Location + ): Unit = test("hardyWeinbergTest") { + val r = eval(invoke( + "hardy_weinberg_test", + stats.hweStruct.virtualType, + ErrorIDs.NO_ERROR, + nHomRef, + nHet, + nHomVar, + false, + )).asInstanceOf[Row] + assert(D0_==(pValue, r.getDouble(0))) + assert(D0_==(hetFreq, r.getDouble(1))) + } } - @Test def modulusTest(): Unit = { + checkHardyWeinbergTest(0, 0, 0, Double.NaN, 0.5) + checkHardyWeinbergTest(1, 2, 1, 0.57142857, 0.65714285) + checkHardyWeinbergTest(0, 1, 0, 1.0, 0.5) + checkHardyWeinbergTest(100, 200, 100, 0.50062578, 0.96016808) + + test("modulusTest") { assertFatal( invoke("mod", TInt32, I32(1), I32(0)), "(modulo by zero)|(error while calling 'mod')", @@ -299,7 +306,7 @@ class MathFunctionsSuite extends HailSuite { ) } - @Test def testMinMax(): Unit = { + test("MinMax") { implicit val execStrats = ExecStrategy.javaOnly assertAllEvalTo( (invoke("min", TFloat32, F32(1.0f), F32(2.0f)), 1.0f), diff --git a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala index d69831ecaad..459b96c8d8e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MatrixIRSuite.scala @@ -19,16 +19,13 @@ import scala.collection.compat._ import org.apache.spark.sql.Row import org.json4s.jackson.JsonMethods -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.{DataProvider, Test} class MatrixIRSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) - @Test def testMatrixWriteRead(): Unit = { + test("MatrixWriteRead") { val range = MatrixIR.range(ctx, 10, 10, Some(3)) val withEntries = MatrixMapEntries( range, @@ -53,7 +50,7 @@ class MatrixIRSuite extends HailSuite { partitionsTypeStr = partType.parsableString(), ) - forAll(ArraySeq(writer1, writer2)) { writer => + ArraySeq(writer1, writer2).foreach { writer => assertEvalsTo(MatrixWrite(original, writer), ()) val read = MatrixIR.read(fs, path, dropCols = false, dropRows = false, None) @@ -108,7 +105,7 @@ class MatrixIRSuite extends HailSuite { def getCols(mir: MatrixIR): Array[Row] = Interpret(MatrixColsTable(mir), ctx).rdd.collect() - @Test def testScanCountBehavesLikeIndexOnRows(): Unit = { + test("ScanCountBehavesLikeIndexOnRows") { val mt = rangeMatrix() val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) @@ -119,7 +116,7 @@ class MatrixIRSuite extends HailSuite { assert(rows.forall { case Row(row_idx, idx) => row_idx == idx }, rows.toSeq) } - @Test def testScanCollectBehavesLikeRangeOnRows(): Unit = { + test("ScanCollectBehavesLikeRangeOnRows") { val mt = rangeMatrix() val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) @@ -133,7 +130,7 @@ class MatrixIRSuite extends HailSuite { }) } - @Test def testScanCollectBehavesLikeRangeWithAggregationOnRows(): Unit = { + test("ScanCollectBehavesLikeRangeWithAggregationOnRows") { val mt = rangeMatrix() val oldRow = Ref(MatrixIR.rowName, mt.typ.rowType) @@ -149,7 +146,7 @@ class MatrixIRSuite extends HailSuite { }) } - @Test def testScanCountBehavesLikeIndexOnCols(): Unit = { + test("ScanCountBehavesLikeIndexOnCols") { val mt = rangeMatrix() val oldCol = Ref(MatrixIR.colName, mt.typ.colType) @@ -160,7 +157,7 @@ class MatrixIRSuite extends HailSuite { assert(cols.forall { case Row(col_idx, idx) => col_idx == idx }) } - @Test def testScanCollectBehavesLikeRangeOnCols(): Unit = { + test("ScanCollectBehavesLikeRangeOnCols") { val mt = rangeMatrix() val oldCol = Ref(MatrixIR.colName, mt.typ.colType) @@ -174,7 +171,7 @@ class MatrixIRSuite extends HailSuite { }) } - @Test def testScanCollectBehavesLikeRangeWithAggregationOnCols(): Unit = { + test("ScanCollectBehavesLikeRangeWithAggregationOnCols") { val mt = rangeMatrix() val oldCol = Ref(MatrixIR.colName, mt.typ.colType) @@ -206,58 +203,61 @@ class MatrixIRSuite extends HailSuite { ) } - @DataProvider(name = "unionRowsData") - def unionRowsData(): Array[Array[Any]] = Array( - Array(FastSeq(0 -> 0, 5 -> 7)), - Array(FastSeq(0 -> 1, 5 -> 7)), - Array(FastSeq(0 -> 6, 5 -> 7)), - Array(FastSeq(2 -> 3, 0 -> 1, 5 -> 7)), - Array(FastSeq(2 -> 4, 0 -> 3, 5 -> 7)), - Array(FastSeq(3 -> 6, 0 -> 1, 5 -> 7)), - ) - - @Test(dataProvider = "unionRowsData") - def testMatrixUnionRows(ranges: IndexedSeq[(Int, Int)]): Unit = { - val expectedOrdering = ranges.flatMap { case (start, end) => - Array.range(start, end) - }.sorted - - val unioned = MatrixUnionRows(ranges.map { case (start, end) => - rangeRowMatrix(start, end) - }) - val actualOrdering = getRows(unioned).map { case Row(i: Int) => i } - - assert(actualOrdering sameElements expectedOrdering) + object checkMatrixUnionRows extends TestCases { + def apply( + ranges: IndexedSeq[(Int, Int)] + )(implicit loc: munit.Location + ): Unit = test("MatrixUnionRows") { + val expectedOrdering = ranges.flatMap { case (start, end) => + Array.range(start, end) + }.sorted + + val unioned = MatrixUnionRows(ranges.map { case (start, end) => + rangeRowMatrix(start, end) + }) + val actualOrdering = getRows(unioned).map { case Row(i: Int) => i } + + assert(actualOrdering sameElements expectedOrdering) + } } - @DataProvider(name = "explodeRowsData") - def explodeRowsData(): Array[Array[Any]] = Array( - Array(FastSeq("empty"), FastSeq()), - Array(FastSeq("null"), null), - Array(FastSeq("set"), FastSeq(1, 3)), - Array(FastSeq("one"), FastSeq(3)), - Array(FastSeq("na"), FastSeq(null)), - Array(FastSeq("x", "y"), FastSeq(3)), - Array(FastSeq("foo", "bar"), FastSeq(1, 3)), - Array(FastSeq("a", "b", "c"), FastSeq()), - ) - - @Test(dataProvider = "explodeRowsData") - def testMatrixExplode(path: IndexedSeq[String], collection: IndexedSeq[Integer]): Unit = { - val range = rangeMatrix(5, 2, None) - - val field = path.init.foldRight(path.last -> toIRArray(collection))(_ -> IRStruct(_)) - val annotated = - MatrixMapRows(range, InsertFields(Ref(MatrixIR.rowName, range.typ.rowType), FastSeq(field))) - - val q = annotated.typ.rowType.query(path: _*) - val exploded = - getRows(MatrixExplodeRows(annotated, path)).map(q(_).asInstanceOf[Integer]) - - val expected = if (collection == null) Array[Integer]() else Array.fill(5)(collection).flatten - assert(exploded sameElements expected) + checkMatrixUnionRows(FastSeq(0 -> 0, 5 -> 7)) + checkMatrixUnionRows(FastSeq(0 -> 1, 5 -> 7)) + checkMatrixUnionRows(FastSeq(0 -> 6, 5 -> 7)) + checkMatrixUnionRows(FastSeq(2 -> 3, 0 -> 1, 5 -> 7)) + checkMatrixUnionRows(FastSeq(2 -> 4, 0 -> 3, 5 -> 7)) + checkMatrixUnionRows(FastSeq(3 -> 6, 0 -> 1, 5 -> 7)) + + object checkMatrixExplode extends TestCases { + def apply( + path: IndexedSeq[String], + collection: IndexedSeq[Integer], + )(implicit loc: munit.Location + ): Unit = test("MatrixExplode") { + val range = rangeMatrix(5, 2, None) + + val field = path.init.foldRight(path.last -> toIRArray(collection))(_ -> IRStruct(_)) + val annotated = + MatrixMapRows(range, InsertFields(Ref(MatrixIR.rowName, range.typ.rowType), FastSeq(field))) + + val q = annotated.typ.rowType.query(path: _*) + val exploded = + getRows(MatrixExplodeRows(annotated, path)).map(q(_).asInstanceOf[Integer]) + + val expected = if (collection == null) Array[Integer]() else Array.fill(5)(collection).flatten + assert(exploded sameElements expected) + } } + checkMatrixExplode(FastSeq("empty"), FastSeq()) + checkMatrixExplode(FastSeq("null"), null) + checkMatrixExplode(FastSeq("set"), FastSeq(1, 3)) + checkMatrixExplode(FastSeq("one"), FastSeq(3)) + checkMatrixExplode(FastSeq("na"), FastSeq(null)) + checkMatrixExplode(FastSeq("x", "y"), FastSeq(3)) + checkMatrixExplode(FastSeq("foo", "bar"), FastSeq(1, 3)) + checkMatrixExplode(FastSeq("a", "b", "c"), FastSeq()) + // these two items are helper for UnlocalizedEntries testing, def makeLocalizedTable(rdata: IndexedSeq[Row], cdata: IndexedSeq[Row]): TableIR = { val rowRdd = sc.parallelize(rdata) @@ -278,7 +278,7 @@ class MatrixIRSuite extends HailSuite { TableLiteral(tv, theHailClassLoader) } - @Test def testCastTableToMatrix(): Unit = { + test("CastTableToMatrix") { val rdata = ArraySeq( Row(1, "fish", FastSeq(Row("a", 1.0), Row("x", 2.0))), Row(2, "cat", FastSeq(Row("b", 0.0), Row("y", 0.1))), @@ -307,7 +307,7 @@ class MatrixIRSuite extends HailSuite { assert(localCols sameElements cdata) } - @Test def testCastTableToMatrixErrors(): Unit = { + test("CastTableToMatrixErrors") { val rdata = ArraySeq( Row(1, "fish", FastSeq(Row("x", 2.0))), Row(2, "cat", FastSeq(Row("b", 0.0), Row("y", 0.1))), @@ -342,7 +342,7 @@ class MatrixIRSuite extends HailSuite { interceptSpark("missing")(Interpret(mir2, ctx).rvd.count()) } - @Test def testMatrixFiltersWorkWithRandomness(): Unit = { + test("MatrixFiltersWorkWithRandomness") { val range = rangeMatrix(20, 20, Some(4), uids = true) def rand(rng: IR): IR = Apply("rand_bool", FastSeq.empty, FastSeq(RNGSplitStatic(rng, 0), 0.5), TBoolean) @@ -366,7 +366,7 @@ class MatrixIRSuite extends HailSuite { assert(entries < 400 && entries > 0) } - @Test def testMatrixRepartition(): Unit = { + test("MatrixRepartition") { val range = rangeMatrix(11, 3, Some(10)) val params = Array( @@ -378,7 +378,7 @@ class MatrixIRSuite extends HailSuite { 10 -> RepartitionStrategy.COALESCE, ) - forAll(params) { case (n, strat) => + params.foreach { case (n, strat) => unoptimized { ctx => val rvd = Interpret(MatrixRepartition(range, n, strat), ctx).rvd assert(rvd.getNumPartitions == n, n -> strat) @@ -388,12 +388,12 @@ class MatrixIRSuite extends HailSuite { } } - @Test def testMatrixMultiWriteDifferentTypesRaisesError(): Unit = { + test("MatrixMultiWriteDifferentTypesRaisesError") { val vcf = importVCF(ctx, getTestResource("sample.vcf")) val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") - assertThrows[HailException] { + intercept[HailException] { TypeCheck( ctx, MatrixMultiWrite(FastSeq(vcf, range), MatrixNativeMultiWriter(IndexedSeq(path1, path2))), diff --git a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala index 6e413d7ca30..8fb4af126eb 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MemoryLeakSuite.scala @@ -7,10 +7,8 @@ import is.hail.expr.ir.defs.{Literal, ToArray, ToStream} import is.hail.types.virtual.{TArray, TBoolean, TSet, TString} import is.hail.utils._ -import org.testng.annotations.Test - class MemoryLeakSuite extends HailSuite { - @Test def testLiteralSetContains(): Unit = { + test("LiteralSetContains") { val litSize = 32000 diff --git a/hail/hail/test/src/is/hail/expr/ir/MissingArrayBuilderSuite.scala b/hail/hail/test/src/is/hail/expr/ir/MissingArrayBuilderSuite.scala index fad2aac8cf4..a18b3f0a557 100644 --- a/hail/hail/test/src/is/hail/expr/ir/MissingArrayBuilderSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/MissingArrayBuilderSuite.scala @@ -1,14 +1,12 @@ package is.hail.expr.ir +import is.hail.TestCaseSupport import is.hail.asm4s.AsmFunction2 import is.hail.collection.FastSeq import scala.reflect.ClassTag -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.{DataProvider, Test} - -class MissingArrayBuilderSuite extends TestNGSuite { +class MissingArrayBuilderSuite extends munit.FunSuite with TestCaseSupport { def ordering[T <: AnyVal](f: (T, T) => Boolean): AsmFunction2[T, T, Boolean] = new AsmFunction2[T, T, Boolean] { override def apply(i: T, j: T): Boolean = f(i, j) @@ -38,108 +36,105 @@ class MissingArrayBuilderSuite extends TestNGSuite { } } - @DataProvider(name = "sortInt") - def integerData(): Array[Array[Any]] = Array( - Array(FastSeq(3, null, 3, 7, null), FastSeq(3, 3, 7, null, null)), - Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()), - ) - - @Test(dataProvider = "sortInt") - def testSortOnIntArrayBuilder(array: IndexedSeq[Integer], expected: IndexedSeq[Integer]): Unit = { - val ab = new IntMissingArrayBuilder(16) - addToArrayBuilder(ab, array)((iab, i) => iab.add(i)) - - ab.sort(ordering[Int]((i, j) => i < j)) - val result = getResult[IntMissingArrayBuilder, Integer](ab)((iab, i) => Int.box(iab(i))) - assert(result sameElements expected) + object checkSortOnIntArrayBuilder extends TestCases { + def apply( + array: IndexedSeq[Integer], + expected: IndexedSeq[Integer], + )(implicit loc: munit.Location + ): Unit = test("sort int array builder") { + val ab = new IntMissingArrayBuilder(16) + addToArrayBuilder(ab, array)((iab, i) => iab.add(i)) + ab.sort(ordering[Int]((i, j) => i < j)) + val result = getResult[IntMissingArrayBuilder, Integer](ab)((iab, i) => Int.box(iab(i))) + assert(result sameElements expected) + } } - @DataProvider(name = "sortLong") - def longData(): Array[Array[Any]] = Array( - Array(FastSeq(3L, null, 3L, 7L, null), FastSeq(3L, 3L, 7L, null, null)), - Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()), - ) - - @Test(dataProvider = "sortLong") - def testSortOnLongArrayBuilder( - array: IndexedSeq[java.lang.Long], - expected: IndexedSeq[java.lang.Long], - ): Unit = { - val ab = new LongMissingArrayBuilder(16) - addToArrayBuilder(ab, array)((jab, j) => jab.add(j)) - - ab.sort(ordering[Long]((i, j) => i < j)) - val result = getResult[LongMissingArrayBuilder, java.lang.Long](ab) { (jab, j) => - Long.box(jab(j)) + checkSortOnIntArrayBuilder(FastSeq(3, null, 3, 7, null), FastSeq(3, 3, 7, null, null)) + checkSortOnIntArrayBuilder(FastSeq(null, null, null, null), FastSeq(null, null, null, null)) + checkSortOnIntArrayBuilder(FastSeq(), FastSeq()) + + object checkSortOnLongArrayBuilder extends TestCases { + def apply( + array: IndexedSeq[java.lang.Long], + expected: IndexedSeq[java.lang.Long], + )(implicit loc: munit.Location + ): Unit = test("sort long array builder") { + val ab = new LongMissingArrayBuilder(16) + addToArrayBuilder(ab, array)((jab, j) => jab.add(j)) + ab.sort(ordering[Long]((i, j) => i < j)) + val result = getResult[LongMissingArrayBuilder, java.lang.Long](ab) { (jab, j) => + Long.box(jab(j)) + } + assert(result sameElements expected) } - assert(result sameElements expected) } - @DataProvider(name = "sortFloat") - def floatData(): Array[Array[Any]] = Array( - Array(FastSeq(3f, null, 3f, 7f, null), FastSeq(3f, 3f, 7f, null, null)), - Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()), - ) - - @Test(dataProvider = "sortFloat") - def testSortOnFloatArrayBuilder( - array: IndexedSeq[java.lang.Float], - expected: IndexedSeq[java.lang.Float], - ): Unit = { - val ab = new FloatMissingArrayBuilder(16) - addToArrayBuilder(ab, array)((fab, f) => fab.add(f)) - - ab.sort(ordering[Float]((i, j) => i < j)) - val result = getResult[FloatMissingArrayBuilder, java.lang.Float](ab) { (fab, f) => - Float.box(fab(f)) + checkSortOnLongArrayBuilder(FastSeq(3L, null, 3L, 7L, null), FastSeq(3L, 3L, 7L, null, null)) + checkSortOnLongArrayBuilder(FastSeq(null, null, null, null), FastSeq(null, null, null, null)) + checkSortOnLongArrayBuilder(FastSeq(), FastSeq()) + + object checkSortOnFloatArrayBuilder extends TestCases { + def apply( + array: IndexedSeq[java.lang.Float], + expected: IndexedSeq[java.lang.Float], + )(implicit loc: munit.Location + ): Unit = test("sort float array builder") { + val ab = new FloatMissingArrayBuilder(16) + addToArrayBuilder(ab, array)((fab, f) => fab.add(f)) + ab.sort(ordering[Float]((i, j) => i < j)) + val result = getResult[FloatMissingArrayBuilder, java.lang.Float](ab) { (fab, f) => + Float.box(fab(f)) + } + assert(result sameElements expected) } - assert(result sameElements expected) } - @DataProvider(name = "sortDouble") - def doubleData(): Array[Array[Any]] = Array( - Array(FastSeq(3d, null, 3d, 7d, null), FastSeq(3d, 3d, 7d, null, null)), - Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()), - ) - - @Test(dataProvider = "sortDouble") - def testSortOnDoubleArrayBuilder( - array: IndexedSeq[java.lang.Double], - expected: IndexedSeq[java.lang.Double], - ): Unit = { - val ab = new DoubleMissingArrayBuilder(16) - addToArrayBuilder(ab, array)((dab, d) => dab.add(d)) + checkSortOnFloatArrayBuilder(FastSeq(3f, null, 3f, 7f, null), FastSeq(3f, 3f, 7f, null, null)) + checkSortOnFloatArrayBuilder(FastSeq(null, null, null, null), FastSeq(null, null, null, null)) + checkSortOnFloatArrayBuilder(FastSeq(), FastSeq()) + + object checkSortOnDoubleArrayBuilder extends TestCases { + def apply( + array: IndexedSeq[java.lang.Double], + expected: IndexedSeq[java.lang.Double], + )(implicit loc: munit.Location + ): Unit = test("sort double array builder") { + val ab = new DoubleMissingArrayBuilder(16) + addToArrayBuilder(ab, array)((dab, d) => dab.add(d)) + ab.sort(ordering[Double]((i, j) => i < j)) + val result = getResult[DoubleMissingArrayBuilder, java.lang.Double](ab) { (dab, d) => + Double.box(dab(d)) + } + assert(result sameElements expected) + } + } - ab.sort(ordering[Double]((i, j) => i < j)) - val result = getResult[DoubleMissingArrayBuilder, java.lang.Double](ab) { (dab, d) => - Double.box(dab(d)) + checkSortOnDoubleArrayBuilder(FastSeq(3d, null, 3d, 7d, null), FastSeq(3d, 3d, 7d, null, null)) + checkSortOnDoubleArrayBuilder(FastSeq(null, null, null, null), FastSeq(null, null, null, null)) + checkSortOnDoubleArrayBuilder(FastSeq(), FastSeq()) + + object checkSortOnBooleanArrayBuilder extends TestCases { + def apply( + array: IndexedSeq[java.lang.Boolean], + expected: IndexedSeq[java.lang.Boolean], + )(implicit loc: munit.Location + ): Unit = test("sort boolean array builder") { + val ab = new BooleanMissingArrayBuilder(16) + addToArrayBuilder(ab, array)((bab, b) => bab.add(b)) + ab.sort(ordering[Boolean]((i, j) => i < j)) + val result = getResult[BooleanMissingArrayBuilder, java.lang.Boolean](ab) { (bab, b) => + Boolean.box(bab(b)) + } + assert(result sameElements expected) } - assert(result sameElements expected) } - @DataProvider(name = "sortBoolean") - def booleanData(): Array[Array[Any]] = Array( - Array(FastSeq(true, null, true, false, null), FastSeq(false, true, true, null, null)), - Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()), + checkSortOnBooleanArrayBuilder( + FastSeq(true, null, true, false, null), + FastSeq(false, true, true, null, null), ) - @Test(dataProvider = "sortBoolean") - def testSortOnBooleanArrayBuilder( - array: IndexedSeq[java.lang.Boolean], - expected: IndexedSeq[java.lang.Boolean], - ): Unit = { - val ab = new BooleanMissingArrayBuilder(16) - addToArrayBuilder(ab, array)((bab, b) => bab.add(b)) - - ab.sort(ordering[Boolean]((i, j) => i < j)) - val result = getResult[BooleanMissingArrayBuilder, java.lang.Boolean](ab) { (bab, b) => - Boolean.box(bab(b)) - } - assert(result sameElements expected) - } + checkSortOnBooleanArrayBuilder(FastSeq(null, null, null, null), FastSeq(null, null, null, null)) + checkSortOnBooleanArrayBuilder(FastSeq(), FastSeq()) } diff --git a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala index 75f3c2ed4a3..7171aeeefe7 100644 --- a/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/OrderingSuite.scala @@ -21,12 +21,9 @@ import is.hail import org.apache.spark.sql.Row import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen -import org.scalatest -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.{DataProvider, Test} +import org.scalacheck.Prop.forAll -class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class OrderingSuite extends HailSuite with munit.ScalaCheckSuite { implicit val execStrats: hail.ExecStrategy.ValueSet = ExecStrategy.values @@ -67,31 +64,31 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) } - @Test def testMissingNonequalComparisons(): Unit = { - def getStagedOrderingFunctionWithMissingness( - t: PType, - op: CodeOrdering.Op, - r: Region, - sortOrder: SortOrder = Ascending, - ): AsmFunction5[Region, Boolean, Long, Boolean, Long, op.ReturnType] = { - implicit val x = op.rtti - val fb = - EmitFunctionBuilder[Region, Boolean, Long, Boolean, Long, op.ReturnType](ctx, "lifted") - fb.emitWithBuilder { cb => - val m1 = fb.getCodeParam[Boolean](2) - val cv1 = t.loadCheapSCode(cb, fb.getCodeParam[Long](3)) - val m2 = fb.getCodeParam[Boolean](4) - val cv2 = t.loadCheapSCode(cb, fb.getCodeParam[Long](5)) - val ev1 = EmitValue(Some(m1), cv1) - val ev2 = EmitValue(Some(m2), cv2) - fb.ecb.getOrderingFunction(ev1.st, ev2.st, op) - .apply(cb, ev1, ev2) - } - fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) - } - - forAll(genTypeVal[TStruct](ctx)) { case (t, a) => + property("MissingNonequalComparisons") = + forAll(genTypeVal[TStruct](ctx)) { case (t: TStruct, a) => pool.scopedRegion { region => + def getStagedOrderingFunctionWithMissingness( + t: PType, + op: CodeOrdering.Op, + r: Region, + sortOrder: SortOrder = Ascending, + ): AsmFunction5[Region, Boolean, Long, Boolean, Long, op.ReturnType] = { + implicit val x = op.rtti + val fb = + EmitFunctionBuilder[Region, Boolean, Long, Boolean, Long, op.ReturnType](ctx, "lifted") + fb.emitWithBuilder { cb => + val m1 = fb.getCodeParam[Boolean](2) + val cv1 = t.loadCheapSCode(cb, fb.getCodeParam[Long](3)) + val m2 = fb.getCodeParam[Boolean](4) + val cv2 = t.loadCheapSCode(cb, fb.getCodeParam[Long](5)) + val ev1 = EmitValue(Some(m1), cv1) + val ev2 = EmitValue(Some(m2), cv2) + fb.ecb.getOrderingFunction(ev1.st, ev2.st, op) + .apply(cb, ev1, ev2) + } + fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) + } + val pType = PType.canonical(t).asInstanceOf[PStruct] val v = pType.unstagedStoreJavaObject(sm, a, region) @@ -233,11 +230,11 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { check(eordMNE.gteq(null, a), true) check(eordMNE.gteq(a, null), false) } + true } - } - @Test def testRandomOpsAgainstExtended(): Unit = - forAll(genTypeNonMissingVal2) { case (t, a1, a2) => + property("RandomOpsAgainstExtended") = + forAll(genTypeNonMissingVal2) { case (t: Type, a1, a2) => pool.scopedRegion { region => val pType = PType.canonical(t) @@ -279,10 +276,11 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(fgteq(region, v1, v2) == gteq, s"gteq expected: $gteq") } + true } - @Test def testReverseIsSwappedArgumentsOfExtendedOrdering(): Unit = - forAll(genTypeNonMissingVal2) { case (t, a1, a2) => + property("ReverseIsSwappedArgumentsOfExtendedOrdering") = + forAll(genTypeNonMissingVal2) { case (t: Type, a1, a2) => pool.scopedRegion { region => val pType = PType.canonical(t) @@ -323,36 +321,39 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(fgteq(region, v1, v2) == gteq, s"gteq expected: $gteq") } + true } - @Test def testSortOnRandomArray(): Unit = { + property("SortOnRandomArray") { implicit val execStrats = ExecStrategy.javaOnly forAll(genTypeVal[TArray](ctx), arbitrary[Boolean]) { - case ((tarray, a: IndexedSeq[Any]), asc) => + case ((tarray: TArray, a: IndexedSeq[Any]), asc) => val ord = tarray.elementType.ordering(sm) assertEvalsTo( ArraySort(ToStream(In(0, tarray)), Literal.coerce(TBoolean, asc)), FastSeq(a -> tarray), expected = a.sorted((if (asc) ord else ord.reverse).toOrdering), ) + true } } - @Test def testToSetOnRandomDuplicatedArray(): Unit = { + property("ToSetOnRandomDuplicatedArray") { implicit val execStrats = ExecStrategy.javaOnly - forAll(genTypeVal[TArray](ctx)) { case (tarray, a: IndexedSeq[Any]) => + forAll(genTypeVal[TArray](ctx)) { case (tarray: TArray, a: IndexedSeq[Any]) => val array = a ++ a assertEvalsTo( ToArray(ToStream(ToSet(ToStream(In(0, tarray))))), FastSeq(array -> tarray), expected = array.sorted(tarray.elementType.ordering(sm).toOrdering).distinct, ) + true } } - @Test def testToDictOnRandomDuplicatedArray(): Unit = { + property("ToDictOnRandomDuplicatedArray") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly @@ -366,7 +367,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { array = a.asInstanceOf[IndexedSeq[Annotation]] } yield (pelt.virtualType, array ++ array) - forAll(compareGen) { case (telt, array) => + forAll(compareGen) { case (telt: TTuple, array) => assertEvalsTo( ToArray(mapIR(ToStream(ToDict(ToStream(In(0, TArray(telt))))))(GetField(_, "key"))), FastSeq(array -> TArray(telt)), @@ -378,18 +379,19 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { .distinct .sorted(telt.types(0).ordering(sm).toOrdering), ) + true } } - @Test def testSortOnMissingArray(): Unit = { + test("SortOnMissingArray") { implicit val execStrats = ExecStrategy.javaOnly val ts = TStream(TStruct("key" -> TInt32, "value" -> TInt32)) val irs: Array[IR => IR] = Array(ArraySort(_, True()), ToSet(_), ToDict(_)) - scalatest.Inspectors.forAll(irs)(irF => assertEvalsTo(IsNA(irF(NA(ts))), true)) + irs.foreach(irF => assertEvalsTo(IsNA(irF(NA(ts))), true)) } - @Test def testSetContainsOnRandomSet(): Unit = { + property("SetContainsOnRandomSet") { implicit val execStrats = ExecStrategy.javaOnly val compareGen = for { @@ -415,10 +417,11 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { FastSeq(set -> tset, test1 -> telt), expected = set.contains(test1), ) + true } } - @Test def testDictGetOnRandomDict(): Unit = { + property("DictGetOnRandomDict") { implicit val execStrats = ExecStrategy.javaOnly val compareGen = @@ -436,8 +439,7 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { dict.getOrElse(testKey1, null), ) - if (dict.isEmpty) scalatest.Succeeded - else { + if (dict.nonEmpty) { val testKey2 = dict.keys.toSeq.head assertEvalsTo( invoke("get", tdict.valueType, In(0, tdict), In(1, tdict.keyType)), @@ -445,18 +447,18 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { expected = dict(testKey2), ) } + true } } - @Test def testBinarySearchOnSet(): Unit = { - val compareGen = + property("BinarySearchOnSet") = + forAll( for { elt <- arbitrary[Type] set <- genNonMissingT[Set[Annotation]](ctx, TSet(elt)) v <- genNonMissing(ctx, elt) } yield (elt, set, v) - - forAll(compareGen) { case (t, set, elem) => + ) { case (t: Type, set, elem) => val pt = PType.canonical(t) val pset = PCanonicalSet(pt) @@ -501,18 +503,17 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(((i - 1 < 0) || ordering.compare(asArray(i - 1), elem) < 0) && ((i >= set.size) || ordering.compare(elem, asArray(i)) <= 0)) } + true } - } - @Test def testBinarySearchOnDict(): Unit = { - val compareGen = + property("BinarySearchOnDict") = + forAll( for { tdict <- arbitrary[TDict] dict <- genNonMissingT[Map[Annotation, Annotation]](ctx, tdict, innerRequired = false) key <- genNonMissing(ctx, tdict.keyType, innerRequired = false) } yield (tdict, dict, key) - - forAll(compareGen) { case (tDict, dict, key) => + ) { case (tDict: TDict, dict, key) => val pDict = PType.canonical(tDict).asInstanceOf[PDict] pool.scopedRegion { region => @@ -559,10 +560,10 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(((i - 1 < 0) || ordering.compare(asArray(i - 1).get(0), key) < 0) && ((i >= asArray.size) || ordering.compare(key, asArray(i).get(0)) <= 0)) } + true } - } - @Test def testContainsWithArrayFold(): Unit = { + test("ContainsWithArrayFold") { implicit val execStrats = ExecStrategy.javaOnly val set1 = ToSet(MakeStream(IndexedSeq(I32(1), I32(4)), TStream(TInt32))) val set2 = ToSet(MakeStream(IndexedSeq(I32(9), I32(1), I32(4)), TStream(TInt32))) @@ -583,64 +584,70 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } - @DataProvider(name = "arrayDoubleOrderingData") - def arrayDoubleOrderingData(): Array[Array[Any]] = { + val arrayDoubleOrderingData: IndexedSeq[(IndexedSeq[Any], IndexedSeq[Any])] = { val xs = Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) val as = Array(null: IndexedSeq[Any]) ++ (for (x <- xs) yield IndexedSeq[Any](x)) - for { + (for { a <- as a2 <- as - } yield Array[Any](a, a2) + } yield (a, a2)).toIndexedSeq } - @Test(dataProvider = "arrayDoubleOrderingData") - def testOrderingArrayDouble( - a: IndexedSeq[Any], - a2: IndexedSeq[Any], - ): Unit = { - val t = TArray(TFloat64) - - val args = FastSeq(a -> t, a2 -> t) - - assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + object checkOrderingArrayDouble extends TestCases { + def apply( + a: IndexedSeq[Any], + a2: IndexedSeq[Any], + )(implicit loc: munit.Location + ): Unit = test("OrderingArrayDouble") { + val t = TArray(TFloat64) + + val args = FastSeq(a -> t, a2 -> t) + + assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + } } - @Test(dataProvider = "arrayDoubleOrderingData") - def testOrderingSetDouble( - a: IndexedSeq[Any], - a2: IndexedSeq[Any], - ): Unit = { - val t = TSet(TFloat64) - - val s = if (a != null) a.toSet else null - val s2 = if (a2 != null) a2.toSet else null - val args = FastSeq(s -> t, s2 -> t) - - assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + arrayDoubleOrderingData.foreach { case (a, a2) => checkOrderingArrayDouble(a, a2) } + + object checkOrderingSetDouble extends TestCases { + def apply( + a: IndexedSeq[Any], + a2: IndexedSeq[Any], + )(implicit loc: munit.Location + ): Unit = test("OrderingSetDouble") { + val t = TSet(TFloat64) + + val s = if (a != null) a.toSet else null + val s2 = if (a2 != null) a2.toSet else null + val args = FastSeq(s -> t, s2 -> t) + + assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + } } - @DataProvider(name = "rowDoubleOrderingData") - def rowDoubleOrderingData(): Array[Array[Any]] = { + arrayDoubleOrderingData.foreach { case (a, a2) => checkOrderingSetDouble(a, a2) } + + val rowDoubleOrderingData: IndexedSeq[(Row, Row)] = { val xs = Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) val ss = Array[Any](null, "a", "aa") @@ -650,29 +657,33 @@ class OrderingSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { s <- ss } yield Row(x, s) - for { + (for { r <- rs r2 <- rs - } yield Array[Any](r, r2) + } yield (r, r2)).toIndexedSeq } - @Test(dataProvider = "rowDoubleOrderingData") - def testOrderingRowDouble( - r: Row, - r2: Row, - ): Unit = { - val t = TStruct("x" -> TFloat64, "s" -> TString) - - val args = FastSeq(r -> t, r2 -> t) - - assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) - assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + object checkOrderingRowDouble extends TestCases { + def apply( + r: Row, + r2: Row, + )(implicit loc: munit.Location + ): Unit = test("OrderingRowDouble") { + val t = TStruct("x" -> TFloat64, "s" -> TString) + + val args = FastSeq(r -> t, r2 -> t) + + assertEvalSame(ApplyComparisonOp(EQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(EQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(NEQWithNA, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(LTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GT, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(GTEQ, In(0, t), In(1, t)), args) + assertEvalSame(ApplyComparisonOp(Compare, In(0, t), In(1, t)), args) + } } + + rowDoubleOrderingData.foreach { case (r, r2) => checkOrderingRowDouble(r, r2) } } diff --git a/hail/hail/test/src/is/hail/expr/ir/PruneSuite.scala b/hail/hail/test/src/is/hail/expr/ir/PruneSuite.scala index 51909f5ba97..5aea2139085 100644 --- a/hail/hail/test/src/is/hail/expr/ir/PruneSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/PruneSuite.scala @@ -17,12 +17,9 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.json4s.JValue -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.{DataProvider, Test} class PruneSuite extends HailSuite { - @Test def testUnionType(): Unit = { + test("UnionType") { val base = TStruct( "a" -> TStruct( "aa" -> TInt32, @@ -36,27 +33,35 @@ class PruneSuite extends HailSuite { )), ) - assert(PruneDeadFields.unify(base, TStruct.empty) == TStruct.empty) - assert(PruneDeadFields.unify(base, TStruct("b" -> TInt32)) == TStruct("b" -> TInt32)) - assert( - PruneDeadFields.unify(base, TStruct("a" -> TStruct.empty)) == TStruct("a" -> TStruct.empty) - ) - assert(PruneDeadFields.unify( - base, + assertEquals(PruneDeadFields.unify(base, TStruct.empty), TStruct.empty) + assertEquals(PruneDeadFields.unify(base, TStruct("b" -> TInt32)), TStruct("b" -> TInt32)) + assertEquals( + PruneDeadFields.unify(base, TStruct("a" -> TStruct.empty)), TStruct("a" -> TStruct.empty), - TStruct("b" -> TInt32), - ) == TStruct("a" -> TStruct.empty, "b" -> TInt32)) - assert(PruneDeadFields.unify(base, TStruct("c" -> TArray(TStruct.empty))) == TStruct( - "c" -> TArray(TStruct.empty) - )) - assert(PruneDeadFields.unify( - base, - TStruct("a" -> TStruct("ab" -> TStruct.empty)), + ) + assertEquals( + PruneDeadFields.unify( + base, + TStruct("a" -> TStruct.empty), + TStruct("b" -> TInt32), + ), + TStruct("a" -> TStruct.empty, "b" -> TInt32), + ) + assertEquals( + PruneDeadFields.unify(base, TStruct("c" -> TArray(TStruct.empty))), TStruct("c" -> TArray(TStruct.empty)), - ) == TStruct("a" -> TStruct("ab" -> TStruct.empty), "c" -> TArray(TStruct.empty))) + ) + assertEquals( + PruneDeadFields.unify( + base, + TStruct("a" -> TStruct("ab" -> TStruct.empty)), + TStruct("c" -> TArray(TStruct.empty)), + ), + TStruct("a" -> TStruct("ab" -> TStruct.empty), "c" -> TArray(TStruct.empty)), + ) } - @Test def testIsSupertype(): Unit = { + test("IsSupertype") { val emptyTuple = TTuple.empty val tuple1Int = TTuple(TInt32) val tuple2Ints = TTuple(TInt32, TInt32) @@ -67,7 +72,7 @@ class PruneSuite extends HailSuite { assert(PruneDeadFields.isSupertype(tuple2IntsFirstRemoved, tuple2Ints)) } - @Test def testIsSupertypeWithDistinctFieldTypes(): Unit = { + test("IsSupertypeWithDistinctFieldTypes") { val tuple2Ints = TTuple(TInt32, TFloat64) val tuple2IntsFirstRemoved = TTuple(IndexedSeq(TupleField(1, TFloat64))) @@ -97,7 +102,7 @@ class PruneSuite extends HailSuite { PruneDeadFields.memoizeValueIR(ctx, ir, requestedType.asInstanceOf[Type], ms, envStates) } - forAll(irCopy.children.zipWithIndex) { case (child, i) => + irCopy.children.zipWithIndex.foreach { case (child, i) => assert( expected(i) == null || expected(i) == ms.requestedType.lookup(child), s"For base IR $ir\n Child $i with IR $child\n Expected: ${expected(i)}\n Actual: ${ms.requestedType.get(child)}", @@ -343,7 +348,7 @@ class PruneSuite extends HailSuite { t.typ.globalType.fieldNames.map(x => x -> (x + "_")).toMap, ) - @Test def testTableJoinMemo(): Unit = { + test("TableJoinMemo") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val tk2 = mangle(TableKeyBy(tab, ArraySeq("3"))) val tj = TableJoin(tk1, tk2, "inner", 1) @@ -387,7 +392,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableLeftJoinRightDistinctMemo(): Unit = { + test("TableLeftJoinRightDistinctMemo") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val tk2 = TableKeyBy(tab, ArraySeq("3")) val tj = TableLeftJoinRightDistinct(tk1, tk2, "foo") @@ -401,7 +406,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableIntervalJoinMemo(): Unit = { + test("TableIntervalJoinMemo") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val tk2 = TableKeyBy(tab, ArraySeq("3")) val tj = TableIntervalJoin(tk1, tk2, "foo", product = false) @@ -415,7 +420,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableMultiWayZipJoinMemo(): Unit = { + test("TableMultiWayZipJoinMemo") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val ts = ArraySeq(tk1, tk1, tk1) val tmwzj = TableMultiWayZipJoin(ts, "data", "gbls") @@ -426,12 +431,12 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableExplodeMemo(): Unit = { + test("TableExplodeMemo") { val te = TableExplode(tab, ArraySeq("2")) checkMemo(te, subsetTable(te.typ), ArraySeq(subsetTable(tab.typ, "row.2"))) } - @Test def testTableFilterMemo(): Unit = { + test("TableFilterMemo") { checkMemo( TableFilter(tab, tableRefBoolean(tab.typ, "row.2")), subsetTable(tab.typ, "row.3"), @@ -444,7 +449,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableKeyByMemo(): Unit = { + test("TableKeyByMemo") { val tk = TableKeyBy(tab, ArraySeq("1")) checkMemo( tk, @@ -457,7 +462,7 @@ class PruneSuite extends HailSuite { } - @Test def testTableMapRowsMemo(): Unit = { + test("TableMapRowsMemo") { val tmr = TableMapRows(tab, tableRefStruct(tab.typ, "row.1", "row.2")) checkMemo( tmr, @@ -473,7 +478,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableMapGlobalsMemo(): Unit = { + test("TableMapGlobalsMemo") { val tmg = TableMapGlobals(tab, tableRefStruct(tab.typ.copy(key = FastSeq()), "global.g1")) checkMemo( tmg, @@ -482,7 +487,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixColsTableMemo(): Unit = { + test("MatrixColsTableMemo") { val mct = MatrixColsTable(mat) checkMemo( mct, @@ -491,7 +496,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixRowsTableMemo(): Unit = { + test("MatrixRowsTableMemo") { val mrt = MatrixRowsTable(mat) checkMemo( mrt, @@ -500,7 +505,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixEntriesTableMemo(): Unit = { + test("MatrixEntriesTableMemo") { val met = MatrixEntriesTable(mat) checkMemo( met, @@ -509,7 +514,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableKeyByAndAggregateMemo(): Unit = { + test("TableKeyByAndAggregateMemo") { val tka = TableKeyByAndAggregate( tab, ApplyAggOp(PrevNonnull())(tableRefStruct(tab.typ, "row.2")), @@ -530,7 +535,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableAggregateByKeyMemo(): Unit = { + test("TableAggregateByKeyMemo") { val tabk = TableAggregateByKey( tab, ApplyAggOp(PrevNonnull())(SelectFields( @@ -545,14 +550,15 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableUnionMemo(): Unit = + test("TableUnionMemo") { checkMemo( TableUnion(FastSeq(tab, tab)), subsetTable(tab.typ, "row.1", "global.g1"), ArraySeq(subsetTable(tab.typ, "row.1", "global.g1"), subsetTable(tab.typ, "row.1")), ) + } - @Test def testTableOrderByMemo(): Unit = { + test("TableOrderByMemo") { val tob = TableOrderBy(tab, ArraySeq(SortField("2", Ascending))) checkMemo( tob, @@ -564,7 +570,7 @@ class PruneSuite extends HailSuite { checkMemo(tob2, subsetTable(tob2.typ), ArraySeq(subsetTable(tab.typ))) } - @Test def testCastMatrixToTableMemo(): Unit = { + test("CastMatrixToTableMemo") { val m2t = CastMatrixToTable(mat, "__entries", "__cols") checkMemo( m2t, @@ -573,7 +579,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixFilterColsMemo(): Unit = { + test("MatrixFilterColsMemo") { val mfc = MatrixFilterCols(mat, matrixRefBoolean(mat.typ, "global.g1", "sa.c2")) checkMemo( mfc, @@ -582,7 +588,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixFilterRowsMemo(): Unit = { + test("MatrixFilterRowsMemo") { val mfr = MatrixFilterRows(mat, matrixRefBoolean(mat.typ, "global.g1", "va.r2")) checkMemo( mfr, @@ -591,7 +597,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixFilterEntriesMemo(): Unit = { + test("MatrixFilterEntriesMemo") { val mfe = MatrixFilterEntries(mat, matrixRefBoolean(mat.typ, "global.g1", "va.r2", "sa.c2", "g.e2")) checkMemo( @@ -604,7 +610,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapColsMemo(): Unit = { + test("MatrixMapColsMemo") { val mmc = MatrixMapCols( mat, ApplyAggOp(PrevNonnull())(matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2")), @@ -633,7 +639,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixKeyRowsByMemo(): Unit = { + test("MatrixKeyRowsByMemo") { val mkr = MatrixKeyRowsBy(mat, FastSeq("rk")) checkMemo( mkr, @@ -642,7 +648,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapRowsMemo(): Unit = { + test("MatrixMapRowsMemo") { val mmr = MatrixMapRows( MatrixKeyRowsBy(mat, IndexedSeq.empty), ApplyAggOp(PrevNonnull())(matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2")), @@ -664,7 +670,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapGlobalsMemo(): Unit = { + test("MatrixMapGlobalsMemo") { val mmg = MatrixMapGlobals(mat, matrixRefStruct(mat.typ, "global.g1")) checkMemo( mmg, @@ -673,7 +679,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAnnotateRowsTableMemo(): Unit = { + test("MatrixAnnotateRowsTableMemo") { val tl = TableLiteral(Interpret(MatrixRowsTable(mat), ctx), theHailClassLoader) val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product = false) checkMemo( @@ -683,7 +689,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testCollectColsByKeyMemo(): Unit = { + test("CollectColsByKeyMemo") { val ccbk = MatrixCollectColsByKey(mat) checkMemo( ccbk, @@ -692,7 +698,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixExplodeRowsMemo(): Unit = { + test("MatrixExplodeRowsMemo") { val mer = MatrixExplodeRows(mat, FastSeq("r3")) checkMemo( mer, @@ -701,7 +707,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixRepartitionMemo(): Unit = { + test("MatrixRepartitionMemo") { checkMemo( MatrixRepartition(mat, 10, RepartitionStrategy.SHUFFLE), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), @@ -712,7 +718,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixUnionRowsMemo(): Unit = { + test("MatrixUnionRowsMemo") { checkMemo( MatrixUnionRows(FastSeq(mat, mat)), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), @@ -723,7 +729,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixDistinctByRowMemo(): Unit = { + test("MatrixDistinctByRowMemo") { checkMemo( MatrixDistinctByRow(mat), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), @@ -734,7 +740,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixExplodeColsMemo(): Unit = { + test("MatrixExplodeColsMemo") { val mer = MatrixExplodeCols(mat, FastSeq("c3")) checkMemo( mer, @@ -743,7 +749,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testCastTableToMatrixMemo(): Unit = { + test("CastTableToMatrixMemo") { val m2t = CastMatrixToTable(mat, "__entries", "__cols") val t2m = CastTableToMatrix(m2t, "__entries", "__cols", FastSeq("ck")) checkMemo( @@ -760,7 +766,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAggregateRowsByKeyMemo(): Unit = { + test("MatrixAggregateRowsByKeyMemo") { val magg = MatrixAggregateRowsByKey( mat, ApplyAggOp(PrevNonnull())(matrixRefStruct(mat.typ, "g.e2", "va.r2", "sa.c2")), @@ -777,7 +783,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAggregateColsByKeyMemo(): Unit = { + test("MatrixAggregateColsByKeyMemo") { val magg = MatrixAggregateColsByKey( mat, ApplyAggOp(PrevNonnull())(matrixRefStruct(mat.typ, "g.e2", "va.r2", "sa.c2")), @@ -807,26 +813,29 @@ class PruneSuite extends HailSuite { val justARequired = TStruct("a" -> TInt32) val justBRequired = TStruct("b" -> TInt32) - @Test def testIfMemo(): Unit = + test("IfMemo") { checkMemo(If(True(), ref, ref), justA, ArraySeq(TBoolean, justA, justA), refEnv) + } - @Test def testSwitchMemo(): Unit = + test("SwitchMemo") { checkMemo( Switch(I32(0), ref, FastSeq(ref)), justA, ArraySeq(TInt32, justA, justA), refEnv, ) + } - @Test def testCoalesceMemo(): Unit = + test("CoalesceMemo") { checkMemo(Coalesce(FastSeq(ref, ref)), justA, ArraySeq(justA, justA), refEnv) + } - @Test def testLetMemo(): Unit = { + test("LetMemo") { checkMemo(bindIR(ref)(x => x), justA, ArraySeq(justA, null), refEnv) checkMemo(bindIR(ref)(_ => True()), TBoolean, ArraySeq(empty, null), refEnv) } - @Test def testAggLetMemo(): Unit = { + test("AggLetMemo") { val env = BindingEnv.empty.createAgg.bindAgg(ref.name -> ref.typ) checkMemo( aggBindIR(ref)(foo => ApplyAggOp(Collect())(SelectFields(foo, IndexedSeq("a")))), @@ -837,46 +846,54 @@ class PruneSuite extends HailSuite { checkMemo(aggBindIR(ref)(_ => True()), TBoolean, ArraySeq(empty, null), env) } - @Test def testMakeArrayMemo(): Unit = + test("MakeArrayMemo") { checkMemo(arr, TArray(justB), ArraySeq(justB, justB), refEnv) + } - @Test def testArrayRefMemo(): Unit = + test("ArrayRefMemo") { checkMemo(ArrayRef(arr, I32(0)), justB, ArraySeq(TArray(justB), null, null), refEnv) + } - @Test def testArrayLenMemo(): Unit = + test("ArrayLenMemo") { checkMemo(ArrayLen(arr), TInt32, ArraySeq(TArray(empty)), refEnv) + } - @Test def testStreamTakeMemo(): Unit = + test("StreamTakeMemo") { checkMemo(StreamTake(st, I32(2)), TStream(justA), ArraySeq(TStream(justA), null), refEnv) + } - @Test def testStreamDropMemo(): Unit = + test("StreamDropMemo") { checkMemo(StreamDrop(st, I32(2)), TStream(justA), ArraySeq(TStream(justA), null), refEnv) + } - @Test def testStreamMapMemo(): Unit = + test("StreamMapMemo") { checkMemo( mapIR(st)(x => x), TStream(justB), ArraySeq(TStream(justB), null), refEnv, ) + } - @Test def testStreamGroupedMemo(): Unit = + test("StreamGroupedMemo") { checkMemo( StreamGrouped(st, I32(2)), TStream(TStream(justB)), ArraySeq(TStream(justB), null), refEnv, ) + } - @Test def testStreamGroupByKeyMemo(): Unit = + test("StreamGroupByKeyMemo") { checkMemo( StreamGroupByKey(st, FastSeq("a"), false), TStream(TStream(justB)), ArraySeq(TStream(TStruct("a" -> TInt32, "b" -> TInt32)), null), refEnv, ) + } - @Test def testStreamMergeMemo(): Unit = { + test("StreamMergeMemo") { val st2 = st.deepCopy() checkMemo( StreamMultiMerge( @@ -889,7 +906,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamZipMemo(): Unit = { + test("StreamZipMemo") { val a2 = st.deepCopy() val a3 = st.deepCopy() for ( @@ -922,7 +939,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamFilterMemo(): Unit = { + test("StreamFilterMemo") { checkMemo( filterIR(st)(foo => bindIR(GetField(foo, "b"))(_ => False())), TStream(empty), @@ -933,31 +950,34 @@ class PruneSuite extends HailSuite { checkMemo(filterIR(st)(_ => False()), TStream(justB), ArraySeq(TStream(justB), null), refEnv) } - @Test def testStreamFlatMapMemo(): Unit = + test("StreamFlatMapMemo") { checkMemo( flatMapIR(st)(foo => MakeStream(FastSeq(foo), TStream(ref.typ))), TStream(justA), ArraySeq(TStream(justA), null), refEnv, ) + } - @Test def testStreamFoldMemo(): Unit = + test("StreamFoldMemo") { checkMemo( foldIR(st, I32(0))((_, foo) => GetField(foo, "a")), TInt32, ArraySeq(TStream(justA), null, null), refEnv, ) + } - @Test def testStreamScanMemo(): Unit = + test("StreamScanMemo") { checkMemo( streamScanIR(st, I32(0))((_, foo) => GetField(foo, "a")), TStream(TInt32), ArraySeq(TStream(justA), null, null), refEnv, ) + } - @Test def testStreamJoinRightDistinct(): Unit = { + test("StreamJoinRightDistinct") { checkMemo( joinRightDistinctIR( st, @@ -984,7 +1004,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamLeftIntervalJoin(): Unit = { + test("StreamLeftIntervalJoin") { val leftElemType = TStruct("a" -> TInt32, "b" -> TInt32, "c" -> TInt32) val rightElemType = TStruct("interval" -> TInterval(TInt32), "ignored" -> TInt32) @@ -1027,7 +1047,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamForMemo(): Unit = { + test("StreamForMemo") { checkMemo( forIR(st)(foo => Die(invoke("str", TString, GetField(foo, "a")), TVoid, ErrorIDs.NO_ERROR)), TVoid, @@ -1036,7 +1056,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMakeNDArrayMemo(): Unit = { + test("MakeNDArrayMemo") { val x = Ref(freshName(), TArray(TStruct("a" -> TInt32, "b" -> TInt64))) val y = Ref(freshName(), TTuple(TInt64, TInt64)) checkMemo( @@ -1051,15 +1071,16 @@ class PruneSuite extends HailSuite { ) } - @Test def testNDArrayMapMemo(): Unit = + test("NDArrayMapMemo") { checkMemo( ndMap(ndArr)(x => x), TNDArray(justBRequired, Nat(1)), ArraySeq(TNDArray(justBRequired, Nat(1)), null), refEnv, ) + } - @Test def testNDArrayMap2Memo(): Unit = { + test("NDArrayMap2Memo") { checkMemo( ndMap2(ndArr, ndArr)((l, _) => l), TNDArray(justBRequired, Nat(1)), @@ -1082,7 +1103,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMakeStructMemo(): Unit = { + test("MakeStructMemo") { checkMemo( MakeStruct(IndexedSeq("a" -> ref, "b" -> I32(10))), TStruct("a" -> justA), @@ -1097,34 +1118,38 @@ class PruneSuite extends HailSuite { ) } - @Test def testInsertFieldsMemo(): Unit = + test("InsertFieldsMemo") { checkMemo( InsertFields(ref, IndexedSeq("d" -> ref)), justA ++ TStruct("d" -> justB), ArraySeq(justA, justB), refEnv, ) + } - @Test def testSelectFieldsMemo(): Unit = { + test("SelectFieldsMemo") { checkMemo(SelectFields(ref, IndexedSeq("a", "b")), justA, ArraySeq(justA), refEnv) checkMemo(SelectFields(ref, IndexedSeq("b", "a")), bAndA, ArraySeq(aAndB), refEnv) } - @Test def testGetFieldMemo(): Unit = + test("GetFieldMemo") { checkMemo(GetField(ref, "a"), TInt32, ArraySeq(justA), refEnv) + } - @Test def testMakeTupleMemo(): Unit = + test("MakeTupleMemo") { checkMemo(MakeTuple(IndexedSeq(0 -> ref)), TTuple(justA), ArraySeq(justA), refEnv) + } - @Test def testGetTupleElementMemo(): Unit = + test("GetTupleElementMemo") { checkMemo( GetTupleElement(MakeTuple.ordered(IndexedSeq(ref, ref)), 1), justB, ArraySeq(TTuple(FastSeq(TupleField(1, justB)))), refEnv, ) + } - @Test def testCastRenameMemo(): Unit = { + test("CastRenameMemo") { val x = Ref(freshName(), TArray(TStruct("x" -> TInt32, "y" -> TString))) checkMemo( CastRename(x, TArray(TStruct("y" -> TInt32, "z" -> TString))), @@ -1134,7 +1159,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testAggFilterMemo(): Unit = { + test("AggFilterMemo") { val t = TStruct("a" -> TInt32, "b" -> TInt64, "c" -> TString) val x = Ref(freshName(), t) checkMemo( @@ -1149,7 +1174,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testAggExplodeMemo(): Unit = { + test("AggExplodeMemo") { val t = TStream(TStruct("a" -> TInt32, "b" -> TInt64)) val x = Ref(freshName(), t) checkMemo( @@ -1160,7 +1185,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testAggArrayPerElementMemo(): Unit = { + test("AggArrayPerElementMemo") { val t = TArray(TStruct("a" -> TInt32, "b" -> TInt64)) val x = Ref(freshName(), t) checkMemo( @@ -1173,7 +1198,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testCDAMemo(): Unit = { + test("CDAMemo") { val ctxT = TStruct("a" -> TInt32, "b" -> TString) val globT = TStruct("c" -> TInt64, "d" -> TFloat64) val x = cdaIR(NA(TStream(ctxT)), NA(globT), "test", NA(TString)) { (ctx, glob) => @@ -1195,66 +1220,75 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableCountMemo(): Unit = + test("TableCountMemo") { checkMemo(TableCount(tab), TInt64, ArraySeq(subsetTable(tab.typ, "NO_KEY"))) + } - @Test def testTableGetGlobalsMemo(): Unit = + test("TableGetGlobalsMemo") { checkMemo( TableGetGlobals(tab), TStruct("g1" -> TInt32), ArraySeq(subsetTable(tab.typ, "global.g1", "NO_KEY")), ) + } - @Test def testTableCollectMemo(): Unit = + test("TableCollectMemo") { checkMemo( TableCollect(TableKeyBy(tab, FastSeq())), TStruct("rows" -> TArray(TStruct("3" -> TString)), "global" -> TStruct("g2" -> TInt32)), ArraySeq(subsetTable(tab.typ.copy(key = FastSeq()), "row.3", "global.g2")), ) + } - @Test def testTableHeadMemo(): Unit = + test("TableHeadMemo") { checkMemo( TableHead(tab, 10L), subsetTable(tab.typ.copy(key = FastSeq()), "global.g1"), ArraySeq(subsetTable(tab.typ, "row.3", "global.g1")), ) + } - @Test def testTableTailMemo(): Unit = + test("TableTailMemo") { checkMemo( TableTail(tab, 10L), subsetTable(tab.typ.copy(key = FastSeq()), "global.g1"), ArraySeq(subsetTable(tab.typ, "row.3", "global.g1")), ) + } - @Test def testTableToValueApplyMemo(): Unit = + test("TableToValueApplyMemo") { checkMemo( TableToValueApply(tab, ForceCountTable()), TInt64, ArraySeq(tab.typ), ) + } - @Test def testMatrixToValueApplyMemo(): Unit = + test("MatrixToValueApplyMemo") { checkMemo( MatrixToValueApply(mat, ForceCountMatrixTable()), TInt64, ArraySeq(mat.typ), ) + } - @Test def testTableAggregateMemo(): Unit = + test("TableAggregateMemo") { checkMemo( TableAggregate(tab, tableRefBoolean(tab.typ, "global.g1")), TBoolean, ArraySeq(subsetTable(tab.typ, "global.g1"), null), ) + } - @Test def testMatrixAggregateMemo(): Unit = + test("MatrixAggregateMemo") { checkMemo( MatrixAggregate(mat, matrixRefBoolean(mat.typ, "global.g1")), TBoolean, ArraySeq(subsetMatrixTable(mat.typ, "global.g1", "NO_COL_KEY"), null), ) + } - @Test def testPipelineLetMemo(): Unit = { + test("PipelineLetMemo") { val t = TStruct("a" -> TInt32) checkMemo( relationalBindIR(NA(t))(x => x), @@ -1263,7 +1297,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableFilterRebuild(): Unit = { + test("TableFilterRebuild") { checkRebuild( TableFilter(tr, tableRefBoolean(tr.typ, "row.2")), subsetTable(tr.typ, "row.3"), @@ -1275,7 +1309,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableMapRowsRebuild(): Unit = { + test("TableMapRowsRebuild") { val tmr = TableMapRows(tr, tableRefStruct(tr.typ, "row.2", "global.g1")) checkRebuild( tmr, @@ -1306,7 +1340,7 @@ class PruneSuite extends HailSuite { } - @Test def testTableMapGlobalsRebuild(): Unit = { + test("TableMapGlobalsRebuild") { val tmg = TableMapGlobals(tr, tableRefStruct(tr.typ.copy(key = FastSeq()), "global.g1")) checkRebuild( tmg, @@ -1319,7 +1353,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableLeftJoinRightDistinctRebuild(): Unit = { + test("TableLeftJoinRightDistinctRebuild") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val tk2 = TableKeyBy(tab, ArraySeq("3")) val tj = TableLeftJoinRightDistinct(tk1, tk2, "foo") @@ -1332,7 +1366,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableIntervalJoinRebuild(): Unit = { + test("TableIntervalJoinRebuild") { val tk1 = TableKeyBy(tab, ArraySeq("1")) val tk2 = TableKeyBy(tab, ArraySeq("3")) val tj = TableIntervalJoin(tk1, tk2, "foo", product = false) @@ -1345,7 +1379,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableUnionRebuildUnifiesRowTypes(): Unit = { + test("TableUnionRebuildUnifiesRowTypes") { val mapExpr = InsertFields( Ref(TableIR.rowName, tr.typ.rowType), FastSeq("foo" -> tableRefBoolean(tr.typ, "row.3", "global.g1")), @@ -1369,7 +1403,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableMultiWayZipJoinRebuildUnifiesRowTypes(): Unit = { + test("TableMultiWayZipJoinRebuildUnifiesRowTypes") { val t1 = TableKeyBy(tab, ArraySeq("1")) val t2 = TableFilter(t1, tableRefBoolean(t1.typ, "row.2")) val t3 = TableFilter(t1, tableRefBoolean(t1.typ, "row.3")) @@ -1386,7 +1420,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixFilterColsRebuild(): Unit = { + test("MatrixFilterColsRebuild") { val mfc = MatrixFilterCols(mr, matrixRefBoolean(mr.typ, "sa.c2")) checkRebuild( mfc, @@ -1399,7 +1433,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixFilterEntriesRebuild(): Unit = { + test("MatrixFilterEntriesRebuild") { val mfe = MatrixFilterEntries(mr, matrixRefBoolean(mr.typ, "sa.c2", "va.r2", "g.e1")) checkRebuild( mfe, @@ -1418,7 +1452,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapRowsRebuild(): Unit = { + test("MatrixMapRowsRebuild") { val mmr = MatrixMapRows( MatrixKeyRowsBy(mr, IndexedSeq.empty), matrixRefStruct(mr.typ, "va.r2"), @@ -1436,7 +1470,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapColsRebuild(): Unit = { + test("MatrixMapColsRebuild") { val mmc = MatrixMapCols(mr, matrixRefStruct(mr.typ, "sa.c2"), Some(FastSeq("foo"))) checkRebuild( mmc, @@ -1454,7 +1488,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapEntriesRebuild(): Unit = { + test("MatrixMapEntriesRebuild") { val mme = MatrixMapEntries(mr, matrixRefStruct(mr.typ, "sa.c2", "va.r2")) checkRebuild( mme, @@ -1472,7 +1506,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixMapGlobalsRebuild(): Unit = { + test("MatrixMapGlobalsRebuild") { val mmg = MatrixMapGlobals(mr, matrixRefStruct(mr.typ, "global.g1")) checkRebuild( mmg, @@ -1490,7 +1524,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAggregateRowsByKeyRebuild(): Unit = { + test("MatrixAggregateRowsByKeyRebuild") { val ma = MatrixAggregateRowsByKey( mr, matrixRefStruct(mr.typ, "sa.c2"), @@ -1507,7 +1541,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAggregateColsByKeyRebuild(): Unit = { + test("MatrixAggregateColsByKeyRebuild") { val ma = MatrixAggregateColsByKey( mr, matrixRefStruct(mr.typ, "va.r2"), @@ -1524,7 +1558,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixUnionRowsRebuild(): Unit = { + test("MatrixUnionRowsRebuild") { val mat2 = MatrixLiteral(mType.copy(colKey = FastSeq()), mat.tl) checkRebuild( MatrixUnionRows(FastSeq( @@ -1539,7 +1573,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixUnionColsRebuild(): Unit = { + test("MatrixUnionColsRebuild") { def getColField(name: String) = GetField(Ref(MatrixIR.colName, mat.typ.colType), name) def childrenMatch(matrixUnionCols: MatrixUnionCols): Boolean = @@ -1590,7 +1624,7 @@ class PruneSuite extends HailSuite { } - @Test def testMatrixAnnotateRowsTableRebuild(): Unit = { + test("MatrixAnnotateRowsTableRebuild") { val tl = TableLiteral(Interpret(MatrixRowsTable(mat), ctx), theHailClassLoader) val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product = false) checkRebuild( @@ -1609,7 +1643,7 @@ class PruneSuite extends HailSuite { def subsetTS(fields: String*): TStruct = ts.filterSet(fields.toSet)._1 - @Test def testNARebuild(): Unit = { + test("NARebuild") { checkRebuild( NA(ts), subsetTS("b"), @@ -1620,7 +1654,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testIfRebuild(): Unit = { + test("IfRebuild") { checkRebuild( If(True(), NA(ts), NA(ts)), subsetTS("b"), @@ -1631,7 +1665,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testSwitchRebuild(): Unit = + test("SwitchRebuild") { checkRebuild[IR]( Switch(I32(0), NA(ts), FastSeq(NA(ts))), subsetTS("b"), @@ -1641,8 +1675,9 @@ class PruneSuite extends HailSuite { cases(0).typ == subsetTS("b") }, ) + } - @Test def testCoalesceRebuild(): Unit = { + test("CoalesceRebuild") { checkRebuild( Coalesce(FastSeq(NA(ts), NA(ts))), subsetTS("b"), @@ -1651,7 +1686,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testLetRebuild(): Unit = { + test("LetRebuild") { checkRebuild( bindIR(NA(ts))(x => x), subsetTS("b"), @@ -1662,7 +1697,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testAggLetRebuild(): Unit = { + test("AggLetRebuild") { checkRebuild( aggBindIR(NA(ref.typ))(foo => ApplyAggOp(Collect())(SelectFields(foo, IndexedSeq("a")))), TArray(subsetTS("a")), @@ -1675,7 +1710,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMakeArrayRebuild(): Unit = { + test("MakeArrayRebuild") { checkRebuild( MakeArray(IndexedSeq(NA(ts)), TArray(ts)), TArray(subsetTS("b")), @@ -1686,7 +1721,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamTakeRebuild(): Unit = { + test("StreamTakeRebuild") { checkRebuild( StreamTake(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(subsetTS("b")), @@ -1697,7 +1732,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamDropRebuild(): Unit = { + test("StreamDropRebuild") { checkRebuild( StreamDrop(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(subsetTS("b")), @@ -1708,7 +1743,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamMapRebuild(): Unit = { + test("StreamMapRebuild") { checkRebuild( mapIR(MakeStream(IndexedSeq(NA(ts)), TStream(ts)))(x => x), TStream(subsetTS("b")), @@ -1719,7 +1754,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamGroupedRebuild(): Unit = { + test("StreamGroupedRebuild") { checkRebuild( StreamGrouped(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(TStream(subsetTS("b"))), @@ -1730,7 +1765,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamGroupByKeyRebuild(): Unit = { + test("StreamGroupByKeyRebuild") { checkRebuild( StreamGroupByKey(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), FastSeq("a"), false), TStream(TStream(subsetTS("b"))), @@ -1741,7 +1776,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamMergeRebuild(): Unit = { + test("StreamMergeRebuild") { checkRebuild( StreamMultiMerge( IndexedSeq( @@ -1755,7 +1790,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamZipRebuild(): Unit = { + test("StreamZipRebuild") { val a2 = st.deepCopy() val a3 = st.deepCopy() for ( @@ -1796,7 +1831,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamFlatmapRebuild(): Unit = { + test("StreamFlatmapRebuild") { checkRebuild( flatMapIR(MakeStream(IndexedSeq(NA(ts)), TStream(ts))) { x => MakeStream(IndexedSeq(x), TStream(ts)) @@ -1809,7 +1844,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMakeStructRebuild(): Unit = { + test("MakeStructRebuild") { checkRebuild( MakeStruct(IndexedSeq("a" -> NA(TInt32), "b" -> NA(TInt64), "c" -> NA(TString))), subsetTS("b"), @@ -1818,7 +1853,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testInsertFieldsRebuild(): Unit = { + test("InsertFieldsRebuild") { checkRebuild( InsertFields(NA(TStruct("a" -> TInt32)), IndexedSeq("b" -> NA(TInt64), "c" -> NA(TString))), subsetTS("b"), @@ -1843,7 +1878,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMakeTupleRebuild(): Unit = { + test("MakeTupleRebuild") { checkRebuild( MakeTuple(IndexedSeq(0 -> I32(1), 1 -> F64(1.0), 2 -> NA(TString))), TTuple(FastSeq(TupleField(2, TString))), @@ -1852,7 +1887,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testSelectFieldsRebuild(): Unit = { + test("SelectFieldsRebuild") { checkRebuild( SelectFields(NA(ts), IndexedSeq("a", "b")), subsetTS("b"), @@ -1863,7 +1898,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testCastRenameRebuild(): Unit = { + test("CastRenameRebuild") { checkRebuild( CastRename( NA(TArray(TStruct("x" -> TInt32, "y" -> TString))), @@ -1884,7 +1919,7 @@ class PruneSuite extends HailSuite { ErrorIDs.NO_ERROR, ) - @Test def testNDArrayMapRebuild(): Unit = { + test("NDArrayMapRebuild") { checkRebuild( ndMap(ndArrayTS)(x => x), TNDArray(subsetTS("b"), Nat(1)), @@ -1897,7 +1932,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testNDArrayMap2Rebuild(): Unit = { + test("NDArrayMap2Rebuild") { checkRebuild( ndMap2(ndArrayTS, ndArrayTS)((l, _) => l), TNDArray(subsetTS("b"), Nat(1)), @@ -1918,7 +1953,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testCDARebuild(): Unit = { + test("CDARebuild") { val ctxT = TStruct("a" -> TInt32, "b" -> TString) val globT = TStruct("c" -> TInt64, "d" -> TFloat64) val x = cdaIR( @@ -1947,7 +1982,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableAggregateRebuild(): Unit = { + test("TableAggregateRebuild") { val ta = TableAggregate(tr, ApplyAggOp(PrevNonnull())(tableRefBoolean(tr.typ, "row.2"))) checkRebuild( ta, @@ -1959,7 +1994,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testTableCollectRebuild(): Unit = { + test("TableCollectRebuild") { val tc = TableCollect(TableKeyBy(tab, FastSeq())) checkRebuild( tc, @@ -1976,7 +2011,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testMatrixAggregateRebuild(): Unit = { + test("MatrixAggregateRebuild") { val ma = MatrixAggregate(mr, ApplyAggOp(Collect())(matrixRefBoolean(mr.typ, "va.r2"))) checkRebuild( ma, @@ -1988,7 +2023,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testPipelineLetRebuild(): Unit = { + test("PipelineLetRebuild") { val t = TStruct("a" -> TInt32) val foo = freshName() checkRebuild( @@ -1999,7 +2034,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testPipelineLetTableRebuild(): Unit = { + test("PipelineLetTableRebuild") { val t = TStruct("a" -> TInt32) val foo = freshName() checkRebuild( @@ -2012,7 +2047,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testPipelineLetMatrixTableRebuild(): Unit = { + test("PipelineLetMatrixTableRebuild") { val t = TStruct("a" -> TInt32) val foo = freshName() checkRebuild( @@ -2025,7 +2060,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testIfUnification(): Unit = { + test("IfUnification") { val pred = False() val t = TStruct("a" -> TInt32, "b" -> TInt32) val pruneT = TStruct("a" -> TInt32) @@ -2047,32 +2082,37 @@ class PruneSuite extends HailSuite { ): Unit } - @DataProvider(name = "supertypePairs") - def supertypePairs: Array[Array[Type]] = Array( - Array(TInt32, TInt32), - Array( - TStruct( - "a" -> TInt32, - "b" -> TArray(TInt64), - ), - TStruct( - "a" -> TInt32, - "b" -> TArray(TInt64), - ), + object checkIsSupertypeRequiredness extends TestCases { + def apply( + t1: Type, + t2: Type, + )(implicit loc: munit.Location + ): Unit = test("IsSupertypeRequiredness") { + assert( + PruneDeadFields.isSupertype(t1, t2), + s"""Failure, supertype relationship not met + | supertype: ${t1.toPrettyString(true)} + | subtype: ${t2.toPrettyString(true)}""".stripMargin, + ) + } + } + + checkIsSupertypeRequiredness(TInt32, TInt32) + + checkIsSupertypeRequiredness( + TStruct( + "a" -> TInt32, + "b" -> TArray(TInt64), + ), + TStruct( + "a" -> TInt32, + "b" -> TArray(TInt64), ), - Array(TSet(TString), TSet(TString)), ) - @Test(dataProvider = "supertypePairs") - def testIsSupertypeRequiredness(t1: Type, t2: Type): Unit = - assert( - PruneDeadFields.isSupertype(t1, t2), - s"""Failure, supertype relationship not met - | supertype: ${t1.toPrettyString(true)} - | subtype: ${t2.toPrettyString(true)}""".stripMargin, - ) + checkIsSupertypeRequiredness(TSet(TString), TSet(TString)) - @Test def testApplyScanOp(): Unit = { + test("ApplyScanOp") { val x = Ref(freshName(), TInt32) val y = Ref(freshName(), TInt32) val env = BindingEnv.empty.createScan.bindScan(x.name -> x.typ, y.name -> y.typ) @@ -2119,7 +2159,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testApplyAggOp(): Unit = { + test("ApplyAggOp") { val x = Ref(freshName(), TInt32) val y = Ref(freshName(), TInt32) val env = BindingEnv.empty.createAgg.bindAgg(x.name -> x.typ, y.name -> y.typ) @@ -2166,7 +2206,7 @@ class PruneSuite extends HailSuite { ) } - @Test def testStreamFold2(): Unit = { + test("StreamFold2") { val eltType = TStruct("a" -> TInt32, "b" -> TInt32) val accum1Type = TStruct("c" -> TInt32, "d" -> TInt32) diff --git a/hail/hail/test/src/is/hail/expr/ir/RandomSuite.scala b/hail/hail/test/src/is/hail/expr/ir/RandomSuite.scala index 9a9b9a77b89..d034bad0beb 100644 --- a/hail/hail/test/src/is/hail/expr/ir/RandomSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/RandomSuite.scala @@ -9,7 +9,6 @@ import is.hail.types.physical.stypes.concrete.{ } import org.apache.commons.math3.distribution.ChiSquaredDistribution -import org.testng.annotations.Test class RandomSuite extends HailSuite { // from skein_golden_kat_short_internals.txt in the skein source @@ -28,7 +27,7 @@ class RandomSuite extends HailSuite { ), ) - @Test def testThreefry(): Unit = { + test("Threefry") { for { (key, tweak, input, expected) <- threefryTestCases } { @@ -160,7 +159,7 @@ class RandomSuite extends HailSuite { (Array[Long](100, 101, 102, 103, 104), 30L), ) - @Test def testPMAC(): Unit = { + test("PMAC") { for { (message, staticID) <- pmacTestCases } { @@ -172,7 +171,7 @@ class RandomSuite extends HailSuite { } } - @Test def testPMACHash(): Unit = { + test("PMACHash") { for { (message, _) <- pmacTestCases } { @@ -185,7 +184,7 @@ class RandomSuite extends HailSuite { } } - @Test def testRandomEngine(): Unit = { + test("RandomEngine") { for { (message, staticID) <- pmacTestCases } { @@ -232,7 +231,7 @@ class RandomSuite extends HailSuite { println(s"passed after $numRuns runs with pvalue $geometricMean") } - @Test def testRandomInt(): Unit = { + test("RandomInt") { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -241,7 +240,7 @@ class RandomSuite extends HailSuite { } } - @Test def testBoundedUniformInt(): Unit = { + test("BoundedUniformInt") { var n = 1 << 25 var k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -256,7 +255,7 @@ class RandomSuite extends HailSuite { } } - @Test def testBoundedUniformLong(): Unit = { + test("BoundedUniformLong") { var n = 1 << 25 var k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -271,7 +270,7 @@ class RandomSuite extends HailSuite { } } - @Test def testUniformDouble(): Unit = { + test("UniformDouble") { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -282,7 +281,7 @@ class RandomSuite extends HailSuite { } } - @Test def testUniformFloat(): Unit = { + test("UniformFloat") { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() diff --git a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala index db442b3c691..85a82f17cc9 100644 --- a/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/RequirednessSuite.scala @@ -19,9 +19,6 @@ import is.hail.types.virtual._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.{DataProvider, Test} class RequirednessSuite extends HailSuite { val required: Boolean = true @@ -108,7 +105,6 @@ class RequirednessSuite extends HailSuite { def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) - @DataProvider(name = "valueIR") def valueIR(): Array[Array[Any]] = { val nodes = Array.newBuilder[Array[Any]] nodes.sizeHint(50) @@ -352,7 +348,6 @@ class RequirednessSuite extends HailSuite { nodes.result() } - @DataProvider(name = "tableIR") def tableIR(): Array[Array[Any]] = { val nodes = Array.newBuilder[Array[Any]] nodes.sizeHint(50) @@ -634,8 +629,7 @@ class RequirednessSuite extends HailSuite { nodes.result() } - @Test - def testDataProviders(): Unit = { + test("DataProviders") { val s = ArrayBuffer.empty[String] valueIR().map(v => v(0) -> v(1)).foreach { case (n: IR, t: PType) => @@ -661,22 +655,25 @@ class RequirednessSuite extends HailSuite { s"${Pretty(ctx, node.t)}: \n$t" }.mkString("\n\n") - @Test(dataProvider = "valueIR") - def testRequiredness(node: IR, expected: Any): Unit = { - TypeCheck(ctx, node) - val et = expected match { - case pt: PType => EmitType(pt.sType, pt.required) - case et: EmitType => et + test("Requiredness") { + valueIR().foreach { v => + val node = v(0).asInstanceOf[IR] + val expected = v(1) + TypeCheck(ctx, node) + val et = expected match { + case pt: PType => EmitType(pt.sType, pt.required) + case et: EmitType => et + } + val res = Requiredness.apply(node, ctx) + val actual = res.r.lookup(node).asInstanceOf[TypeWithRequiredness] + assert( + actual.canonicalEmitType(node.typ) == et, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) } - val res = Requiredness.apply(node, ctx) - val actual = res.r.lookup(node).asInstanceOf[TypeWithRequiredness] - assert( - actual.canonicalEmitType(node.typ) == et, - s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", - ) } - @Test def sharedNodesWorkCorrectly(): Unit = { + test("sharedNodesWorkCorrectly") { val n2 = bindIR(I32(1))(x => MakeStruct(FastSeq("a" -> x, "b" -> x))) val node = InsertFields(n2, FastSeq("c" -> GetField(n2, "a"), "d" -> GetField(n2, "b"))) val res = Requiredness.apply(node, ctx) @@ -690,21 +687,25 @@ class RequirednessSuite extends HailSuite { )) } - @Test(dataProvider = "tableIR") - def testTableRequiredness(node: TableIR, row: PType, global: PType): Unit = { - val res = Requiredness.apply(node, ctx) - val actual = res.r.lookup(node).asInstanceOf[RTable] - assert( - actual.rowType.canonicalPType(node.typ.rowType) == row, - s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", - ) - assert( - actual.globalType.canonicalPType(node.typ.globalType) == global, - s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", - ) + test("TableRequiredness") { + tableIR().foreach { v => + val node = v(0).asInstanceOf[TableIR] + val row = v(1).asInstanceOf[PType] + val global = v(2).asInstanceOf[PType] + val res = Requiredness.apply(node, ctx) + val actual = res.r.lookup(node).asInstanceOf[RTable] + assert( + actual.rowType.canonicalPType(node.typ.rowType) == row, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) + assert( + actual.globalType.canonicalPType(node.typ.globalType) == global, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) + } } - @Test def testTableReader(): Unit = { + test("TableReader") { val table = TableParallelize( makestruct( "rows" -> MakeArray(makestruct( @@ -731,12 +732,10 @@ class RequirednessSuite extends HailSuite { } val reader = TableNativeReader(fs, TableNativeReaderParameters(path, None)) - forAll( - Array( - table.typ, - TableType(TStruct("a" -> tnestedarray), FastSeq(), TStruct("z" -> tstruct)), - ) - ) { rType => + Array( + table.typ, + TableType(TStruct("a" -> tnestedarray), FastSeq(), TStruct("z" -> tstruct)), + ).foreach { rType => val row = reader.rowRequiredness(ctx, rType) val global = reader.globalRequiredness(ctx, rType) val node = TableRead(rType, dropRows = false, reader) @@ -753,7 +752,7 @@ class RequirednessSuite extends HailSuite { } } - @Test def testSubsettedTuple(): Unit = { + test("SubsettedTuple") { val node = MakeTuple(FastSeq(0 -> I32(0), 4 -> NA(TInt32), 2 -> NA(TArray(TInt32)))) val expected = PCanonicalTuple( FastSeq( diff --git a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala index 5b2d6df840f..93e060de028 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SetFunctionsSuite.scala @@ -6,15 +6,13 @@ import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.defs.{I32, NA, ToSet, ToStream} import is.hail.types.virtual._ -import org.testng.annotations.Test - class SetFunctionsSuite extends HailSuite { val naa = NA(TArray(TInt32)) val nas = NA(TSet(TInt32)) implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @Test def toSet(): Unit = { + test("toSet") { assertEvalsTo(IRSet(3, 7), Set(3, 7)) assertEvalsTo(IRSet(3, null, 7), Set(null, 3, 7)) assertEvalsTo(nas, null) @@ -24,7 +22,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("toSet", TSet(TInt32), naa), null) } - @Test def isEmpty(): Unit = { + test("isEmpty") { assertEvalsTo(invoke("isEmpty", TBoolean, IRSet(3, 7)), false) assertEvalsTo(invoke("isEmpty", TBoolean, IRSet(3, null, 7)), false) assertEvalsTo(invoke("isEmpty", TBoolean, IRSet()), true) @@ -32,7 +30,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("isEmpty", TBoolean, nas), null) } - @Test def contains(): Unit = { + test("contains") { val s = IRSet(3, null, 7) val swoutna = IRSet(3, 7) @@ -43,10 +41,10 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("contains", TBoolean, s, NA(TInt32)), true) assertEvalsTo(invoke("contains", TBoolean, swoutna, NA(TInt32)), false) assertEvalsTo(invoke("contains", TBoolean, IRSet(3, 7), NA(TInt32)), false) - assert(eval(invoke("contains", TBoolean, IRSet(), 3)) == false) + assertEquals(eval(invoke("contains", TBoolean, IRSet(), 3)), false) } - @Test def remove(): Unit = { + test("remove") { val s = IRSet(3, null, 7) assertEvalsTo(invoke("remove", TSet(TInt32), s, I32(3)), Set(null, 7)) assertEvalsTo(invoke("remove", TSet(TInt32), s, I32(4)), Set(null, 3, 7)) @@ -54,7 +52,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("remove", TSet(TInt32), IRSet(3, 7), NA(TInt32)), Set(3, 7)) } - @Test def add(): Unit = { + test("add") { val s = IRSet(3, null, 7) assertEvalsTo(invoke("add", TSet(TInt32), s, I32(3)), Set(null, 3, 7)) assertEvalsTo(invoke("add", TSet(TInt32), s, I32(4)), Set(null, 3, 4, 7)) @@ -63,7 +61,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("add", TSet(TInt32), IRSet(3, 7), NA(TInt32)), Set(null, 3, 7)) } - @Test def isSubset(): Unit = { + test("isSubset") { val s = IRSet(3, null, 7) assertEvalsTo(invoke("isSubset", TBoolean, s, invoke("add", TSet(TInt32), s, I32(4))), true) assertEvalsTo( @@ -82,12 +80,12 @@ class SetFunctionsSuite extends HailSuite { ) } - @Test def union(): Unit = { + test("union") { assertEvalsTo(invoke("union", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(null, 3, 7, 8)) assertEvalsTo(invoke("union", TSet(TInt32), IRSet(3, 7), IRSet(3, 8, null)), Set(null, 3, 7, 8)) } - @Test def intersection(): Unit = { + test("intersection") { assertEvalsTo(invoke("intersection", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(3)) assertEvalsTo( invoke("intersection", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8, null)), @@ -95,12 +93,12 @@ class SetFunctionsSuite extends HailSuite { ) } - @Test def difference(): Unit = { + test("difference") { assertEvalsTo(invoke("difference", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(null, 7)) assertEvalsTo(invoke("difference", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8, null)), Set(7)) } - @Test def median(): Unit = { + test("median") { assertEvalsTo(invoke("median", TInt32, IRSet(5)), 5) assertEvalsTo(invoke("median", TInt32, IRSet(5, null)), 5) assertEvalsTo(invoke("median", TInt32, IRSet(3, 7)), 5) diff --git a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala index 7b5a0062102..f70050e58e0 100644 --- a/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/SimplifySuite.scala @@ -11,49 +11,24 @@ import is.hail.utils.Interval import is.hail.variant.Locus import org.apache.spark.sql.Row -import org.scalactic.{Equivalence, Prettifier} -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.scalatest.matchers.{MatchResult, Matcher} -import org.scalatest.matchers.dsl.MatcherFactory1 -import org.scalatest.matchers.must.Matchers.not -import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} -import org.testng.annotations.{DataProvider, Test} class SimplifySuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.interpretOnly - def simplifyTo(expected: BaseIR): MatcherFactory1[BaseIR, Equivalence] = - new MatcherFactory1[BaseIR, Equivalence] { - override def matcher[T <: BaseIR: Equivalence]: Matcher[T] = - Matcher[BaseIR] { input => - val simplified = - Simplify(ctx, input) - - val prettyDiff = - s""" before = ${Pretty(ctx, input)} - | after = ${Pretty(ctx, simplified)} - |expected = ${Pretty(ctx, expected)} - | """.stripMargin - - MatchResult( - matches = - simplified.isAlphaEquiv(ctx, expected), - rawFailureMessage = - s"The simplified IR was not alpha-equivalent:\n$prettyDiff", - rawNegatedFailureMessage = - s"The simplified IR was alpha-equivalent:\n$prettyDiff", - ) - } - } - - implicit val irPrettifier: Prettifier = { - case ir: BaseIR => Pretty(ctx, ir) - case x => Prettifier.default(x) + def assertSimplifiesTo(input: BaseIR, expected: BaseIR)(implicit loc: munit.Location): Unit = { + val simplified = Simplify(ctx, input) + assert( + simplified.isAlphaEquiv(ctx, expected), + s"""The simplified IR was not alpha-equivalent: + | before = ${Pretty(ctx, input)} + | after = ${Pretty(ctx, simplified)} + |expected = ${Pretty(ctx, expected)} + | """.stripMargin, + ) } - @Test def testTableMultiWayZipJoinGlobalsRewrite(): Unit = { + test("TableMultiWayZipJoinGlobalsRewrite") { val tmwzj = TableGetGlobals(TableMultiWayZipJoin( ArraySeq(TableRange(10, 10), TableRange(10, 10), TableRange(10, 10)), "rowField", @@ -62,7 +37,7 @@ class SimplifySuite extends HailSuite { assertEvalsTo(tmwzj, Row(FastSeq(Row(), Row(), Row()))) } - @Test def testRepartitionableMapUpdatesForUpstreamOptimizations(): Unit = { + test("RepartitionableMapUpdatesForUpstreamOptimizations") { val range = TableKeyBy(TableRange(10, 3), FastSeq()) val simplifiableIR = If(True(), GetField(Ref(TableIR.rowName, range.typ.rowType), "idx").ceq(0), False()) @@ -77,15 +52,14 @@ class SimplifySuite extends HailSuite { lazy val base = Literal(TStruct("1" -> TInt32, "2" -> TInt32), Row(1, 2)) - @Test def testInsertFieldsRewriteRules(): Unit = { + test("InsertFieldsRewriteRules") { val ir1 = InsertFields(InsertFields(base, FastSeq("1" -> I32(2)), None), FastSeq("1" -> I32(3)), None) - ir1 should simplifyTo(InsertFields( - base, - FastSeq("1" -> I32(3)), - Some(FastSeq("1", "2")), - )) + assertSimplifiesTo( + ir1, + InsertFields(base, FastSeq("1" -> I32(3)), Some(FastSeq("1", "2"))), + ) val ir2 = InsertFields( @@ -94,11 +68,10 @@ class SimplifySuite extends HailSuite { None, ) - ir2 should simplifyTo(InsertFields( - base, - FastSeq("3" -> I32(3)), - Some(FastSeq("3", "1", "2")), - )) + assertSimplifiesTo( + ir2, + InsertFields(base, FastSeq("3" -> I32(3)), Some(FastSeq("3", "1", "2"))), + ) val ir3 = InsertFields( @@ -107,11 +80,14 @@ class SimplifySuite extends HailSuite { Some(FastSeq("3", "1", "2", "4")), ) - ir3 should simplifyTo(InsertFields( - base, - FastSeq("3" -> I32(2), "4" -> I32(3)), - Some(FastSeq("3", "1", "2", "4")), - )) + assertSimplifiesTo( + ir3, + InsertFields( + base, + FastSeq("3" -> I32(2), "4" -> I32(3)), + Some(FastSeq("3", "1", "2", "4")), + ), + ) val ir4 = InsertFields( @@ -119,17 +95,20 @@ class SimplifySuite extends HailSuite { FastSeq("3" -> I32(5)), ) - ir4 should simplifyTo(InsertFields( - base, - FastSeq("4" -> I32(1), "3" -> I32(5)), - Some(FastSeq("1", "2", "3", "4")), - )) + assertSimplifiesTo( + ir4, + InsertFields( + base, + FastSeq("4" -> I32(1), "3" -> I32(5)), + Some(FastSeq("1", "2", "3", "4")), + ), + ) } lazy val base2 = Literal(TStruct("A" -> TInt32, "B" -> TInt32, "C" -> TInt32, "D" -> TInt32), Row(1, 2, 3, 4)) - @Test def testInsertFieldsWhereFieldBeingInsertedCouldBeSelected(): Unit = { + test("InsertFieldsWhereFieldBeingInsertedCouldBeSelected") { val ir1 = InsertFields( SelectFields(base2, IndexedSeq("A", "B", "C")), @@ -140,21 +119,23 @@ class SimplifySuite extends HailSuite { assert(simplify1.typ == ir1.typ) } - @Test def testInsertSelectRewriteRules(): Unit = { - SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("1")) should - simplifyTo(SelectFields(base, FastSeq("1"))) + test("InsertSelectRewriteRules") { + assertSimplifiesTo( + SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("1")), + SelectFields(base, FastSeq("1")), + ) - SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("3", "1")) should - simplifyTo { - InsertFields( - SelectFields(base, FastSeq("1")), - FastSeq("3" -> I32(1)), - Some(FastSeq("3", "1")), - ) - } + assertSimplifiesTo( + SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("3", "1")), + InsertFields( + SelectFields(base, FastSeq("1")), + FastSeq("3" -> I32(1)), + Some(FastSeq("3", "1")), + ), + ) } - @Test def testContainsRewrites(): Unit = { + test("ContainsRewrites") { assertEvalsTo( invoke("contains", TBoolean, Literal(TArray(TString), FastSeq("a")), In(0, TString)), FastSeq("a" -> TString), @@ -174,7 +155,7 @@ class SimplifySuite extends HailSuite { ) } - @Test def testTableCountExplodeSetRewrite(): Unit = { + test("TableCountExplodeSetRewrite") { var ir: TableIR = TableRange(1, 1) ir = TableMapRows( ir, @@ -187,106 +168,111 @@ class SimplifySuite extends HailSuite { assertEvalsTo(TableCount(ir), 1L) } - @DataProvider(name = "NestedInserts") - def nestedInserts: Array[Array[Any]] = { + object checkNestedInsertsSimplify extends TestCases { + def apply(input: IR, expected: IR)(implicit loc: munit.Location): Unit = + test("nested inserts simplify")(assertSimplifiesTo(input, expected)) + } + + { val unbound = Name("do-not-touch") val r = Ref(Name("unbound-struct"), TStruct("x" -> TInt32)) - Array[Array[Any]]( - Array( - bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => - InsertFields(r2, FastSeq("z" -> GetField(r2, "x").toD)) - }, - bindIRs(F64(0), r) { case Seq(x0, r2) => - InsertFields( - r2, - FastSeq("y" -> x0, "z" -> GetField(r2, "x").toD), - Some(FastSeq("x", "y", "z")), - ) - }, - ), - Array( - bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => - InsertFields(r2, FastSeq("z" -> (GetField(r2, "x").toD + GetField(r2, "y")))) - }, - bindIRs(F64(0), r) { case Seq(x0, r2) => - InsertFields( - r2, - FastSeq("y" -> x0, "z" -> (GetField(r2, "x").toD + x0)), - Some(FastSeq("x", "y", "z")), - ) - }, - ), - Array( - bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => - InsertFields(Ref(unbound, TStruct.empty), FastSeq("z" -> GetField(r2, "y").toI)) - }, - bindIRs(F64(0), r) { case Seq(x0, _) => - InsertFields(Ref(unbound, TStruct.empty), FastSeq("z" -> x0.toI)) - }, - ), - Array.fill(2) { // unrewriteable + checkNestedInsertsSimplify( + bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => + InsertFields(r2, FastSeq("z" -> GetField(r2, "x").toD)) + }, + bindIRs(F64(0), r) { case Seq(x0, r2) => + InsertFields( + r2, + FastSeq("y" -> x0, "z" -> GetField(r2, "x").toD), + Some(FastSeq("x", "y", "z")), + ) + }, + ) + + checkNestedInsertsSimplify( + bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => + InsertFields(r2, FastSeq("z" -> (GetField(r2, "x").toD + GetField(r2, "y")))) + }, + bindIRs(F64(0), r) { case Seq(x0, r2) => + InsertFields( + r2, + FastSeq("y" -> x0, "z" -> (GetField(r2, "x").toD + x0)), + Some(FastSeq("x", "y", "z")), + ) + }, + ) + + checkNestedInsertsSimplify( + bindIR(InsertFields(r, FastSeq("y" -> F64(0)))) { r2 => + InsertFields(Ref(unbound, TStruct.empty), FastSeq("z" -> GetField(r2, "y").toI)) + }, + bindIRs(F64(0), r) { case Seq(x0, _) => + InsertFields(Ref(unbound, TStruct.empty), FastSeq("z" -> x0.toI)) + }, + ) + + { + val unrewriteable = bindIR(InsertFields(r, FastSeq("y" -> Ref(unbound, TFloat64)))) { r2 => InsertFields(r2, FastSeq(("z", invoke("str", TString, r2)))) } + checkNestedInsertsSimplify(unrewriteable, unrewriteable) + } + + checkNestedInsertsSimplify( + IRBuilder.scoped { b => + val a = b.strictMemoize(I32(32)) + val r2 = b.strictMemoize(InsertFields(r, FastSeq("y" -> F64(0)))) + val r3 = b.strictMemoize(InsertFields(r2, FastSeq("w" -> a))) + InsertFields(r3, FastSeq("z" -> (GetField(r3, "x").toD + GetField(r3, "y")))) }, - Array( - IRBuilder.scoped { b => - val a = b.strictMemoize(I32(32)) - val r2 = b.strictMemoize(InsertFields(r, FastSeq("y" -> F64(0)))) - val r3 = b.strictMemoize(InsertFields(r2, FastSeq("w" -> a))) - InsertFields(r3, FastSeq("z" -> (GetField(r3, "x").toD + GetField(r3, "y")))) - }, - IRBuilder.scoped { b => - val a = b.strictMemoize(I32(32)) - val x0 = b.strictMemoize(F64(0)) - val r2 = b.strictMemoize(r) - val x1 = b.strictMemoize(x0) - val x2 = b.strictMemoize(a) - val r3 = b.strictMemoize(r2) - InsertFields( - r3, - FastSeq( - "y" -> x1, - "w" -> x2, - "z" -> (GetField(r3, "x").toD + x1), - ), - Some(FastSeq("x", "y", "w", "z")), - ) - }, - ), - Array( - IRBuilder.scoped { outer => - val ins = - outer.strictMemoize { - IRBuilder.scoped { in => - val a = in.strictMemoize(I32(1) + Ref(unbound, TInt32)) - InsertFields(r, FastSeq("field0" -> a, "field1" -> (I32(1) + a))) - } + IRBuilder.scoped { b => + val a = b.strictMemoize(I32(32)) + val x0 = b.strictMemoize(F64(0)) + val r2 = b.strictMemoize(r) + val x1 = b.strictMemoize(x0) + val x2 = b.strictMemoize(a) + val r3 = b.strictMemoize(r2) + InsertFields( + r3, + FastSeq( + "y" -> x1, + "w" -> x2, + "z" -> (GetField(r3, "x").toD + x1), + ), + Some(FastSeq("x", "y", "w", "z")), + ) + }, + ) + + checkNestedInsertsSimplify( + IRBuilder.scoped { outer => + val ins = + outer.strictMemoize { + IRBuilder.scoped { in => + val a = in.strictMemoize(I32(1) + Ref(unbound, TInt32)) + InsertFields(r, FastSeq("field0" -> a, "field1" -> (I32(1) + a))) } + } - InsertFields(ins, FastSeq("field2" -> (I32(1) + GetField(ins, "field1")))) - }, - IRBuilder.scoped { ib => - val a = ib.strictMemoize(I32(1) + Ref(unbound, TInt32)) - val x0 = ib.strictMemoize(a) - val x1 = ib.strictMemoize(I32(1) + a) - val s = ib.strictMemoize(r) - InsertFields( - s, - FastSeq("field0" -> x0, "field1" -> x1, "field2" -> (I32(1) + x1)), - Some(FastSeq("x", "field0", "field1", "field2")), - ) - }, - ), + InsertFields(ins, FastSeq("field2" -> (I32(1) + GetField(ins, "field1")))) + }, + IRBuilder.scoped { ib => + val a = ib.strictMemoize(I32(1) + Ref(unbound, TInt32)) + val x0 = ib.strictMemoize(a) + val x1 = ib.strictMemoize(I32(1) + a) + val s = ib.strictMemoize(r) + InsertFields( + s, + FastSeq("field0" -> x0, "field1" -> x1, "field2" -> (I32(1) + x1)), + Some(FastSeq("x", "field0", "field1", "field2")), + ) + }, ) } - @Test(dataProvider = "NestedInserts") - def testNestedInsertsSimplify(input: IR, expected: IR): Unit = - input should simplifyTo(expected) - - @Test def testArrayAggNoAggRewrites(): Unit = { + test("ArrayAggNoAggRewrites") { val doesRewrite: Array[StreamAgg] = { val x = Ref(freshName(), TInt32) Array( @@ -297,7 +283,7 @@ class SimplifySuite extends HailSuite { ) } - forAll(doesRewrite)(a => a should simplifyTo(a.query)) + doesRewrite.foreach(a => assertSimplifiesTo(a, a.query)) val doesNotRewrite: Array[StreamAgg] = Array( streamAggIR(ToStream(In(0, TArray(TInt32))))(ApplyAggOp(Sum())(_)), @@ -306,10 +292,10 @@ class SimplifySuite extends HailSuite { }, ) - forAll(doesNotRewrite)(a => a should simplifyTo(a)) + doesNotRewrite.foreach(a => assertSimplifiesTo(a, a)) } - @Test def testArrayAggScanNoAggRewrites(): Unit = { + test("ArrayAggScanNoAggRewrites") { val doesRewrite: Array[StreamAggScan] = Array( streamAggScanIR(ToStream(In(0, TArray(TInt32))))(_ => Ref(freshName(), TInt32)), streamAggScanIR(ToStream(In(0, TArray(TInt32)))) { _ => @@ -317,7 +303,7 @@ class SimplifySuite extends HailSuite { }, ) - forAll(doesRewrite)(ir => Simplify(ctx, ir) should not be a[StreamAggScan]) + doesRewrite.foreach(ir => assert(!Simplify(ctx, ir).isInstanceOf[StreamAggScan])) val doesNotRewrite: Array[StreamAggScan] = Array( streamAggScanIR(ToStream(In(0, TArray(TInt32))))(foo => ApplyScanOp(Sum())(foo)), @@ -326,10 +312,10 @@ class SimplifySuite extends HailSuite { }, ) - forAll(doesNotRewrite)(a => Simplify(ctx, a) shouldBe a) + doesNotRewrite.foreach(a => assertSimplifiesTo(a, a)) } - @Test def testArrayLenCollectToTableCount(): Unit = { + test("ArrayLenCollectToTableCount") { val tr = TableRange(10, 10) val a = ArrayLen(GetField(TableCollect(tr), "rows")) assert(a.typ == TInt32) @@ -338,7 +324,7 @@ class SimplifySuite extends HailSuite { assert(s.typ == TInt32) } - @Test def testMatrixColsTableMatrixMapColsWithAggLetDoesNotSimplify(): Unit = { + test("MatrixColsTableMatrixMapColsWithAggLetDoesNotSimplify") { val reader = MatrixRangeReader(ctx, 1, 1, None) var mir: MatrixIR = MatrixRead(reader.fullMatrixType, false, false, reader) val colType = reader.fullMatrixType.colType @@ -351,19 +337,17 @@ class SimplifySuite extends HailSuite { ) val tir = MatrixColsTable(mir) - tir should simplifyTo(tir) + assertSimplifiesTo(tir, tir) } - @Test def testFilterParallelize(): Unit = - forAll( - Array( - MakeStruct(FastSeq( - ("rows", In(0, TArray(TStruct("x" -> TInt32)))), - ("global", In(1, TStruct.empty)), - )), - In(0, TStruct("rows" -> TArray(TStruct("x" -> TInt32)), "global" -> TStruct.empty)), - ) - ) { rowsAndGlobals => + test("FilterParallelize") { + Array( + MakeStruct(FastSeq( + ("rows", In(0, TArray(TStruct("x" -> TInt32)))), + ("global", In(1, TStruct.empty)), + )), + In(0, TStruct("rows" -> TArray(TStruct("x" -> TInt32)), "global" -> TStruct.empty)), + ).foreach { rowsAndGlobals => val tp = TableParallelize(rowsAndGlobals, None) val tf = TableFilter(tp, GetField(Ref(TableIR.rowName, tp.typ.rowType), "x") < 100) @@ -371,8 +355,9 @@ class SimplifySuite extends HailSuite { TypeCheck(ctx, rw) assert(!Exists(rw, _.isInstanceOf[TableFilter])) } + } - @Test def testStreamLenSimplifications(): Unit = { + test("StreamLenSimplifications") { val rangeIR = StreamRange(I32(0), I32(10), I32(1)) val mapOfRange = mapIR(rangeIR)(range_element => range_element + 5) val mapBlockedByLet = @@ -384,7 +369,7 @@ class SimplifySuite extends HailSuite { }) } - @Test def testNestedFilterIntervals(): Unit = { + test("NestedFilterIntervals") { var tir: TableIR = TableRange(10, 5) def r = Ref(TableIR.rowName, tir.typ.rowType) tir = TableMapRows(tir, InsertFields(r, FastSeq("idx2" -> GetField(r, "idx")))) @@ -397,7 +382,7 @@ class SimplifySuite extends HailSuite { )) } - @Test def testSimplifyReadFilterIntervals(): Unit = { + test("SimplifyReadFilterIntervals") { val src = getTestResource("sample-indexed-0.2.52.mt") val mnr = MatrixNativeReader(fs, src, None) @@ -435,9 +420,10 @@ class SimplifySuite extends HailSuite { ), ) - TableFilterIntervals(tr, intervals1, true) should simplifyTo(exp1) + assertSimplifiesTo(TableFilterIntervals(tr, intervals1, true), exp1) - TableFilterIntervals(exp1, intervals2, true) should simplifyTo { + assertSimplifiesTo( + TableFilterIntervals(exp1, intervals2, true), TableRead( tnr.fullType, false, @@ -448,29 +434,31 @@ class SimplifySuite extends HailSuite { Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true)), ), ), - ) - } + ), + ) val ztfi1 = TableFilterIntervals(tzr, intervals1, true) - ztfi1 should simplifyTo { + assertSimplifiesTo( + ztfi1, TableRead( tzr.typ, false, tzrr.copy(options = Some(NativeReaderOptions(intervals1, tnr.fullType.keyType, true))), - ) - } + ), + ) - TableFilterIntervals(ztfi1, intervals2, true) should simplifyTo { + assertSimplifiesTo( + TableFilterIntervals(ztfi1, intervals2, true), TableRead( tzr.typ, false, tzrr.copy(options = Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true))), - ) - } + ), + ) } - @Test(enabled = false) def testFilterIntervalsKeyByToFilter(): Unit = { + test("FilterIntervalsKeyByToFilter".ignore) { var t: TableIR = TableRange(100, 10) t = TableMapRows( t, @@ -494,7 +482,7 @@ class SimplifySuite extends HailSuite { }) } - @Test def testSimplifyArraySlice(): Unit = { + test("SimplifyArraySlice") { val stream = StreamRange(I32(0), I32(10), I32(1)) val streamSlice1 = Simplify(ctx, ArraySlice(ToArray(stream), I32(0), Some(I32(7)))) assert(streamSlice1 match { @@ -527,159 +515,180 @@ class SimplifySuite extends HailSuite { def ref(typ: Type) = Ref(Name("#undefined"), typ) - @DataProvider(name = "unaryBooleanArithmetic") - def unaryBooleanArithmetic: Array[Array[Any]] = - Array( - Array(ApplyUnaryPrimOp(Bang, ApplyUnaryPrimOp(Bang, ref(TBoolean))), ref(TBoolean)) - ).asInstanceOf[Array[Array[Any]]] + object checkUnaryBooleanSimplification extends TestCases { + def apply(input: IR, expected: IR)(implicit loc: munit.Location): Unit = + test("unary boolean simplification")(assertSimplifiesTo(input, expected)) + } - @Test(dataProvider = "unaryBooleanArithmetic") - def testUnaryBooleanSimplification(input: IR, expected: IR): Unit = - input should simplifyTo(expected) + checkUnaryBooleanSimplification( + ApplyUnaryPrimOp(Bang, ApplyUnaryPrimOp(Bang, ref(TBoolean))), + ref(TBoolean), + ) - @DataProvider(name = "unaryIntegralArithmetic") - def unaryIntegralArithmetic: Array[Array[Any]] = - Array(TInt32, TInt64).flatMap { typ => - Array( - Array(ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(Negate, ref(typ))), ref(typ)), - Array(ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(BitNot, ref(typ))), ref(typ)), - Array( - ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), - ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), - ), - Array( - ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), - ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), - ), - ).asInstanceOf[Array[Array[Any]]] - } + object checkUnaryIntegralSimplification extends TestCases { + def apply(input: IR, expected: IR)(implicit loc: munit.Location): Unit = + test("unary integral simplification")(assertSimplifiesTo(input, expected)) + } - @Test(dataProvider = "unaryIntegralArithmetic") - def testUnaryIntegralSimplification(input: IR, expected: IR): Unit = - input should simplifyTo(expected) - - @DataProvider(name = "binaryIntegralArithmetic") - def binaryIntegralArithmetic: Array[Array[Any]] = - Array((Literal.coerce(TInt32, _)) -> TInt32, (Literal.coerce(TInt64, _)) -> TInt64).flatMap { - case (pure, typ) => - Array.concat( - Array( - // Addition - Array( - ApplyBinaryPrimOp(Add(), ref(typ), ref(typ)), - ApplyBinaryPrimOp(Multiply(), pure(2), ref(typ)), - ), - Array(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)), - - // Subtraction - Array(ApplyBinaryPrimOp(Subtract(), ref(typ), ref(typ)), pure(0)), - Array( - ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - Array(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)), - - // Multiplication - Array(ApplyBinaryPrimOp(Multiply(), pure(0), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(0)), pure(0)), - Array(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)), - Array( - ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - Array( - ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - - // Div (truncated to -Inf) - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), ref(typ)), pure(1)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), pure(0), ref(typ)), pure(0)), - Array( - ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(0)), - Die("division by zero", typ), - ), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), ref(typ)), - Array( - ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - - // Bitwise And - Array(ApplyBinaryPrimOp(BitAnd(), pure(0), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(0)), pure(0)), - Array(ApplyBinaryPrimOp(BitAnd(), pure(-1), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(-1)), ref(typ)), - - // Bitwise Or - Array(ApplyBinaryPrimOp(BitOr(), pure(0), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(0)), ref(typ)), - Array(ApplyBinaryPrimOp(BitOr(), pure(-1), ref(typ)), pure(-1)), - Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(-1)), pure(-1)), - - // Bitwise Xor - Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), pure(0)), ref(typ)), - Array(ApplyBinaryPrimOp(BitXOr(), pure(0), ref(typ)), ref(typ)), - ).asInstanceOf[Array[Array[Any]]], - // Shifts - Array(LeftShift(), RightShift(), LogicalRightShift()).flatMap { shift => - Array( - Array(ApplyBinaryPrimOp(shift, pure(0), ref(TInt32)), pure(0)), - Array(ApplyBinaryPrimOp(shift, ref(typ), I32(0)), ref(typ)), - ) - }.asInstanceOf[Array[Array[Any]]], - ) - } + Array(TInt32, TInt64).foreach { typ => + checkUnaryIntegralSimplification( + ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(Negate, ref(typ))), + ref(typ), + ) + checkUnaryIntegralSimplification( + ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(BitNot, ref(typ))), + ref(typ), + ) + checkUnaryIntegralSimplification( + ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), + ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), + ) + checkUnaryIntegralSimplification( + ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), + ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), + ) + } - @Test(dataProvider = "binaryIntegralArithmetic") - def testBinaryIntegralSimplification(input: IR, expected: IR): Unit = - input should simplifyTo(expected) + object checkBinaryIntegralSimplification extends TestCases { + def apply(input: IR, expected: IR)(implicit loc: munit.Location): Unit = + test("binary integral simplification")(assertSimplifiesTo(input, expected)) + } - @DataProvider(name = "floatingIntegralArithmetic") - def binaryFloatingArithmetic: Array[Array[Any]] = - Array( - (Literal.coerce(TFloat32, _)) -> TFloat32, - (Literal.coerce(TFloat64, _)) -> TFloat64, - ).flatMap { case (pure, typ) => - Array( - // Addition - Array(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)), - - // Subtraction - Array(ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), ApplyUnaryPrimOp(Negate, ref(typ))), - Array(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)), - - // Multiplication - Array(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)), - Array( - ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - Array( - ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), + Array[(Any => IR, Type)]( + (Literal.coerce(TInt32, _), TInt32), + (Literal.coerce(TInt64, _), TInt64), + ).foreach { case (pure, typ) => + // Addition + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(Add(), ref(typ), ref(typ)), + ApplyBinaryPrimOp(Multiply(), pure(2), ref(typ)), + ) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)) + + // Subtraction + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Subtract(), ref(typ), ref(typ)), pure(0)) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)) + + // Multiplication + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Multiply(), pure(0), ref(typ)), pure(0)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(0)), pure(0)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) - // Div (truncated to -Inf) - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), ref(typ)), - Array( - ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), - ApplyUnaryPrimOp(Negate, ref(typ)), - ), - ).asInstanceOf[Array[Array[Any]]] + // Div (truncated to -Inf) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), ref(typ)), + pure(1), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), pure(0), ref(typ)), + pure(0), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(0)), + Die("division by zero", typ), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), + ref(typ), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + + // Bitwise And + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitAnd(), pure(0), ref(typ)), pure(0)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(0)), pure(0)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitAnd(), pure(-1), ref(typ)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(-1)), ref(typ)) + + // Bitwise Or + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitOr(), pure(0), ref(typ)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(0)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitOr(), pure(-1), ref(typ)), pure(-1)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(-1)), pure(-1)) + + // Bitwise Xor + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitXOr(), ref(typ), ref(typ)), pure(0)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitXOr(), ref(typ), pure(0)), ref(typ)) + checkBinaryIntegralSimplification(ApplyBinaryPrimOp(BitXOr(), pure(0), ref(typ)), ref(typ)) + + // Shifts + Array(LeftShift(), RightShift(), LogicalRightShift()).foreach { shift => + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(shift, pure(0), ref(TInt32)), + pure(0), + ) + checkBinaryIntegralSimplification( + ApplyBinaryPrimOp(shift, ref(typ), I32(0)), + ref(typ), + ) } + } + + object checkBinaryFloatingSimplification extends TestCases { + def apply(input: IR, expected: IR)(implicit loc: munit.Location): Unit = + test("binary floating simplification")(assertSimplifiesTo(input, expected)) + } - @Test(dataProvider = "binaryIntegralArithmetic") - def testBinaryFloatingSimplification(input: IR, expected: IR): Unit = - input should simplifyTo(expected) + Array[(Any => IR, Type)]( + (Literal.coerce(TFloat32, _), TFloat32), + (Literal.coerce(TFloat64, _), TFloat64), + ).foreach { case (pure, typ) => + // Addition + checkBinaryFloatingSimplification(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)) + checkBinaryFloatingSimplification(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)) + + // Subtraction + checkBinaryFloatingSimplification( + ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + checkBinaryFloatingSimplification(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)) + + // Multiplication + checkBinaryFloatingSimplification(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)) + checkBinaryFloatingSimplification(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)) + checkBinaryFloatingSimplification( + ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + checkBinaryFloatingSimplification( + ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) - @DataProvider(name = "blockMatrixRules") - def blockMatrixRules: Array[Array[Any]] = { + // Div (truncated to -Inf) + checkBinaryFloatingSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), + ref(typ), + ) + checkBinaryFloatingSimplification( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ) + } + + object checkBlockMatrixSimplification extends TestCases { + def apply(input: BlockMatrixIR, expected: BlockMatrixIR)(implicit loc: munit.Location): Unit = + test("block matrix simplification")(assertSimplifiesTo(input, expected)) + } + + { val matrix = ValueToBlockMatrix( MakeArray((1 to 4).map(i => F64(i.toDouble)), TArray(TFloat64)), @@ -687,73 +696,87 @@ class SimplifySuite extends HailSuite { 10, ) - Array( - Array(BlockMatrixBroadcast(matrix, 0 to 1, matrix.shape, matrix.blockSize), matrix), - Array(bmMap(matrix, true)(x => x), matrix), - Array( - bmMap(matrix, true)(_ => F64(2356)), - BlockMatrixBroadcast( - ValueToBlockMatrix(F64(2356), FastSeq(1, 1), matrix.blockSize), - FastSeq(), - matrix.shape, - matrix.blockSize, - ), + checkBlockMatrixSimplification( + BlockMatrixBroadcast(matrix, 0 to 1, matrix.shape, matrix.blockSize), + matrix, + ) + checkBlockMatrixSimplification(bmMap(matrix, true)(x => x), matrix) + checkBlockMatrixSimplification( + bmMap(matrix, true)(_ => F64(2356)), + BlockMatrixBroadcast( + ValueToBlockMatrix(F64(2356), FastSeq(1, 1), matrix.blockSize), + FastSeq(), + matrix.shape, + matrix.blockSize, ), - ).asInstanceOf[Array[Array[Any]]] + ) } - @Test(dataProvider = "blockMatrixRules") - def testBlockMatrixSimplification(input: BlockMatrixIR, expected: BlockMatrixIR): Unit = - input should simplifyTo(expected) - - @DataProvider(name = "SwitchRules") - def switchRules: Array[Array[Any]] = - Array( - Array(I32(-1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(-1)), - Array(I32(1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(1)), - Array( - ref(TInt32), - I32(-1), - IndexedSeq.tabulate(5)(I32), - Switch(ref(TInt32), I32(-1), IndexedSeq.tabulate(5)(I32)), - ), - Array(I32(256), I32(-1), IndexedSeq.empty[IR], I32(-1)), - Array( - ref(TInt32), - I32(-1), - IndexedSeq.empty[IR], - Switch(ref(TInt32), I32(-1), IndexedSeq.empty[IR]), - ), // missingness - ) + object checkSwitchSimplification extends TestCases { + def apply( + x: IR, + default: IR, + cases: IndexedSeq[IR], + expected: BaseIR, + )(implicit loc: munit.Location + ): Unit = + test("switch simplification")(assertSimplifiesTo(Switch(x, default, cases), expected)) + } - @Test(dataProvider = "SwitchRules") - def testTestSwitchSimplification(x: IR, default: IR, cases: IndexedSeq[IR], expected: BaseIR) - : Unit = - Switch(x, default, cases) should simplifyTo(expected) + checkSwitchSimplification(I32(-1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(-1)) + checkSwitchSimplification(I32(1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(1)) + + checkSwitchSimplification( + ref(TInt32), + I32(-1), + IndexedSeq.tabulate(5)(I32), + Switch(ref(TInt32), I32(-1), IndexedSeq.tabulate(5)(I32)), + ) + + checkSwitchSimplification(I32(256), I32(-1), IndexedSeq.empty[IR], I32(-1)) + + checkSwitchSimplification( + ref(TInt32), + I32(-1), + IndexedSeq.empty[IR], + Switch(ref(TInt32), I32(-1), IndexedSeq.empty[IR]), + ) + + object checkIfSimplification extends TestCases { + def apply( + pred: IR, + cnsq: IR, + altr: IR, + expected: BaseIR, + )(implicit loc: munit.Location + ): Unit = + test("if simplification")(assertSimplifiesTo(If(pred, cnsq, altr), expected)) + } - @DataProvider(name = "IfRules") - def ifRules: Array[Array[Any]] = { + { val x = Ref(freshName(), TInt32) val y = Ref(freshName(), TInt32) val c = Ref(freshName(), TBoolean) - Array( - Array(True(), x, Die("Failure", x.typ), x), - Array(False(), Die("Failure", x.typ), x, x), - Array(IsNA(x), NA(x.typ), x, x), - Array(ApplyUnaryPrimOp(Bang, c), x, y, If(c, y, x)), - Array(c, If(c, x, y), y, If(c, x, y)), - Array(c, x, If(c, x, y), If(c, x, y)), - Array(c, x, x, If(IsNA(c), NA(x.typ), x)), - ) + checkIfSimplification(True(), x, Die("Failure", x.typ), x) + checkIfSimplification(False(), Die("Failure", x.typ), x, x) + checkIfSimplification(IsNA(x), NA(x.typ), x, x) + checkIfSimplification(ApplyUnaryPrimOp(Bang, c), x, y, If(c, y, x)) + checkIfSimplification(c, If(c, x, y), y, If(c, x, y)) + checkIfSimplification(c, x, If(c, x, y), If(c, x, y)) + checkIfSimplification(c, x, x, If(IsNA(c), NA(x.typ), x)) } - @Test(dataProvider = "IfRules") - def testIfSimplification(pred: IR, cnsq: IR, altr: IR, expected: BaseIR): Unit = - If(pred, cnsq, altr) should simplifyTo(expected) + object checkMakeStruct extends TestCases { + def apply( + fields: IndexedSeq[(String, IR)], + expected: IR, + )(implicit loc: munit.Location + ): Unit = + test("make struct")(assertSimplifiesTo(MakeStruct(fields), expected)) + } - @DataProvider(name = "MakeStructRules") - def makeStructRules: Array[Array[Any]] = { + { val s = ref(TStruct( "a" -> TInt32, "b" -> TInt64, @@ -762,46 +785,42 @@ class SimplifySuite extends HailSuite { def get(name: String) = GetField(s, name) - Array( - Array( - FastSeq("x" -> get("a")), - CastRename(SelectFields(s, FastSeq("a")), TStruct("x" -> TInt32)), - ), - Array( - FastSeq("x" -> get("a"), "y" -> get("b")), - CastRename(SelectFields(s, FastSeq("a", "b")), TStruct("x" -> TInt32, "y" -> TInt64)), - ), - Array( - FastSeq("a" -> get("a"), "b" -> get("b")), - SelectFields(s, FastSeq("a", "b")), - ), - Array( - FastSeq("a" -> get("a"), "b" -> get("b"), "c" -> get("c")), - s, - ), + checkMakeStruct( + FastSeq("x" -> get("a")), + CastRename(SelectFields(s, FastSeq("a")), TStruct("x" -> TInt32)), ) - } - - @Test(dataProvider = "MakeStructRules") - def testMakeStruct(fields: IndexedSeq[(String, IR)], expected: IR): Unit = - MakeStruct(fields) should simplifyTo(expected) - - @DataProvider(name = "CastRules") - def castRules: Array[Array[Any]] = { - Array( - Array(TInt32, TFloat32, false), - Array(TInt32, TInt64, true), - Array(TInt64, TInt32, false), - Array(TInt32, TFloat64, true), - Array(TFloat32, TFloat64, true), - Array(TFloat64, TFloat32, false), + checkMakeStruct( + FastSeq("x" -> get("a"), "y" -> get("b")), + CastRename(SelectFields(s, FastSeq("a", "b")), TStruct("x" -> TInt32, "y" -> TInt64)), + ) + checkMakeStruct( + FastSeq("a" -> get("a"), "b" -> get("b")), + SelectFields(s, FastSeq("a", "b")), + ) + checkMakeStruct( + FastSeq("a" -> get("a"), "b" -> get("b"), "c" -> get("c")), + s, ) } - @Test(dataProvider = "CastRules") - def testCastSimplify(t1: Type, t2: Type, simplifies: Boolean): Unit = { - val x = ref(t1) - val ir = Cast(Cast(x, t2), t1) - ir should simplifyTo(if (simplifies) x else ir) + object checkCastSimplify extends TestCases { + def apply( + t1: Type, + t2: Type, + simplifies: Boolean, + )(implicit loc: munit.Location + ): Unit = + test("cast simplify") { + val x = ref(t1) + val ir = Cast(Cast(x, t2), t1) + assertSimplifiesTo(ir, if (simplifies) x else ir) + } } + + checkCastSimplify(TInt32, TFloat32, false) + checkCastSimplify(TInt32, TInt64, true) + checkCastSimplify(TInt64, TInt32, false) + checkCastSimplify(TInt32, TFloat64, true) + checkCastSimplify(TFloat32, TFloat64, true) + checkCastSimplify(TFloat64, TFloat32, false) } diff --git a/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala index dd545371bff..88f8c44605e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StagedBTreeSuite.scala @@ -14,12 +14,12 @@ import is.hail.types.physical.stypes.interfaces.primitive import is.hail.types.physical.stypes.primitives.SInt64 import scala.collection.mutable +import scala.concurrent.duration.Duration import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.scalacheck.Gen._ -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll class TestBTreeKey(mb: EmitMethodBuilder[_]) extends BTreeKey { private val comp = mb.ecb.getOrderingFunction(SInt64, SInt64, CodeOrdering.Compare()) @@ -236,45 +236,58 @@ class TestSet { def getElements: Array[java.lang.Long] = map.toArray } -class StagedBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { - - @Test def testBTree(): Unit = { - pool.scopedRegion { region => - val refSet = new TestSet() - val nodeSizeParams = Array( - 2 -> choose(-10, 10), - 3 -> choose(-10, 10), - 5 -> choose(-30, 30), - 6 -> choose(-30, 30), - 22 -> choose(-3, 3), - ) +class StagedBTreeSuite extends HailSuite with munit.ScalaCheckSuite { + var region: Region = _ + val refSet = new TestSet() + + override def beforeAll(): Unit = { + super.beforeAll() + region = pool.getRegion() + } - for ((n, values) <- nodeSizeParams) { - val testSet = new BTreeBackedSet(ctx, region, n) + override def afterAll(): Unit = { + region.close() + region = null + super.afterAll() + } - val sets = containerOf[Array, java.lang.Long](zip(prob(.1), values) - .map { case (m, v) => if (m) null else Long.box(v.longValue()) }) + override val munitTimeout = Duration(60, "s") - val lt = { (l1: java.lang.Long, l2: java.lang.Long) => - !(l1 == null) && ((l2 == null) || (l1 < l2)) - } + property("BTree") { + val refSet = new TestSet() + val nodeSizeParams = Array( + 2 -> choose(-10, 10), + 3 -> choose(-10, 10), + 5 -> choose(-30, 30), + 6 -> choose(-30, 30), + 22 -> choose(-3, 3), + ) - forAll(sets) { set => - refSet.clear() - testSet.clear() - assert(refSet.getElements sameElements testSet.getElements) - - set.forall { v => - refSet.getOrElseInsert(v) - testSet.getOrElseInsert(v) - refSet.getElements.sortWith(lt) sameElements testSet.getElements.sortWith(lt) - } && { - val serialized = testSet.bulkStore - val testSet2 = BTreeBackedSet.bulkLoad(ctx, region, serialized, n) - refSet.getElements.sortWith(lt) sameElements testSet2.getElements.sortWith(lt) - } + nodeSizeParams.map { case (n, values) => + val testSet = new BTreeBackedSet(ctx, region, n) + + val sets = containerOf[Array, java.lang.Long](zip(prob(.1), values) + .map { case (m, v) => if (m) null else Long.box(v.longValue()) }) + + val lt = { (l1: java.lang.Long, l2: java.lang.Long) => + !(l1 == null) && ((l2 == null) || (l1 < l2)) + } + + forAll(sets) { set => + refSet.clear() + testSet.clear() + assert(refSet.getElements sameElements testSet.getElements) + + set.forall { v => + refSet.getOrElseInsert(v) + testSet.getOrElseInsert(v) + refSet.getElements.sortWith(lt) sameElements testSet.getElements.sortWith(lt) + } && { + val serialized = testSet.bulkStore + val testSet2 = BTreeBackedSet.bulkLoad(ctx, region, serialized, n) + refSet.getElements.sortWith(lt) sameElements testSet2.getElements.sortWith(lt) } } - } + }.reduce(_ ++ _) } } diff --git a/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala index 5ae18935e03..d7bc95c3f1e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -17,13 +17,9 @@ import is.hail.variant.{Locus, ReferenceGenome} import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Gen -import org.scalatest -import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class StagedMinHeapSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class StagedMinHeapSuite extends HailSuite with munit.ScalaCheckSuite { implicit object StagedIntCoercions extends StagedCoercions[Int] { override def ti: TypeInfo[Int] = implicitly @@ -36,28 +32,25 @@ class StagedMinHeapSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { sa.asInt.value } - @Test def testSorting(): Unit = - forAll((xs: IndexedSeq[Int]) => assert(sort(xs) == xs.sorted)) - - @Test def testHeapProperty(): Unit = - forAll { (xs: IndexedSeq[Int]) => - val heap = heapify(xs) - scalatest.Inspectors.forAll(0 until heap.size / 2) { i => - assert( - ((2 * i + 1) >= heap.size || heap(i) <= heap(2 * i + 1)) && - ((2 * i + 2) >= heap.size || heap(i) <= heap(2 * i + 2)) - ) - } + property("Sorting") = forAll((xs: IndexedSeq[Int]) => assertEquals(sort(xs), xs.sorted)) + + property("HeapProperty") = forAll { (xs: IndexedSeq[Int]) => + val heap = heapify(xs) + (0 until heap.size / 2).foreach { i => + assert((2 * i + 1) >= heap.size || heap(i) <= heap(2 * i + 1)) + assert((2 * i + 2) >= heap.size || heap(i) <= heap(2 * i + 2)) } + } - @Test def testNonEmpty(): Unit = + test("NonEmpty") { gen(ctx, "NonEmpty") { (heap: IntHeap) => - heap.nonEmpty should be(false) + assertEquals(heap.nonEmpty, false) for (i <- 0 to 10) heap.push(i) - heap.nonEmpty should be(true) + assertEquals(heap.nonEmpty, true) for (_ <- 0 to 10) heap.pop() - heap.nonEmpty should be(false) + assertEquals(heap.nonEmpty, false) } + } val loci: Gen[(ReferenceGenome, IndexedSeq[Locus])] = for { @@ -65,21 +58,20 @@ class StagedMinHeapSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { loci <- Gen.containerOf[IndexedSeq, Locus](genLocus(rg)) } yield (rg, loci) - @Test def testLocus(): Unit = - forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => - ctx.local(references = Map(rg.name -> rg)) { ctx => - implicit val coercions: StagedCoercions[Locus] = - stagedLocusCoercions(rg) + property("Locus") = forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => + ctx.local(references = Map(rg.name -> rg)) { ctx => + implicit val coercions: StagedCoercions[Locus] = + stagedLocusCoercions(rg) - val sortedLoci = - gen[Locus, LocusHeap, IndexedSeq[Locus]](ctx, "Locus") { (heap: LocusHeap) => - loci.foreach(heap.push) - IndexedSeq.fill(loci.size)(heap.pop()) - } + val sortedLoci = + gen[Locus, LocusHeap, IndexedSeq[Locus]](ctx, "Locus") { (heap: LocusHeap) => + loci.foreach(heap.push) + IndexedSeq.fill(loci.size)(heap.pop()) + } - assert(sortedLoci == loci.sorted(rg.locusOrdering)) - } + assertEquals(sortedLoci, loci.sorted(rg.locusOrdering)) } + } def sort(xs: IndexedSeq[Int]): IndexedSeq[Int] = gen(ctx, "Sort") { (heap: IntHeap) => diff --git a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala index f6b72e140a7..dcf8407bb15 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringFunctionsSuite.scala @@ -8,12 +8,11 @@ import is.hail.expr.ir.defs.{F32, I32, I64, MakeTuple, NA, Str} import is.hail.types.virtual._ import org.json4s.jackson.JsonMethods -import org.testng.annotations.{DataProvider, Test} class StringFunctionsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @Test def testRegexMatch(): Unit = { + test("RegexMatch") { assertEvalsTo(invoke("regexMatch", TBoolean, Str("a"), NA(TString)), null) assertEvalsTo(invoke("regexMatch", TBoolean, NA(TString), Str("b")), null) @@ -24,27 +23,27 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("regexMatch", TBoolean, Str("[a-z][0-9]"), Str("3x")), false) } - @Test def testLength(): Unit = { + test("Length") { assertEvalsTo(invoke("length", TInt32, Str("ab")), 2) assertEvalsTo(invoke("length", TInt32, Str("")), 0) assertEvalsTo(invoke("length", TInt32, NA(TString)), null) } - @Test def testSubstring(): Unit = { + test("Substring") { assertEvalsTo(invoke("substring", TString, Str("ab"), 0, 1), "a") assertEvalsTo(invoke("substring", TString, Str("ab"), NA(TInt32), 1), null) assertEvalsTo(invoke("substring", TString, Str("ab"), 0, NA(TInt32)), null) assertEvalsTo(invoke("substring", TString, NA(TString), 0, 1), null) } - @Test def testConcat(): Unit = { + test("Concat") { assertEvalsTo(invoke("concat", TString, Str("a"), NA(TString)), null) assertEvalsTo(invoke("concat", TString, NA(TString), Str("b")), null) assertEvalsTo(invoke("concat", TString, Str("a"), Str("b")), "ab") } - @Test def testSplit(): Unit = { + test("Split") { assertEvalsTo(invoke("split", TArray(TString), NA(TString), Str(",")), null) assertEvalsTo(invoke("split", TArray(TString), Str("a,b,c"), NA(TString)), null) @@ -61,7 +60,7 @@ class StringFunctionsSuite extends HailSuite { ) } - @Test def testReplace(): Unit = { + test("Replace") { assertEvalsTo(invoke("replace", TString, NA(TString), Str(","), Str(".")), null) assertEvalsTo(invoke("replace", TString, Str("a,b,c"), NA(TString), Str(".")), null) assertEvalsTo(invoke("replace", TString, Str("a,b,c"), Str(","), NA(TString)), null) @@ -69,7 +68,7 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("replace", TString, Str("a,b,c"), Str(","), Str(".")), "a.b.c") } - @Test def testArrayMkString(): Unit = { + test("ArrayMkString") { assertEvalsTo(invoke("mkString", TString, IRStringArray("a", "b", "c"), NA(TString)), null) assertEvalsTo(invoke("mkString", TString, NA(TArray(TString)), Str(",")), null) assertEvalsTo(invoke("mkString", TString, IRStringArray("a", "b", "c"), Str(",")), "a,b,c") @@ -78,7 +77,7 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mkString", TString, IRStringArray("a", null, "c"), Str(",")), "a,null,c") } - @Test def testSetMkString(): Unit = { + test("SetMkString") { assertEvalsTo(invoke("mkString", TString, IRStringSet("a", "b", "c"), NA(TString)), null) assertEvalsTo(invoke("mkString", TString, NA(TSet(TString)), Str(",")), null) assertEvalsTo(invoke("mkString", TString, IRStringSet("a", "b", "c"), Str(",")), "a,b,c") @@ -87,7 +86,7 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mkString", TString, IRStringSet("a", null, "c"), Str(",")), "a,c,null") } - @Test def testFirstMatchIn(): Unit = { + test("FirstMatchIn") { assertEvalsTo(invoke("firstMatchIn", TArray(TString), Str("""([a-zA-Z]+)"""), Str("1")), null) assertEvalsTo( invoke("firstMatchIn", TArray(TString), Str("Hello world!"), Str("""([a-zA-Z]+)""")), @@ -99,39 +98,53 @@ class StringFunctionsSuite extends HailSuite { ) } - @Test def testHammingDistance(): Unit = { + test("HammingDistance") { assertEvalsTo(invoke("hamming", TInt32, Str("foo"), NA(TString)), null) assertEvalsTo(invoke("hamming", TInt32, Str("foo"), Str("fool")), null) assertEvalsTo(invoke("hamming", TInt32, Str("foo"), Str("fol")), 1) } - @DataProvider(name = "str") - def strData(): Array[Array[Any]] = Array( - Array(NA(TString), TString), - Array(NA(TStruct("x" -> TInt32)), TStruct("x" -> TInt32)), - Array(F32(3.14f), TFloat32), - Array(I64(7), TInt64), - Array(IRArray(1, null, 5), TArray(TInt32)), - Array(MakeTuple.ordered(FastSeq(1, NA(TInt32), 5.7)), TTuple(TInt32, TInt32, TFloat64)), + val strData: Array[(IR, Type)] = Array( + (NA(TString), TString), + (NA(TStruct("x" -> TInt32)), TStruct("x" -> TInt32)), + (F32(3.14f), TFloat32), + (I64(7), TInt64), + (IRArray(1, null, 5), TArray(TInt32)), + (MakeTuple.ordered(FastSeq(1, NA(TInt32), 5.7)), TTuple(TInt32, TInt32, TFloat64)), ) - @Test(dataProvider = "str") - def str(annotation: IR, typ: Type): Unit = - assertEvalsTo( - invoke("str", TString, annotation), { - val a = eval(annotation); if (a == null) null else typ.str(a) - }, - ) + object checkStr extends TestCases { + def apply( + annotation: IR, + typ: Type, + )(implicit loc: munit.Location + ): Unit = test("str") { + assertEvalsTo( + invoke("str", TString, annotation), { + val a = eval(annotation); if (a == null) null else typ.str(a) + }, + ) + } + } - @Test(dataProvider = "str") - def json(annotation: IR, typ: Type): Unit = - assertEvalsTo( - invoke("json", TString, annotation), - JsonMethods.compact(typ.export(eval(annotation))), - ) + strData.foreach { case (a, t) => checkStr(a, t) } + + object checkJson extends TestCases { + def apply( + annotation: IR, + typ: Type, + )(implicit loc: munit.Location + ): Unit = test("json") { + assertEvalsTo( + invoke("json", TString, annotation), + JsonMethods.compact(typ.export(eval(annotation))), + ) + } + } - @DataProvider(name = "time") - def timeData(): Array[Array[Any]] = Array( + strData.foreach { case (a, t) => checkJson(a, t) } + + val timeData: Array[(String, String, Long)] = Array( // □ = untested // ■ = tested // ⊗ = unimplemented @@ -139,53 +152,63 @@ class StringFunctionsSuite extends HailSuite { // % A a B b C c D d e F G g H I j k l M m n p R r S s T t U u V v W w X x Y y Z z // ■ ■ ■ ■ ■ ⊗ ⊗ ■ ■ ■ ■ ⊗ ⊗ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⊗ ⊗ ⊗ ■ ■ ⊗ ■ - Array("%t%%%n%s", "\t%\n123456789", 123456789), - Array("%m/%d/%y %I:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), - Array("%m/%d/%y %I:%M:%S %p", "07/08/19 03:00:01 AM", 1562569201), - Array("%Y.%m.%d %H:%M:%S %z", "1997.10.10 23:45:23 -04:00", 876541523), - Array("%Y.%m.%d %H:%M:%S %Z", "2019.07.08 03:00:01 America/New_York", 1562569201), - Array("day %j of %Y. %R:%S", "day 283 of 1997. 23:45:23", 876541523), - Array("day %j of %Y. %R:%S", "day 189 of 2019. 03:00:01", 1562569201), - Array("day %j of %Y. %R:%S", "day 001 of 1970. 22:46:40", 100000), - Array("%v %T", "10-Oct-1997 23:45:23", 876541523), - Array("%v %T", " 8-Jul-2019 03:00:01", 1562569201), - Array("%A, %B %e, %Y. %r", "Friday, October 10, 1997. 11:45:23 PM", 876541523), - Array("%A, %B %e, %Y. %r", "Monday, July 8, 2019. 03:00:01 AM", 1562569201), - Array("%a, %b %e, '%y. %I:%M:%S %p", "Fri, Oct 10, '97. 11:45:23 PM", 876541523), - Array("%a, %b %e, '%y. %I:%M:%S %p", "Mon, Jul 8, '19. 03:00:01 AM", 1562569201), - Array("%D %l:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), - Array("%D %l:%M:%S %p", "07/08/19 3:00:01 AM", 1562569201), - Array("%F %k:%M:%S", "1997-10-10 23:45:23", 876541523), - Array("%F %k:%M:%S", "2019-07-08 3:00:01", 1562569201), - Array( + ("%t%%%n%s", "\t%\n123456789", 123456789L), + ("%m/%d/%y %I:%M:%S %p", "10/10/97 11:45:23 PM", 876541523L), + ("%m/%d/%y %I:%M:%S %p", "07/08/19 03:00:01 AM", 1562569201L), + ("%Y.%m.%d %H:%M:%S %z", "1997.10.10 23:45:23 -04:00", 876541523L), + ("%Y.%m.%d %H:%M:%S %Z", "2019.07.08 03:00:01 America/New_York", 1562569201L), + ("day %j of %Y. %R:%S", "day 283 of 1997. 23:45:23", 876541523L), + ("day %j of %Y. %R:%S", "day 189 of 2019. 03:00:01", 1562569201L), + ("day %j of %Y. %R:%S", "day 001 of 1970. 22:46:40", 100000L), + ("%v %T", "10-Oct-1997 23:45:23", 876541523L), + ("%v %T", " 8-Jul-2019 03:00:01", 1562569201L), + ("%A, %B %e, %Y. %r", "Friday, October 10, 1997. 11:45:23 PM", 876541523L), + ("%A, %B %e, %Y. %r", "Monday, July 8, 2019. 03:00:01 AM", 1562569201L), + ("%a, %b %e, '%y. %I:%M:%S %p", "Fri, Oct 10, '97. 11:45:23 PM", 876541523L), + ("%a, %b %e, '%y. %I:%M:%S %p", "Mon, Jul 8, '19. 03:00:01 AM", 1562569201L), + ("%D %l:%M:%S %p", "10/10/97 11:45:23 PM", 876541523L), + ("%D %l:%M:%S %p", "07/08/19 3:00:01 AM", 1562569201L), + ("%F %k:%M:%S", "1997-10-10 23:45:23", 876541523L), + ("%F %k:%M:%S", "2019-07-08 3:00:01", 1562569201L), + ( "ISO 8601 week day %u. %Y.%m.%d %H:%M:%S", "ISO 8601 week day 4. 1970.01.01 22:46:40", - 100000, + 100000L, ), - Array( + ( "Week number %U of %Y. %Y.%m.%d %H:%M:%S", "Week number 00 of 1973. 1973.01.01 10:33:20", - 94750400, - ), - Array( - "ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", - "ISO 8601 week #53. 2005.01.02 00:00:00", - 1104642000, - ), - Array( - "ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", - "ISO 8601 week #01. 2005.01.03 00:00:00", - 1104728400, + 94750400L, ), - Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #00. 2005.01.02 00:00:00", 1104642000), - Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #01. 2005.01.03 00:00:00", 1104728400), + ("ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", "ISO 8601 week #53. 2005.01.02 00:00:00", 1104642000L), + ("ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", "ISO 8601 week #01. 2005.01.03 00:00:00", 1104728400L), + ("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #00. 2005.01.02 00:00:00", 1104642000L), + ("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #01. 2005.01.03 00:00:00", 1104728400L), ) - @Test(dataProvider = "time") - def strftime(fmt: String, s: String, t: Long): Unit = - assertEvalsTo(invoke("strftime", TString, Str(fmt), I64(t), Str("America/New_York")), s) + object checkStrftime extends TestCases { + def apply( + fmt: String, + s: String, + t: Long, + )(implicit loc: munit.Location + ): Unit = test("strftime") { + assertEvalsTo(invoke("strftime", TString, Str(fmt), I64(t), Str("America/New_York")), s) + } + } + + timeData.foreach { case (fmt, s, t) => checkStrftime(fmt, s, t) } + + object checkStrptime extends TestCases { + def apply( + fmt: String, + s: String, + t: Long, + )(implicit loc: munit.Location + ): Unit = test("strptime") { + assertEvalsTo(invoke("strptime", TInt64, Str(s), Str(fmt), Str("America/New_York")), t) + } + } - @Test(dataProvider = "time") - def strptime(fmt: String, s: String, t: Long): Unit = - assertEvalsTo(invoke("strptime", TInt64, Str(s), Str(fmt), Str("America/New_York")), t) + timeData.foreach { case (fmt, s, t) => checkStrptime(fmt, s, t) } } diff --git a/hail/hail/test/src/is/hail/expr/ir/StringLengthSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringLengthSuite.scala index 6c7c8563bb9..8b80f1cb3b0 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringLengthSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringLengthSuite.scala @@ -5,12 +5,10 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.expr.ir.defs.Str import is.hail.types.virtual.TInt32 -import org.testng.annotations.Test - class StringLengthSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @Test def sameAsJavaStringLength(): Unit = { + test("sameAsJavaStringLength") { val strings = Array("abc", "", "\uD83D\uDCA9") for (s <- strings) assertEvalsTo(invoke("length", TInt32, Str(s)), s.length) diff --git a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala index ea5882b3aa5..c23b1185f9c 100644 --- a/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/StringSliceSuite.scala @@ -6,13 +6,10 @@ import is.hail.collection.FastSeq import is.hail.expr.ir.defs.{I32, In, NA, Str} import is.hail.types.virtual.TString -import org.scalatest.Inspectors.forEvery -import org.testng.annotations.Test - class StringSliceSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly - @Test def unicodeSlicingSlicesCodePoints(): Unit = { + test("unicodeSlicingSlicesCodePoints") { val poopEmoji = "\uD83D\uDCA9" val s = s"abc${poopEmoji}def" @@ -27,23 +24,24 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("slice", TString, Str(s), I32(0), I32(5)), s"abc$poopEmoji") } - @Test def zeroToLengthIsIdentity(): Unit = + test("zeroToLengthIsIdentity") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(3)), "abc") + } - @Test def simpleSlicesMatchIntuition(): Unit = { + test("simpleSlicesMatchIntuition") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(3), I32(3)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(2), I32(3)), "c") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(2)), "ab") } - @Test def sizeZeroSliceIsEmptyString(): Unit = { + test("sizeZeroSliceIsEmptyString") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(2), I32(2)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(1)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(0)), "") } - @Test def substringMatchesJavaStringSubstring(): Unit = { + test("substringMatchesJavaStringSubstring") { assertEvalsTo( invoke("substring", TString, Str("abc"), I32(0), I32(2)), "abc".substring(0, 2), @@ -54,36 +52,37 @@ class StringSliceSuite extends HailSuite { ) } - @Test def isStrict(): Unit = { + test("isStrict") { assertEvalsTo(invoke("slice", TString, NA(TString), I32(0), I32(2)), null) assertEvalsTo(invoke("slice", TString, NA(TString), I32(-5), I32(-10)), null) } - @Test def leftSliceMatchesIntuition(): Unit = { + test("leftSliceMatchesIntuition") { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(2)), "c") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(1)), "bc") } - @Test def rightSliceMatchesIntuition(): Unit = { + test("rightSliceMatchesIntuition") { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(2)), "ab") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(1)), "a") } - @Test def bothSideSliceMatchesIntuition(): Unit = + test("bothSideSliceMatchesIntuition") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(2)), "ab") - // assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") + // assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") + } - @Test def leftSliceIsPythony(): Unit = { + test("leftSliceIsPythony") { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-1)), "c") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-2)), "bc") } - @Test def rightSliceIsPythony(): Unit = { + test("rightSliceIsPythony") { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-1)), "ab") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-2)), "a") } - @Test def sliceIsPythony(): Unit = { + test("sliceIsPythony") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-3), I32(-1)), "ab") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-3), I32(-2)), "a") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-2), I32(-1)), "b") @@ -92,7 +91,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(-1)), "b") } - @Test def bothSidesSliceFunctionOutOfBoundsNotFatal(): Unit = { + test("bothSidesSliceFunctionOutOfBoundsNotFatal") { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(4), I32(4)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(3), I32(2)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-1), I32(2)), "") @@ -103,7 +102,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-10), I32(-1)), "ab") } - @Test def leftSliceFunctionOutOfBoundsNotFatal(): Unit = { + test("leftSliceFunctionOutOfBoundsNotFatal") { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(15)), "") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(4)), "") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(3)), "") @@ -112,7 +111,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-100)), "abc") } - @Test def rightSliceFunctionOutOfBoundsNotFatal(): Unit = { + test("rightSliceFunctionOutOfBoundsNotFatal") { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(15)), "abc") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(4)), "abc") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(3)), "abc") @@ -121,7 +120,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-100)), "") } - @Test def testStringIndex(): Unit = { + test("StringIndex") { assertEvalsTo(invoke("index", TString, In(0, TString), I32(0)), FastSeq("Baz" -> TString), "B") assertEvalsTo(invoke("index", TString, In(0, TString), I32(1)), FastSeq("Baz" -> TString), "a") assertEvalsTo(invoke("index", TString, In(0, TString), I32(2)), FastSeq("Baz" -> TString), "z") @@ -129,7 +128,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("index", TString, In(0, TString), I32(-2)), FastSeq("Baz" -> TString), "a") assertEvalsTo(invoke("index", TString, In(0, TString), I32(-3)), FastSeq("Baz" -> TString), "B") - forEvery(execStrats) { implicit strat => + execStrats.foreach { implicit strat => interceptFatal("string index out of bounds") { evaluate( ctx, @@ -139,7 +138,7 @@ class StringSliceSuite extends HailSuite { } } - forEvery(execStrats) { implicit strat => + execStrats.foreach { implicit strat => interceptFatal("string index out of bounds") { evaluate( ctx, diff --git a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala index 01f244c8648..ee6a9f8cc41 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TableIRSuite.scala @@ -20,16 +20,13 @@ import is.hail.variant.Locus import scala.collection.compat._ import org.apache.spark.sql.Row -import org.scalatest.{Failed, Succeeded} -import org.scalatest.Inspectors.forAll -import org.testng.annotations.{DataProvider, Test} class TableIRSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) - @Test def testRangeCount(): Unit = { + test("RangeCount") { val node1 = TableCount(TableRange(10, 2)) val node2 = TableCount(TableRange(15, 5)) val node = ApplyBinaryPrimOp(Add(), node1, node2) @@ -38,14 +35,14 @@ class TableIRSuite extends HailSuite { assertEvalsTo(node, 25L) } - @Test def testForceCount(): Unit = { + test("ForceCount") { implicit val execStrats = ExecStrategy.interpretOnly val tableRangeSize = Int.MaxValue / 20 val forceCountRange = TableToValueApply(TableRange(tableRangeSize, 2), ForceCountTable()) assertEvalsTo(forceCountRange, tableRangeSize.toLong) } - @Test def testRangeRead(): Unit = { + test("RangeRead") { implicit val execStrats = ExecStrategy.lowering val original = TableKeyBy( TableMapGlobals(TableRange(10, 3), MakeStruct(FastSeq("foo" -> I32(57)))), @@ -69,13 +66,13 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCollect(droppedRows), Row(FastSeq(), expectedGlobals)) } - @Test def testCountRead(): Unit = { + test("CountRead") { implicit val execStrats = ExecStrategy.lowering val tir: TableIR = TableRead.native(fs, getTestResource("three_key.ht")) assertEvalsTo(TableCount(tir), 120L) } - @Test def testRangeCollect(): Unit = { + test("RangeCollect") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val row = Ref(TableIR.rowName, t.typ.rowType) @@ -84,7 +81,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(node, Row(ArraySeq.tabulate(10)(i => Row(i, i)), Row())) } - @Test def testNestedRangeCollect(): Unit = { + test("NestedRangeCollect") { implicit val execStrats = ExecStrategy.allRelational val r = TableRange(2, 2) @@ -106,7 +103,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testRangeSum(): Unit = { + test("RangeSum") { implicit val execStrats = ExecStrategy.interpretOnly val t = TableRange(10, 2) val row = Ref(TableIR.rowName, t.typ.rowType) @@ -123,7 +120,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testGetGlobals(): Unit = { + test("GetGlobals") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val newGlobals = @@ -132,7 +129,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(node, Row(Row(ArraySeq.tabulate(10)(i => Row(i)), Row()))) } - @Test def testCollectGlobals(): Unit = { + test("CollectGlobals") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val newGlobals = @@ -151,7 +148,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(collect(node), Row(expected, Row(collectedT))) } - @Test def testRangeExplode(): Unit = { + test("RangeExplode") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val row = Ref(TableIR.rowName, t.typ.rowType) @@ -178,7 +175,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(collect(node2), Row(expected2, Row())) } - @Test def testFilter(): Unit = { + test("Filter") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val node = TableFilter( @@ -198,7 +195,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(collect(node), Row(expected, Row(4))) } - @Test def testFilterIntervals(): Unit = { + test("FilterIntervals") { implicit val execStrats = ExecStrategy.allRelational def assertFilterIntervals( @@ -266,7 +263,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableMapWithLiterals(): Unit = { + test("TableMapWithLiterals") { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val node = TableMapRows( @@ -284,7 +281,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(collect(node), Row(expected, Row())) } - @Test def testScanCountBehavesLikeIndex(): Unit = { + test("ScanCountBehavesLikeIndex") { implicit val execStrats = ExecStrategy.interpretOnly val t = rangeKT val oldRow = Ref(TableIR.rowName, t.typ.rowType) @@ -301,7 +298,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testScanCollectBehavesLikeRange(): Unit = { + test("ScanCollectBehavesLikeRange") { implicit val execStrats = ExecStrategy.interpretOnly val t = rangeKT val oldRow = Ref(TableIR.rowName, t.typ.rowType) @@ -530,153 +527,166 @@ class TableIRSuite extends HailSuite { ("inner", (row: Row) => !row.isNullAt(1) && !row.isNullAt(3)), ) - @DataProvider(name = "join") - def joinData(): Array[Array[Any]] = { + object checkTableJoin extends TestCases { + def apply( + lParts: Int, + rParts: Int, + joinType: String, + pred: Row => Boolean, + leftProject: Set[Int], + rightProject: Set[Int], + )(implicit loc: munit.Location + ): Unit = test("TableJoin") { + val (leftType, leftProjectF) = rowType.filter(f => !leftProject.contains(f.index)) + val left = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(leftType), "global" -> TStruct.empty), + Row(leftData.map(leftProjectF.asInstanceOf[Row => Row]), Row()), + ), + Some(lParts), + ), + if (!leftProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), + ) + + val (rightType, rightProjectF) = rowType.filter(f => !rightProject.contains(f.index)) + val right = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(rightType), "global" -> TStruct.empty), + Row(rightData.map(rightProjectF.asInstanceOf[Row => Row]), Row()), + ), + Some(rParts), + ), + if (!rightProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), + ) + + val (_, joinProjectF) = + joinedType.filter(f => + !leftProject.contains(f.index) && !rightProject.contains(f.index - 2) + ) + val joined = collect( + TableJoin( + left, + TableRename( + right, + Array("A", "B", "C") + .filter(right.typ.rowType.hasField) + .map(a => a -> (a + "_")) + .toMap, + Map.empty, + ), + joinType, + 1, + ) + ) + + assertEvalsTo(joined, Row(expectedOuterJoin.filter(pred).map(joinProjectF), Row())) + } + } + + { val defaultLParts = 2 val defaultRParts = 2 val defaultLeftProject = Set(1, 2) val defaultRightProject = Set(1, 2) - val ab = Array.newBuilder[Array[Any]] for ((j, p) <- joinTypes) { for { lParts <- Array[Integer](1, 2, 3) rParts <- Array[Integer](1, 2, 3) } - ab += Array[Any](lParts, rParts, j, p, defaultLeftProject, defaultRightProject) + checkTableJoin(lParts, rParts, j, p, defaultLeftProject, defaultRightProject) for { leftProject <- Seq[Set[Int]](Set(), Set(1), Set(2), Set(1, 2)) rightProject <- Seq[Set[Int]](Set(), Set(1), Set(2), Set(1, 2)) if !leftProject.contains(1) || rightProject.contains(1) } - ab += Array[Any](defaultLParts, defaultRParts, j, p, leftProject, rightProject) + checkTableJoin(defaultLParts, defaultRParts, j, p, leftProject, rightProject) } - ab.result() } - @Test(dataProvider = "join") - def testTableJoin( - lParts: Int, - rParts: Int, - joinType: String, - pred: Row => Boolean, - leftProject: Set[Int], - rightProject: Set[Int], - ): Unit = { - val (leftType, leftProjectF) = rowType.filter(f => !leftProject.contains(f.index)) - val left = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(leftType), "global" -> TStruct.empty), - Row(leftData.map(leftProjectF.asInstanceOf[Row => Row]), Row()), - ), - Some(lParts), - ), - if (!leftProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), - ) - - val (rightType, rightProjectF) = rowType.filter(f => !rightProject.contains(f.index)) - val right = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(rightType), "global" -> TStruct.empty), - Row(rightData.map(rightProjectF.asInstanceOf[Row => Row]), Row()), - ), - Some(rParts), - ), - if (!rightProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), - ) - - val (_, joinProjectF) = - joinedType.filter(f => !leftProject.contains(f.index) && !rightProject.contains(f.index - 2)) - val joined = collect( - TableJoin( - left, - TableRename( - right, - Array("A", "B", "C") - .filter(right.typ.rowType.hasField) - .map(a => a -> (a + "_")) - .toMap, - Map.empty, + val unionData: Array[(Int, Int)] = + (for { + lParts <- Array(1, 2, 3) + rParts <- Array(1, 2, 3) + } yield (lParts, rParts)) + + object checkTableUnion extends TestCases { + def apply( + lParts: Int, + rParts: Int, + )(implicit loc: munit.Location + ): Unit = test("TableUnion") { + val left = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), + Row(leftData, Row()), + ), + Some(lParts), ), - joinType, - 1, + FastSeq("A", "B"), ) - ) - assertEvalsTo(joined, Row(expectedOuterJoin.filter(pred).map(joinProjectF), Row())) - } - - @DataProvider(name = "union") - def unionData(): Array[Array[Any]] = - for { - lParts <- Array[Integer](1, 2, 3) - rParts <- Array[Integer](1, 2, 3) - } yield Array[Any](lParts, rParts) - - @Test(dataProvider = "union") - def testTableUnion(lParts: Int, rParts: Int): Unit = { - val left = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(leftData, Row()), - ), - Some(lParts), - ), - FastSeq("A", "B"), - ) - - val right = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(rightData, Row()), + val right = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), + Row(rightData, Row()), + ), + Some(rParts), ), - Some(rParts), - ), - FastSeq("A", "B"), - ) + FastSeq("A", "B"), + ) - val merged = collect(TableUnion(FastSeq(left, right))) + val merged = collect(TableUnion(FastSeq(left, right))) - assertEvalsTo(merged, Row(expectedUnion, Row())) + assertEvalsTo(merged, Row(expectedUnion, Row())) + } } - @Test(dataProvider = "union") - def testTableMultiWayZipJoin(lParts: Int, rParts: Int): Unit = { - implicit val execStrats = Set(ExecStrategy.LoweredJVMCompile) - val left = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(leftData, Row()), + object checkTableMultiWayZipJoin extends TestCases { + def apply( + lParts: Int, + rParts: Int, + )(implicit loc: munit.Location + ): Unit = test("TableMultiWayZipJoin") { + implicit val execStrats = Set(ExecStrategy.LoweredJVMCompile) + val left = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), + Row(leftData, Row()), + ), + Some(lParts), ), - Some(lParts), - ), - FastSeq("A", "B"), - ) + FastSeq("A", "B"), + ) - val right = TableKeyBy( - TableParallelize( - Literal( - TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(rightData, Row()), + val right = TableKeyBy( + TableParallelize( + Literal( + TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), + Row(rightData, Row()), + ), + Some(rParts), ), - Some(rParts), - ), - FastSeq("A", "B"), - ) + FastSeq("A", "B"), + ) - val merged = collect(TableMultiWayZipJoin(FastSeq(left, right), "row", "global")) + val merged = collect(TableMultiWayZipJoin(FastSeq(left, right), "row", "global")) - assertEvalsTo(merged, Row(expectedZipJoin, Row(FastSeq(Row(), Row())))) + assertEvalsTo(merged, Row(expectedZipJoin, Row(FastSeq(Row(), Row())))) + } } + unionData.foreach { case (l, r) => checkTableUnion(l, r) } + unionData.foreach { case (l, r) => checkTableMultiWayZipJoin(l, r) } + // Catches a bug in the partitioner created by the importer. - @Test def testTableJoinOfImport(): Unit = { + test("TableJoinOfImport") { val mnr = MatrixNativeReader(fs, getTestResource("sample.vcf.mt")) val mt2 = MatrixRead(mnr.fullMatrixType, false, false, mnr) val t2 = MatrixRowsTable(mt2) @@ -690,7 +700,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(join), 346L) } - @Test def testNativeReaderWithOverlappingPartitions(): Unit = { + test("NativeReaderWithOverlappingPartitions") { val path = getTestResource("sample.vcf-20-partitions-with-overlap.mt/rows") // i1 overlaps the first two partitions val i1 = Interval(Row(Locus("20", 10200000)), Row(Locus("20", 10500000)), true, true) @@ -709,7 +719,7 @@ class TableIRSuite extends HailSuite { test(true, 2) } - @Test def testTableKeyBy(): Unit = { + test("TableKeyBy") { implicit val execStrats = ExecStrategy.interpretOnly val data = ArraySeq(ArraySeq("A", 1), ArraySeq("A", 2), ArraySeq("B", 1)) val rdd = sc.parallelize(data.map(Row.fromSeq(_))) @@ -732,7 +742,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(distinctCount, 2L) } - @Test def testTableKeyByLowering(): Unit = { + test("TableKeyByLowering") { implicit val execStrats = ExecStrategy.lowering val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), @@ -747,13 +757,13 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(keyed), length.toLong) } - @Test def testTableParallelize(): Unit = { + test("TableParallelize") { implicit val execStrats = ExecStrategy.allRelational val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString), ) - forAll(Array(1, 10, 17, 34, 103)) { length => + Array(1, 10, 17, 34, 103).foreach { length => val value = Row(FastSeq(0 until length: _*).map(i => Row(i, "row" + i)), Row("global")) assertEvalsTo( collectNoKey( @@ -769,7 +779,7 @@ class TableIRSuite extends HailSuite { } } - @Test def testTableParallelizeCount(): Unit = { + test("TableParallelizeCount") { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.allRelational val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), @@ -790,7 +800,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableHead(): Unit = { + test("TableHead") { val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString), @@ -802,9 +812,9 @@ class TableIRSuite extends HailSuite { val initialDataLength = 10 val initialData = makeData(initialDataLength) - forAll(numRowsToTakeArray) { howManyRowsToTake => + numRowsToTakeArray.foreach { howManyRowsToTake => val headData = makeData(Math.min(howManyRowsToTake, initialDataLength)) - forAll(numInitialPartitionsArray) { howManyInitialPartitions => + numInitialPartitionsArray.foreach { howManyInitialPartitions => assertEvalsTo( collectNoKey( TableHead( @@ -821,7 +831,7 @@ class TableIRSuite extends HailSuite { } } - @Test def testTableTail(): Unit = { + test("TableTail") { val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString), @@ -838,9 +848,9 @@ class TableIRSuite extends HailSuite { ) val initialData = makeData(initialDataLength) - forAll(numRowsToTakeArray) { howManyRowsToTake => + numRowsToTakeArray.foreach { howManyRowsToTake => val headData = makeData(Math.min(howManyRowsToTake, initialDataLength)) - forAll(numInitialPartitionsArray) { howManyInitialPartitions => + numInitialPartitionsArray.foreach { howManyInitialPartitions => assertEvalsTo( collectNoKey( TableTail( @@ -857,7 +867,7 @@ class TableIRSuite extends HailSuite { } } - @Test def testShuffleAndJoinDoesntMemoryLeak(): Unit = { + test("ShuffleAndJoinDoesntMemoryLeak") { implicit val execStrats = Set(ExecStrategy.LoweredJVMCompile, ExecStrategy.Interpret) val row = Ref(TableIR.rowName, TStruct("idx" -> TInt32)) val t1 = TableRename(TableRange(1, 1), Map("idx" -> "idx_"), Map.empty) @@ -873,7 +883,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(TableJoin(t1, t2, "left")), 1L) } - @Test def testTableRename(): Unit = { + test("TableRename") { implicit val execStrats = ExecStrategy.lowering val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), @@ -915,7 +925,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableMapGlobals(): Unit = { + test("TableMapGlobals") { val t = TStruct( "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString), @@ -948,7 +958,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableWrite(): Unit = { + test("TableWrite") { val table = TableRange(5, 4) val path = ctx.createTmpPath("test-table-write", "ht") Interpret[Unit](ctx, TableWrite(table, TableNativeWriter(path))) @@ -960,7 +970,7 @@ class TableIRSuite extends HailSuite { assert(before.rdd.collect().toFastSeq == after.rdd.collect().toFastSeq) } - @Test def testWriteKeyDistinctness(): Unit = { + test("WriteKeyDistinctness") { val rt = TableRange(40, 4) val idxRef = GetField(Ref(TableIR.rowName, rt.typ.rowType), "idx") val at = TableMapRows( @@ -1011,7 +1021,7 @@ class TableIRSuite extends HailSuite { assert(!readTwoMissing.isDistinctlyKeyed) } - @Test def testPartitionCountsWithDropRows(): Unit = { + test("PartitionCountsWithDropRows") { val tr = new FakeTableReader { override def pathsUsed: Seq[String] = Seq.empty override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastSeq(1, 2, 3, 4)) @@ -1021,7 +1031,7 @@ class TableIRSuite extends HailSuite { assert(PartitionCounts(tir).forall(_.sum == 0)) } - @Test def testScanInAggInMapRows(): Unit = { + test("ScanInAggInMapRows") { implicit val execStrats = ExecStrategy.interpretOnly var tr: TableIR = TableRange(10, 3) tr = TableKeyBy(tr, FastSeq(), false) @@ -1051,7 +1061,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testScanInAggInScanInMapRows(): Unit = { + test("ScanInAggInScanInMapRows") { implicit val execStrats = ExecStrategy.interpretOnly var tr: TableIR = TableRange(10, 3) tr = TableKeyBy(tr, FastSeq(), false) @@ -1084,7 +1094,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableAggregateByKey(): Unit = { + test("TableAggregateByKey") { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRead.native(fs, getTestResource("three_key.ht")) tir = TableKeyBy(tir, FastSeq("x", "y"), true) @@ -1107,7 +1117,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableDistinct(): Unit = { + test("TableDistinct") { val tir: TableIR = TableRead.native(fs, getTestResource("three_key.ht")) val keyedByX = TableKeyBy(tir, FastSeq("x"), true) val distinctByX = TableDistinct(keyedByX) @@ -1122,7 +1132,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(distinctByAll), 120L) } - @Test def testRangeOrderByDescending(): Unit = { + test("RangeOrderByDescending") { var tir: TableIR = TableRange(10, 3) tir = TableOrderBy(tir, FastSeq(SortField("idx", Descending))) val x = GetField(TableCollect(tir), "rows") @@ -1130,8 +1140,8 @@ class TableIRSuite extends HailSuite { assertEvalsTo(x, (0 until 10).reverse.map(i => Row(i)))(ExecStrategy.allRelational) } - @Test def testTableLeftJoinRightDistinctRangeTables(): Unit = { - forAll(IndexedSeq((1, 1), (3, 2), (10, 5), (5, 10))) { case (nParts1, nParts2) => + test("TableLeftJoinRightDistinctRangeTables") { + IndexedSeq((1, 1), (3, 2), (10, 5), (5, 10)).foreach { case (nParts1, nParts2) => val rangeTable1 = TableRange(10, nParts1) var rangeTable2: TableIR = TableRange(5, nParts2) val row = Ref(TableIR.rowName, rangeTable2.typ.rowType) @@ -1150,7 +1160,7 @@ class TableIRSuite extends HailSuite { } } - @Test def testNestedStreamInTable(): Unit = { + test("NestedStreamInTable") { var tir: TableIR = TableRange(1, 1) var ir: IR = rangeIR(5) ir = StreamGrouped(ir, 2) @@ -1193,7 +1203,7 @@ class TableIRSuite extends HailSuite { val table2KeyedByA = TableKeyBy(table2, IndexedSeq("a2")) val joinedParKeyedByA = TableLeftJoinRightDistinct(table1KeyedByA, table2KeyedByA, "joinRoot") - @Test def testTableLeftJoinRightDistinctParallelizeSameKey(): Unit = { + test("TableLeftJoinRightDistinctParallelizeSameKey") { assertEvalsTo(TableCount(table1KeyedByA), parTable1Length.toLong) assertEvalsTo(TableCount(table2KeyedByA), parTable2Length.toLong) @@ -1209,7 +1219,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableLeftJoinRightDistinctParallelizePrefixKey(): Unit = { + test("TableLeftJoinRightDistinctParallelizePrefixKey") { val table1KeyedByAAndB = TableKeyBy(table1, IndexedSeq("a1", "b1")) val joinedParKeyedByAAndB = TableLeftJoinRightDistinct(table1KeyedByAAndB, table2KeyedByA, "joinRoot") @@ -1226,7 +1236,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableIntervalJoin(): Unit = { + test("TableIntervalJoin") { val intervals: IndexedSeq[Interval] = for { (start, end, includesStart, includesEnd) <- FastSeq( @@ -1285,7 +1295,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableKeyByAndAggregate(): Unit = { + test("TableKeyByAndAggregate") { val tir: TableIR = TableRead.native(fs, getTestResource("three_key.ht")) val unkeyed = TableKeyBy(tir, IndexedSeq[String]()) val rowRef = Ref(TableIR.rowName, unkeyed.typ.rowType) @@ -1365,7 +1375,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testTableAggregateCollectAndTake(): Unit = { + test("TableAggregateCollectAndTake") { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(10, 3) tir = @@ -1390,7 +1400,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testNDArrayMultiplyAddAggregator(): Unit = { + test("NDArrayMultiplyAddAggregator") { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(6, 3) val nDArray1 = Literal( @@ -1418,7 +1428,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(x, SafeNDArray(Vector(2, 2), IndexedSeq(24.0, 24.0, 24.0, 24.0))) } - @Test def testTableScanCollect(): Unit = { + test("TableScanCollect") { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(5, 3) tir = TableMapRows( @@ -1445,7 +1455,7 @@ class TableIRSuite extends HailSuite { ) } - @Test def testIssue9016(): Unit = { + test("Issue9016") { val rows = mapIR(ToStream(MakeArray(makestruct("a" -> MakeTuple.ordered(FastSeq(I32(0), I32(1))))))) { row => @@ -1466,7 +1476,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCollect(table), Row(FastSeq(Row(Row(1))), Row())) } - @Test def testTableNativeZippedReaderWithPrefixKey(): Unit = { + test("TableNativeZippedReaderWithPrefixKey") { /* This test is important because it tests that we can handle lowering with a * TableNativeZippedReader when elements of the original key get pruned away (so I copy key to * only be "locus" instead of "locus", "alleles") */ @@ -1493,7 +1503,7 @@ class TableIRSuite extends HailSuite { LowerTableIR(optimized, DArrayLowering.All, ctx, analyses): Unit } - @Test def testTableMapPartitions(): Unit = { + test("TableMapPartitions") { val table = TableKeyBy( @@ -1561,38 +1571,30 @@ class TableIRSuite extends HailSuite { ) } - @Test def testRepartitionCostEstimate(): Unit = { + test("RepartitionCostEstimate") { val empty = RVDPartitioner.empty(ctx.stateManager, TStruct.empty) val some = RVDPartitioner.unkeyed(ctx.stateManager, _) - val data = IndexedSeq( - (empty, empty, Succeeded, Failed("Repartitioning from an empty partitioner should be free")), - ( - empty, - some(1), - Succeeded, - Failed("Repartitioning from an empty partitioner should be free"), - ), - (some(1), empty, Succeeded, Failed("Repartitioning to an empty partitioner should be free")), - ( - some(5), - some(1), - Succeeded, - Failed("Combining multiple partitions into one should not incur a reload"), - ), - ( - some(1), - some(60), - Failed("Recomputing the same partition multiple times should be replaced with a reload"), - Succeeded, - ), + assert( + LowerTableIR.isRepartitioningCheap(empty, empty), + "Repartitioning from an empty partitioner should be free", + ) + assert( + LowerTableIR.isRepartitioningCheap(empty, some(1)), + "Repartitioning from an empty partitioner should be free", + ) + assert( + LowerTableIR.isRepartitioningCheap(some(1), empty), + "Repartitioning to an empty partitioner should be free", + ) + assert( + LowerTableIR.isRepartitioningCheap(some(5), some(1)), + "Combining multiple partitions into one should not incur a reload", + ) + assert( + !LowerTableIR.isRepartitioningCheap(some(1), some(60)), + "Recomputing the same partition multiple times should be replaced with a reload", ) - - forAll(data) { case (a, b, t, f) => - (if (LowerTableIR.isRepartitioningCheap(a, b)) t else f).toSucceeded.asInstanceOf[ - Unit - ] - } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala b/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala index bc51df922a6..9676f97e92e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/TakeByAggregatorSuite.scala @@ -11,13 +11,9 @@ import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ import is.hail.types.physical.stypes.primitives.SInt32Value -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.Test - class TakeByAggregatorSuite extends HailSuite { - @Test def testPointers(): Unit = { - forAll(Array((1000, 100), (1, 10), (100, 10000), (1000, 10000))) { case (size, n) => + test("Pointers") { + Array((1000, 100), (1, 10), (100, 10000), (1000, 10000)).foreach { case (size, n) => val fb = EmitFunctionBuilder[Region, Long](ctx, "test_pointers") val cb = fb.ecb val stringPT = PCanonicalString(true) @@ -50,8 +46,9 @@ class TakeByAggregatorSuite extends HailSuite { val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) val result = SafeRow.read(rt, o) - assert( - result == ((n - 1) to 0 by -1) + assertEquals( + result, + ((n - 1) to 0 by -1) .iterator .map(i => s"str$i") .take(size) @@ -62,7 +59,7 @@ class TakeByAggregatorSuite extends HailSuite { } } - @Test def testMissing(): Unit = { + test("Missing") { val fb = EmitFunctionBuilder[Region, Long](ctx, "take_by_test_missing") val cb = fb.ecb val tba = @@ -88,12 +85,12 @@ class TakeByAggregatorSuite extends HailSuite { val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) val result = SafeRow.read(rt, o) - assert(result == FastSeq(0, 1, 2, 3, null, null, null)) + assertEquals(result, FastSeq(0, 1, 2, 3, null, null, null)) } } - @Test def testRandom(): Unit = - forAll(Array(1, 2, 10, 100, 1000, 10000, 100000, 1000000)) { n => + test("Random") { + Array(1, 2, 10, 100, 1000, 10000, 100000, 1000000).foreach { n => val nToTake = 1025 val fb = EmitFunctionBuilder[Region, Long](ctx, "take_by_test_random") val kb = fb.ecb @@ -137,7 +134,8 @@ class TakeByAggregatorSuite extends HailSuite { val collOffset = Region.loadAddress(o + 8) val collected = SafeRow.read(ab.eltArray, collOffset).asInstanceOf[IndexedSeq[Int]].take(n) val minValues = collected.sorted.take(nToTake) - assert(pq == minValues, s"n=$n") + assertEquals(pq, minValues, s"n=$n") } } + } } diff --git a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala index f8a08f8f70b..4f878f7b9f2 100644 --- a/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/UtilFunctionsSuite.scala @@ -5,8 +5,6 @@ import is.hail.ExecStrategy.ExecStrategy import is.hail.expr.ir.defs.{Die, False, MakeStream, NA, Str, True} import is.hail.types.virtual.{TBoolean, TInt32, TStream} -import org.testng.annotations.Test - class UtilFunctionsSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.javaOnly @@ -15,23 +13,22 @@ class UtilFunctionsSuite extends HailSuite { val folded = foldIR(MakeStream(IndexedSeq(true), TStream(TBoolean)), die)(_ || _) - @Test def shortCircuitOr(): Unit = { + test("shortCircuitOr") { assertEvalsTo(True() || True(), true) assertEvalsTo(True() || False(), true) assertEvalsTo(False() || True(), true) assertEvalsTo(False() || False(), false) } - @Test def shortCircuitOrHandlesMissingness(): Unit = { + test("shortCircuitOrHandlesMissingness") { assertEvalsTo(na || na, null) assertEvalsTo(na || True(), true) assertEvalsTo(True() || na, true) assertEvalsTo(na || False(), null) assertEvalsTo(False() || na, null) - } - @Test def shortCircuitOrHandlesErrors(): Unit = { + test("shortCircuitOrHandlesErrors") { // FIXME: interpreter evaluates args for ApplySpecial before invoking the function :-| assertCompiledFatal(na || die, "it ded") assertCompiledFatal(False() || die, "it ded") @@ -48,14 +45,14 @@ class UtilFunctionsSuite extends HailSuite { assert(eval(True() || folded) == true) } - @Test def shortCircuitAnd(): Unit = { + test("shortCircuitAnd") { assertEvalsTo(True() && True(), true) assertEvalsTo(True() && False(), false) assertEvalsTo(False() && True(), false) assertEvalsTo(False() && False(), false) } - @Test def shortCircuitAndHandlesMissingness(): Unit = { + test("shortCircuitAndHandlesMissingness") { assertEvalsTo(na && na, null) assertEvalsTo(True() && na, null) assertEvalsTo(na && True(), null) @@ -63,7 +60,7 @@ class UtilFunctionsSuite extends HailSuite { assertEvalsTo(na && False(), false) } - @Test def shortCircuitAndHandlesErroes(): Unit = { + test("shortCircuitAndHandlesErroes") { // FIXME: interpreter evaluates args for ApplySpecial before invoking the function :-| assertCompiledFatal(na && die, "it ded") assertCompiledFatal(True() && die, "it ded") @@ -79,7 +76,7 @@ class UtilFunctionsSuite extends HailSuite { assert(eval(False() && folded) == false) } - @Test def testParseFunctionRequiredness(): Unit = { + test("ParseFunctionRequiredness") { assertEvalsTo(invoke("toInt32OrMissing", TInt32, Str("123")), 123) assertEvalsTo(invoke("toInt32OrMissing", TInt32, Str("foo")), null) } diff --git a/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala b/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala index dc8e260806c..86274a6d551 100644 --- a/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/agg/DownsampleSuite.scala @@ -8,11 +8,13 @@ import is.hail.types.VirtualTypeWithReq import is.hail.types.physical.{PCanonicalArray, PCanonicalString} import is.hail.types.physical.stypes.primitives.SFloat64Value -import org.testng.annotations.Test +import scala.concurrent.duration.Duration class DownsampleSuite extends HailSuite { - @Test def testLargeRandom(): Unit = { + override val munitTimeout = Duration(120, "s") + + test("LargeRandom") { val lt = PCanonicalArray(PCanonicalString()) val fb = EmitFunctionBuilder[RegionPool, Unit](ctx, "foo") val cb = fb.ecb diff --git a/hail/hail/test/src/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala b/hail/hail/test/src/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala index 25b6ca18bfa..f94ef440d2b 100644 --- a/hail/hail/test/src/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala @@ -9,9 +9,6 @@ import is.hail.collection.compat.mutable.GrowableCompat import is.hail.expr.ir.{EmitCode, EmitFunctionBuilder} import is.hail.types.physical._ -import org.testng.Assert._ -import org.testng.annotations.Test - class StagedBlockLinkedListSuite extends HailSuite { class BlockLinkedList[E](region: Region, val elemPType: PType, initImmediately: Boolean = true) @@ -141,14 +138,15 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testPushIntsRequired(): Unit = + test("PushIntsRequired") { pool.scopedRegion { region => val b = new BlockLinkedList[Int](region, PInt32Required) for (i <- 1 to 100) b += i assertEquals(b.toIndexedSeq, IndexedSeq.tabulate(100)(_ + 1)) } + } - @Test def testPushStrsMissing(): Unit = { + test("PushStrsMissing") { pool.scopedRegion { region => val a = ArraySeq.newBuilder[String] val b = new BlockLinkedList[String](region, PCanonicalString()) @@ -161,7 +159,7 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testAppendAnother(): Unit = { + test("AppendAnother") { pool.scopedRegion { region => val b1 = new BlockLinkedList[String](region, PCanonicalString()) val b2 = new BlockLinkedList[String](region, PCanonicalString()) @@ -174,7 +172,7 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testDeepCopy(): Unit = { + test("DeepCopy") { pool.scopedRegion { region => val b1 = new BlockLinkedList[Double](region, PFloat64()) b1 ++= Seq(1.0, 2.0, 3.0) diff --git a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 2187ba3ad32..5973c0e673e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -14,167 +14,183 @@ import java.io.FileNotFoundException import java.lang import org.json4s.JValue -import org.testng.annotations.{DataProvider, Test} class SemanticHashSuite extends HailSuite { - def isTriviallySemanticallyEquivalent: Array[Array[Any]] = - Array( - Array(True(), True(), true, "Refl"), - Array(False(), False(), true, "Refl"), - Array(True(), False(), false, "Refl"), - Array(I32(0), I32(0), true, "Refl"), - Array(I32(0), I32(1), false, "Refl"), - Array(I64(0), I64(0), true, "Refl"), - Array(I64(0), I64(1), false, "Refl"), - Array(F32(0), F32(0), true, "Refl"), - Array(F32(0), F32(1), false, "Refl"), - Array(Void(), Void(), true, "Refl"), - Array(Str("a"), Str("a"), true, "Refl"), - Array(Str("a"), Str("b"), false, "Refl"), - Array(NA(TInt32), NA(TInt32), true, "Refl"), - Array(NA(TInt32), NA(TFloat64), false, "Refl"), - ) + private[this] val fakeFs: FS = + new FakeFS { + override def eTag(url: FakeURL): Option[String] = + Some(url.path) + + override def glob(url: FakeURL): IndexedSeq[FileListEntry] = + ArraySeq(new FileListEntry { + override def getPath: String = url.path + override def getActualUrl: String = url.path + override def getModificationTime: lang.Long = ??? + override def getLen: Long = ??? + override def isDirectory: Boolean = ??? + override def isSymlink: Boolean = ??? + override def isFile: Boolean = true + override def getOwner: String = ??? + }) + } def mkRelationalLet(bindings: IndexedSeq[(Name, IR)], body: IR): IR = bindings.foldRight(body) { case ((name, value), body) => RelationalLet(name, value, body) } - def isLetSemanticallyEquivalent: Array[Array[Any]] = { - val x = freshName() - val y = freshName() - Array((Let(_, _), Ref), (mkRelationalLet _, RelationalRef)).flatMap { case (let, ref) => - Array( - Array( - let(FastSeq(x -> I32(0)), ref(x, TInt32)), - let(FastSeq(y -> I32(0)), ref(y, TInt32)), - true, - "names used in let-bindings do not change semantics", - ), - Array( - let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), - let(FastSeq(y -> I32(0), x -> I32(0)), ref(y, TInt32)), - true, - "names of let-bindings do not change semantics", - ), - Array( - let(FastSeq(x -> I32(0)), ref(x, TInt32)), - let(FastSeq(x -> I64(0)), ref(x, TInt64)), - false, - "different IRs", - ), - Array( - let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), - let(FastSeq(y -> I32(0), x -> I32(0)), ref(x, TInt32)), - false, - "Different binding being referenced", - ), - /* `SemanticHash` does not perform or recognise opportunities for simplification. - * The following examples demonstrate some of its limitations as a consequence. */ - Array( - let(FastSeq(x -> I32(0)), ref(x, TInt32)), - let(FastSeq(x -> let(FastSeq(freshName() -> I32(0)), I32(0))), ref(x, TInt32)), - false, - "SemanticHash does not simplify", - ), - Array( - let(FastSeq(x -> I32(0)), ref(x, TInt32)), - let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), - false, - "SemanticHash does not simplify", - ), - ) + object checkSemanticEquivalence extends TestCases { + def apply( + a: BaseIR, + b: BaseIR, + isEqual: Boolean, + comment: String, + )(implicit loc: munit.Location + ): Unit = test("SemanticEquivalence") { + ctx.local(fs = fakeFs) { ctx => + assertEquals( + SemanticHash(ctx, a) == SemanticHash(ctx, b), + isEqual, + s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", + ) + } } } - def isBaseStructSemanticallyEquivalent: Array[Array[Any]] = - Array.concat( - Array( - Array( - MakeStruct(ArraySeq.empty), - MakeStruct(ArraySeq.empty), - true, - "empty structs", - ), - Array( - MakeStruct(ArraySeq(genUID() -> I32(0))), - MakeStruct(ArraySeq(genUID() -> I32(0))), - true, - "field names do not affect MakeStruct semantics", - ), - Array( - MakeTuple(ArraySeq.empty), - MakeTuple(ArraySeq.empty), - true, - "empty tuples", - ), - Array( - MakeTuple(ArraySeq(0 -> I32(0))), - MakeTuple(ArraySeq(0 -> I32(0))), - true, - "identical tuples", - ), - Array( - MakeTuple(ArraySeq(0 -> I32(0))), - MakeTuple(ArraySeq(1 -> I32(0))), - false, - "tuple indices affect MakeTuple semantics", - ), - ), { - - def f(mkType: Int => Type, get: (IR, Int) => IR, isSame: Boolean, reason: String) = - Array.tabulate(2)(idx => bindIR(NA(mkType(idx)))(get(_, idx))) ++ Array(isSame, reason) - - Array( - f( - mkType = i => TStruct(i.toString -> TInt32), - get = (ir, i) => GetField(ir, i.toString), - isSame = true, - "field names do not affect GetField semantics", - ), - f( - mkType = _ => TTuple(TInt32), - get = (ir, _) => GetTupleElement(ir, 0), - isSame = true, - "GetTupleElement of same index", - ), - f( - mkType = i => TTuple(ArraySeq(TupleField(i, TInt32))), - get = (ir, i) => GetTupleElement(ir, i), - isSame = false, - "GetTupleElement on different index", - ), - ) - }, - ) + // trivial value semantics + checkSemanticEquivalence(True(), True(), true, "Refl") + checkSemanticEquivalence(False(), False(), true, "Refl") + checkSemanticEquivalence(True(), False(), false, "Refl") + checkSemanticEquivalence(I32(0), I32(0), true, "Refl") + checkSemanticEquivalence(I32(0), I32(1), false, "Refl") + checkSemanticEquivalence(I64(0), I64(0), true, "Refl") + checkSemanticEquivalence(I64(0), I64(1), false, "Refl") + checkSemanticEquivalence(F32(0), F32(0), true, "Refl") + checkSemanticEquivalence(F32(0), F32(1), false, "Refl") + checkSemanticEquivalence(Void(), Void(), true, "Refl") + checkSemanticEquivalence(Str("a"), Str("a"), true, "Refl") + checkSemanticEquivalence(Str("a"), Str("b"), false, "Refl") + checkSemanticEquivalence(NA(TInt32), NA(TInt32), true, "Refl") + checkSemanticEquivalence(NA(TInt32), NA(TFloat64), false, "Refl") + + // let-binding semantics + { + val x = freshName() + val y = freshName() - def isTreeStructureSemanticallyEquivalent: Array[Array[Any]] = - Array( - Array( - MakeArray( - MakeArray(I32(0)), - MakeArray(I32(0)), - ), - MakeArray( - MakeArray( - MakeArray(I32(0), I32(0)) - ) - ), + def letCases(let: (IndexedSeq[(Name, IR)], IR) => IR, ref: (Name, Type) => IR): Unit = { + checkSemanticEquivalence( + let(FastSeq(x -> I32(0)), ref(x, TInt32)), + let(FastSeq(y -> I32(0)), ref(y, TInt32)), + true, + "names used in let-bindings do not change semantics", + ) + checkSemanticEquivalence( + let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), + let(FastSeq(y -> I32(0), x -> I32(0)), ref(y, TInt32)), + true, + "names of let-bindings do not change semantics", + ) + checkSemanticEquivalence( + let(FastSeq(x -> I32(0)), ref(x, TInt32)), + let(FastSeq(x -> I64(0)), ref(x, TInt64)), false, - "Tree structure contributes to semantics", + "different IRs", ) - ) + checkSemanticEquivalence( + let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), + let(FastSeq(y -> I32(0), x -> I32(0)), ref(x, TInt32)), + false, + "Different binding being referenced", + ) + /* `SemanticHash` does not perform or recognise opportunities for simplification. + * The following examples demonstrate some of its limitations as a consequence. */ + checkSemanticEquivalence( + let(FastSeq(x -> I32(0)), ref(x, TInt32)), + let(FastSeq(x -> let(FastSeq(freshName() -> I32(0)), I32(0))), ref(x, TInt32)), + false, + "SemanticHash does not simplify", + ) + checkSemanticEquivalence( + let(FastSeq(x -> I32(0)), ref(x, TInt32)), + let(FastSeq(x -> I32(0), y -> I32(0)), ref(x, TInt32)), + false, + "SemanticHash does not simplify", + ) + } - def isValueIRSemanticallyEquivalent: Array[Array[Any]] = - Array.concat( - isTriviallySemanticallyEquivalent, - isLetSemanticallyEquivalent, - isBaseStructSemanticallyEquivalent, - isTreeStructureSemanticallyEquivalent, - ) + letCases(Let(_, _), Ref) + letCases(mkRelationalLet _, RelationalRef) + } - def isTableIRSemanticallyEquivalent: Array[Array[Any]] = { + // struct/tuple semantics + checkSemanticEquivalence( + MakeStruct(ArraySeq.empty), + MakeStruct(ArraySeq.empty), + true, + "empty structs", + ) + + checkSemanticEquivalence( + MakeStruct(ArraySeq(genUID() -> I32(0))), + MakeStruct(ArraySeq(genUID() -> I32(0))), + true, + "field names do not affect MakeStruct semantics", + ) + + checkSemanticEquivalence( + MakeTuple(ArraySeq.empty), + MakeTuple(ArraySeq.empty), + true, + "empty tuples", + ) + + checkSemanticEquivalence( + MakeTuple(ArraySeq(0 -> I32(0))), + MakeTuple(ArraySeq(0 -> I32(0))), + true, + "identical tuples", + ) + + checkSemanticEquivalence( + MakeTuple(ArraySeq(0 -> I32(0))), + MakeTuple(ArraySeq(1 -> I32(0))), + false, + "tuple indices affect MakeTuple semantics", + ) + + checkSemanticEquivalence( + bindIR(NA(TStruct("0" -> TInt32)))(GetField(_, "0")), + bindIR(NA(TStruct("1" -> TInt32)))(GetField(_, "1")), + true, + "field names do not affect GetField semantics", + ) + + checkSemanticEquivalence( + bindIR(NA(TTuple(TInt32)))(GetTupleElement(_, 0)), + bindIR(NA(TTuple(TInt32)))(GetTupleElement(_, 0)), + true, + "GetTupleElement of same index", + ) + + checkSemanticEquivalence( + bindIR(NA(TTuple(ArraySeq(TupleField(0, TInt32)))))(GetTupleElement(_, 0)), + bindIR(NA(TTuple(ArraySeq(TupleField(1, TInt32)))))(GetTupleElement(_, 1)), + false, + "GetTupleElement on different index", + ) + + // tree structure semantics + checkSemanticEquivalence( + MakeArray(MakeArray(I32(0)), MakeArray(I32(0))), + MakeArray(MakeArray(MakeArray(I32(0), I32(0)))), + false, + "Tree structure contributes to semantics", + ) + + // table IR semantics + { val ttype = TableType(TStruct("a" -> TInt32, "b" -> TStruct()), IndexedSeq("a"), TStruct()) val ttypeb = TableType(TStruct("c" -> TInt32, "d" -> TStruct()), IndexedSeq(), TStruct()) @@ -189,120 +205,102 @@ class SemanticHashSuite extends HailSuite { val tir = mkTableIR(ttype, "/fake/table") - Array.concat( - Array( - Array(tir, tir, true, "TableRead same table"), - Array(tir, mkTableIR(ttype, "/another/fake/table"), false, "TableRead different table"), - Array( - TableKeyBy(tir, IndexedSeq("a")), - TableKeyBy(tir, IndexedSeq("a")), - true, - "TableKeyBy same key", - ), - Array( - TableKeyBy(tir, IndexedSeq("a")), - TableKeyBy(tir, IndexedSeq("b")), - false, - "TableKeyBy different key", - ), - ), - Array[String => TableReader]( - path => - new StringTableReader( - StringTableReaderParameters(ArraySeq(path), None, false, false, false), - fakeFs.glob(path), - ), - path => - TableNativeZippedReader( - path + ".left", - path + ".right", - None, - mkFakeTableSpec(ttype), - mkFakeTableSpec(ttypeb), - ), - ) - .map(mkTableRead _ compose _) - .flatMap { reader => - Array( - Array(reader("/fake/table"), reader("/fake/table"), true, "read same table"), - Array( - reader("/fake/table"), - reader("/another/fake/table"), - false, - "read different table", - ), - ) - }, - Array( - TableGetGlobals, - TableAggregate(_, I32(0)), - TableAggregateByKey(_, MakeStruct(FastSeq())), - TableKeyByAndAggregate( - _, - MakeStruct(FastSeq()), - MakeStruct(FastSeq("idx" -> I32(0))), - None, - 256, - ), - (ir: TableIR) => TableCollect(TableKeyBy(ir, FastSeq())), - TableCount, - TableDistinct, - TableFilter(_, True()), - TableMapGlobals(_, MakeStruct(IndexedSeq.empty)), - TableMapRows(_, MakeStruct(FastSeq("a" -> I32(0)))), - TableRename(_, Map.empty, Map.empty), - ).map(wrap => Array(wrap(tir), wrap(tir), true, "")), + checkSemanticEquivalence(tir, tir, true, "TableRead same table") + checkSemanticEquivalence( + tir, + mkTableIR(ttype, "/another/fake/table"), + false, + "TableRead different table", + ) + checkSemanticEquivalence( + TableKeyBy(tir, IndexedSeq("a")), + TableKeyBy(tir, IndexedSeq("a")), + true, + "TableKeyBy same key", ) + checkSemanticEquivalence( + TableKeyBy(tir, IndexedSeq("a")), + TableKeyBy(tir, IndexedSeq("b")), + false, + "TableKeyBy different key", + ) + + Array[String => TableIR]( + path => + mkTableRead(new StringTableReader( + StringTableReaderParameters(ArraySeq(path), None, false, false, false), + fakeFs.glob(path), + )), + path => + mkTableRead(TableNativeZippedReader( + path + ".left", + path + ".right", + None, + mkFakeTableSpec(ttype), + mkFakeTableSpec(ttypeb), + )), + ).foreach { reader => + checkSemanticEquivalence( + reader("/fake/table"), + reader("/fake/table"), + true, + "read same table", + ) + checkSemanticEquivalence( + reader("/fake/table"), + reader("/another/fake/table"), + false, + "read different table", + ) + } + + Array[TableIR => BaseIR]( + TableGetGlobals, + TableAggregate(_, I32(0)), + TableAggregateByKey(_, MakeStruct(FastSeq())), + TableKeyByAndAggregate( + _, + MakeStruct(FastSeq()), + MakeStruct(FastSeq("idx" -> I32(0))), + None, + 256, + ), + (ir: TableIR) => TableCollect(TableKeyBy(ir, FastSeq())), + TableCount, + TableDistinct, + TableFilter(_, True()), + TableMapGlobals(_, MakeStruct(IndexedSeq.empty)), + TableMapRows(_, MakeStruct(FastSeq("a" -> I32(0)))), + TableRename(_, Map.empty, Map.empty), + ).foreach(wrap => checkSemanticEquivalence(wrap(tir), wrap(tir), true, "")) } - def isBlockMatrixIRSemanticallyEquivalent: Array[Array[Any]] = - Array[String => BlockMatrixReader]( - path => BlockMatrixBinaryReader(path, ArraySeq(1L, 1L), 1), + // block matrix IR semantics + { + Array[String => BlockMatrixIR]( + path => BlockMatrixRead(BlockMatrixBinaryReader(path, ArraySeq(1L, 1L), 1)), path => - new BlockMatrixNativeReader( + BlockMatrixRead(new BlockMatrixNativeReader( BlockMatrixNativeReaderParameters(path), BlockMatrixMetadata(1, 1, 1, None, IndexedSeq.empty), - ), - ) - .map(BlockMatrixRead compose _) - .flatMap { reader => - Array( - Array( - reader("/fake/block-matrix"), - reader("/fake/block-matrix"), - true, - "Read same block matrix", - ), - Array( - reader("/fake/block-matrix"), - reader("/another/fake/block-matrix"), - false, - "Read different block matrix", - ), - ) - } - - @DataProvider(name = "isBaseIRSemanticallyEquivalent") - def isBaseIRSemanticallyEquivalent: Array[Array[Any]] = - Array.concat( - isValueIRSemanticallyEquivalent, - isTableIRSemanticallyEquivalent, - isBlockMatrixIRSemanticallyEquivalent, - ) - - @Test(dataProvider = "isBaseIRSemanticallyEquivalent") - def testSemanticEquivalence(a: BaseIR, b: BaseIR, isEqual: Boolean, comment: String): Unit = - ctx.local(fs = fakeFs) { ctx => - assertResult( - isEqual, - s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", - )( - SemanticHash(ctx, a) == SemanticHash(ctx, b) + )), + ).foreach { reader => + checkSemanticEquivalence( + reader("/fake/block-matrix"), + reader("/fake/block-matrix"), + true, + "Read same block matrix", + ) + checkSemanticEquivalence( + reader("/fake/block-matrix"), + reader("/another/fake/block-matrix"), + false, + "Read different block matrix", ) } + } - @Test - def testFileNotFoundExceptions(): Unit = { + test("FileNotFoundExceptions") { val fs = new FakeFS { override def eTag(url: FakeURL): Option[String] = @@ -312,30 +310,14 @@ class SemanticHashSuite extends HailSuite { val ir = importMatrix("gs://fake-bucket/fake-matrix") ctx.local(fs = fs) { ctx => - assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( - SemanticHash(ctx, ir) + assertEquals( + SemanticHash(ctx, ir), + None, + "SemHash should be resilient to FileNotFoundExceptions.", ) } } - private[this] val fakeFs: FS = - new FakeFS { - override def eTag(url: FakeURL): Option[String] = - Some(url.path) - - override def glob(url: FakeURL): IndexedSeq[FileListEntry] = - ArraySeq(new FileListEntry { - override def getPath: String = url.path - override def getActualUrl: String = url.path - override def getModificationTime: lang.Long = ??? - override def getLen: Long = ??? - override def isDirectory: Boolean = ??? - override def isSymlink: Boolean = ??? - override def isFile: Boolean = true - override def getOwner: String = ??? - }) - } - def importMatrix(path: String): MatrixIR = { val ty = MatrixType( diff --git a/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala index 53b5ad147db..1026df06e97 100644 --- a/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala @@ -2,12 +2,9 @@ package is.hail.expr.ir.defs import is.hail.HailSuite -import org.testng.annotations.Test - class EncodedLiteralSuite extends HailSuite { - @Test - def testWrappedByteArrayEquality(): Unit = { + test("WrappedByteArrayEquality") { val byteArray1 = Array[Byte](1, 2, 1, 1) val byteArray2 = Array[Byte](1, 2, 1, 1) val byteArray3 = Array[Byte](0, 0, 1, 0) diff --git a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 81ca537f439..9b6619ddebf 100644 --- a/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -1,6 +1,6 @@ package is.hail.expr.ir.lowering -import is.hail.{ExecStrategy, HailSuite, TestUtils} +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.collection.FastSeq import is.hail.expr.ir.{ @@ -14,13 +14,15 @@ import is.hail.expr.ir.lowering.LowerDistributedSort.samplePartition import is.hail.types.RTable import is.hail.types.virtual.{TArray, TInt32, TStruct} +import scala.concurrent.duration.Duration + import org.apache.spark.sql.Row -import org.testng.annotations.Test -class LowerDistributedSortSuite extends HailSuite with TestUtils { +class LowerDistributedSortSuite extends HailSuite { + override val munitTimeout = Duration(60, "s") implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly - @Test def testSamplePartition(): Unit = { + test("SamplePartition") { val dataKeys = IndexedSeq( (0, 0), (0, -1), @@ -107,7 +109,7 @@ class LowerDistributedSortSuite extends HailSuite with TestUtils { assert(res == scalaSorted) } - @Test def testDistributedSort(): Unit = { + test("DistributedSort") { val tableRange = TableRange(100, 10) val rangeRow = Ref(TableIR.rowName, tableRange.typ.rowType) val tableWithExtraField = TableMapRows( @@ -140,7 +142,7 @@ class LowerDistributedSortSuite extends HailSuite with TestUtils { ) } - @Test def testDistributedSortEmpty(): Unit = { + test("DistributedSortEmpty") { val tableRange = TableRange(0, 1) testDistributedSortHelper(tableRange, IndexedSeq(SortField("idx", Ascending))) } diff --git a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala index 3294c639139..7da3c1118e4 100644 --- a/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/table/TableGenSuite.scala @@ -16,26 +16,22 @@ import is.hail.utils.{HailException, Interval} import org.apache.spark.SparkException import org.apache.spark.sql.Row -import org.scalatest.matchers.should.Matchers._ -import org.testng.annotations.Test class TableGenSuite extends HailSuite { implicit val execStrategy: Set[ExecStrategy] = ExecStrategy.lowering - @Test(groups = Array("construction", "typecheck")) - def testWithInvalidContextsType(): Unit = { + test("WithInvalidContextsType") { val ex = intercept[IllegalArgumentException] { TypeCheck(ctx, mkTableGen(contexts = Some(Str("oh noes :'(")))) } - ex.getMessage should include("contexts") - ex.getMessage should include(s"Expected: ${classOf[TStream].getName}") - ex.getMessage should include(s"Actual: ${TString.getClass.getName}") + assert(ex.getMessage.contains("contexts")) + assert(ex.getMessage.contains(s"Expected: ${classOf[TStream].getName}")) + assert(ex.getMessage.contains(s"Actual: ${TString.getClass.getName}")) } - @Test(groups = Array("construction", "typecheck")) - def testWithInvalidGlobalsType(): Unit = { + test("WithInvalidGlobalsType") { val ex = intercept[HailException] { TypeCheck( ctx, @@ -45,23 +41,21 @@ class TableGenSuite extends HailSuite { ), ) } - ex.getCause.getMessage should include("globals") - ex.getCause.getMessage should include(s"Expected: ${classOf[TStruct].getName}") - ex.getCause.getMessage should include(s"Actual: ${TString.getClass.getName}") + assert(ex.getCause.getMessage.contains("globals")) + assert(ex.getCause.getMessage.contains(s"Expected: ${classOf[TStruct].getName}")) + assert(ex.getCause.getMessage.contains(s"Actual: ${TString.getClass.getName}")) } - @Test(groups = Array("construction", "typecheck")) - def testWithInvalidBodyType(): Unit = { + test("WithInvalidBodyType") { val ex = intercept[HailException] { TypeCheck(ctx, mkTableGen(body = Some((_, _) => Str("oh noes :'(")))) } - ex.getCause.getMessage should include("body") - ex.getCause.getMessage should include(s"Expected: ${classOf[TStream].getName}") - ex.getCause.getMessage should include(s"Actual: ${TString.getClass.getName}") + assert(ex.getCause.getMessage.contains("body")) + assert(ex.getCause.getMessage.contains(s"Expected: ${classOf[TStream].getName}")) + assert(ex.getCause.getMessage.contains(s"Actual: ${TString.getClass.getName}")) } - @Test(groups = Array("construction", "typecheck")) - def testWithInvalidBodyElementType(): Unit = { + test("WithInvalidBodyElementType") { val ex = intercept[HailException] { TypeCheck( ctx, @@ -70,13 +64,12 @@ class TableGenSuite extends HailSuite { ), ) } - ex.getCause.getMessage should include("body.elementType") - ex.getCause.getMessage should include(s"Expected: ${classOf[TStruct].getName}") - ex.getCause.getMessage should include(s"Actual: ${TString.getClass.getName}") + assert(ex.getCause.getMessage.contains("body.elementType")) + assert(ex.getCause.getMessage.contains(s"Expected: ${classOf[TStruct].getName}")) + assert(ex.getCause.getMessage.contains(s"Actual: ${TString.getClass.getName}")) } - @Test(groups = Array("construction", "typecheck")) - def testWithInvalidPartitionerKeyType(): Unit = { + test("WithInvalidPartitionerKeyType") { val ex = intercept[HailException] { TypeCheck( ctx, @@ -85,11 +78,10 @@ class TableGenSuite extends HailSuite { ), ) } - ex.getCause.getMessage should include("partitioner") + assert(ex.getCause.getMessage.contains("partitioner")) } - @Test(groups = Array("construction", "typecheck")) - def testWithTooLongPartitionerKeyType(): Unit = { + test("WithTooLongPartitionerKeyType") { val ex = intercept[HailException] { TypeCheck( ctx, @@ -98,26 +90,23 @@ class TableGenSuite extends HailSuite { ), ) } - ex.getCause.getMessage should include("partitioner") + assert(ex.getCause.getMessage.contains("partitioner")) } - @Test(groups = Array("requiredness")) - def testRequiredness(): Unit = { + test("Requiredness") { val table = mkTableGen() val analysis = Requiredness(table, ctx) - analysis.lookup(table).required shouldBe true - analysis.states.m.isEmpty shouldBe true + assertEquals(analysis.lookup(table).required, true) + assertEquals(analysis.states.m.isEmpty, true) } - @Test(groups = Array("lowering")) - def testLowering(): Unit = { + test("Lowering") { val table = collect(mkTableGen()) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } - @Test(groups = Array("lowering")) - def testNumberOfContextsMatchesPartitions(): Unit = { + test("NumberOfContextsMatchesPartitions") { val errorId = 42 val table = collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), @@ -127,12 +116,11 @@ class TableGenSuite extends HailSuite { val ex = intercept[HailException] { loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) } - ex.errorId shouldBe errorId - ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") + assertEquals(ex.errorId, errorId) + assert(ex.getMessage.contains("partitioner contains 0 partitions, got 2 contexts.")) } - @Test(groups = Array("lowering")) - def testRowsAreCorrectlyKeyed(): Unit = { + test("RowsAreCorrectlyKeyed") { val errorId = 56 val table = collect(mkTableGen( partitioner = Some(new RVDPartitioner( @@ -150,19 +138,17 @@ class TableGenSuite extends HailSuite { loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) }.getCause.asInstanceOf[HailException] - ex.errorId shouldBe errorId - ex.getMessage should include("TableGen: Unexpected key in partition") + assertEquals(ex.errorId, errorId) + assert(ex.getMessage.contains("TableGen: Unexpected key in partition")) } - @Test(groups = Array("optimization", "prune")) - def testPruneNoUnusedFields(): Unit = { + test("PruneNoUnusedFields") { val start = mkTableGen() val pruned = PruneDeadFields(ctx, start) - pruned.typ shouldBe start.typ + assertEquals(pruned.typ, start.typ) } - @Test(groups = Array("optimization", "prune")) - def testPruneGlobals(): Unit = { + test("PruneGlobals") { val start = mkTableGen( body = Some { (c, _) => val elem = MakeStruct(IndexedSeq("a" -> c)) @@ -176,17 +162,16 @@ class TableGenSuite extends HailSuite { TableAggregate(start, IRAggCollect(Ref(TableIR.rowName, start.typ.rowType))), ) - pruned.typ should not be start.typ - pruned.typ.globalType shouldBe TStruct() - pruned.asInstanceOf[TableGen].globals shouldBe MakeStruct(IndexedSeq()) + assertNotEquals(pruned.typ, start.typ) + assertEquals(pruned.typ.globalType, TStruct()) + assertEquals(pruned.asInstanceOf[TableGen].globals, MakeStruct(IndexedSeq())) } - @Test(groups = Array("optimization", "prune")) - def testPruneContexts(): Unit = { + test("PruneContexts") { val start = mkTableGen() val TableGetGlobals(pruned) = PruneDeadFields(ctx, TableGetGlobals(start)) - pruned.typ should not be start.typ - pruned.typ.rowType shouldBe TStruct() + assertNotEquals(pruned.typ, start.typ) + assertEquals(pruned.typ.rowType, TStruct()) } def mkTableGen( diff --git a/hail/hail/test/src/is/hail/io/ArrayImpexSuite.scala b/hail/hail/test/src/is/hail/io/ArrayImpexSuite.scala index 96d7b13b25f..10f75d86654 100644 --- a/hail/hail/test/src/is/hail/io/ArrayImpexSuite.scala +++ b/hail/hail/test/src/is/hail/io/ArrayImpexSuite.scala @@ -2,17 +2,15 @@ package is.hail.io import is.hail.HailSuite -import org.testng.annotations.Test - class ArrayImpexSuite extends HailSuite { - @Test def testArrayImpex(): Unit = { + test("ArrayImpex") { val file = ctx.createTmpPath("test") val a = Array.fill[Double](100)(util.Random.nextDouble()) val a2 = new Array[Double](100) ArrayImpex.exportToDoubles(fs, file, a, bufSize = 32) ArrayImpex.importFromDoubles(fs, file, a2, bufSize = 16) - assert(a === a2) + assertEquals(a.toSeq, a2.toSeq) interceptFatal("Premature") { ArrayImpex.importFromDoubles(fs, file, new Array[Double](101), bufSize = 64) diff --git a/hail/hail/test/src/is/hail/io/AvroReaderSuite.scala b/hail/hail/test/src/is/hail/io/AvroReaderSuite.scala index d16e303b2f7..d65117a20b1 100644 --- a/hail/hail/test/src/is/hail/io/AvroReaderSuite.scala +++ b/hail/hail/test/src/is/hail/io/AvroReaderSuite.scala @@ -12,7 +12,6 @@ import org.apache.avro.SchemaBuilder import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} import org.apache.spark.sql.Row -import org.testng.annotations.Test class AvroReaderSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.LoweredJVMCompile) @@ -63,7 +62,7 @@ class AvroReaderSuite extends HailSuite { avroFile } - @Test def avroReaderWorks(): Unit = { + test("avroReaderWorks") { val avroFile = makeTestFile() val ir = ToArray(ReadPartition( MakeStruct(ArraySeq("partitionPath" -> Str(avroFile), "partitionIndex" -> I64(0))), @@ -76,7 +75,7 @@ class AvroReaderSuite extends HailSuite { assertEvalsTo(ir, testValueWithUIDs) } - @Test def testSmallerRequestedType(): Unit = { + test("SmallerRequestedType") { val avroFile = makeTestFile() val ir = ToArray(ReadPartition( MakeStruct(ArraySeq("partitionPath" -> Str(avroFile), "partitionIndex" -> I64(0))), diff --git a/hail/hail/test/src/is/hail/io/ByteArrayReaderSuite.scala b/hail/hail/test/src/is/hail/io/ByteArrayReaderSuite.scala index 7355ce7eb36..dffce62f251 100644 --- a/hail/hail/test/src/is/hail/io/ByteArrayReaderSuite.scala +++ b/hail/hail/test/src/is/hail/io/ByteArrayReaderSuite.scala @@ -1,27 +1,21 @@ package is.hail.io -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class ByteArrayReaderSuite extends TestNGSuite { - @Test - def readLongReadsALong(): Unit = { +class ByteArrayReaderSuite extends munit.FunSuite { + test("readLong reads a long") { val a = Array(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff) .map(_.toByte) - assert(new ByteArrayReader(a).readLong() == -1L) + assertEquals(new ByteArrayReader(a).readLong(), -1L) } - @Test - def readLongReadsALong2(): Unit = { + test("readLong reads a long 2") { val a = Array(0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8) .map(_.toByte) - assert(new ByteArrayReader(a).readLong() == 0xf8f7f6f5f4f3f2f1L) + assertEquals(new ByteArrayReader(a).readLong(), 0xf8f7f6f5f4f3f2f1L) } - @Test - def readLongReadsALong3(): Unit = { + test("readLong reads a long 3") { val a = Array(0xf8, 0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1) .map(_.toByte) - assert(new ByteArrayReader(a).readLong() == 0xf1f2f3f4f5f6f7f8L) + assertEquals(new ByteArrayReader(a).readLong(), 0xf1f2f3f4f5f6f7f8L) } } diff --git a/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala b/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala index 895ff89060a..a902ff3a297 100644 --- a/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala +++ b/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala @@ -1,15 +1,15 @@ package is.hail.io import is.hail.HailSuite +import is.hail.annotations.UnsafeUtils import scala.collection.mutable.ArrayBuffer import org.scalacheck.Gen import org.scalacheck.Gen._ -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class IndexBTreeSuite extends HailSuite with munit.ScalaCheckSuite { val genStarts: Gen[(Int, Array[Long])] = for { @@ -27,7 +27,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } yield (depth, starts) - @Test def queryGivesSameAnswerAsArray(): Unit = + property("queryGivesSameAnswerAsArray") = forAll(genStarts) { case (depth: Int, arrayRandomStarts: Array[Long]) => val index = ctx.createTmpPath("testBtree", "idx") @@ -36,10 +36,10 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val btree = new IndexBTree(index, fs) val indexSize = fs.getFileSize(index) - val padding = 1024 - (arrayRandomStarts.length % 1024) - val numEntries = arrayRandomStarts.length + padding + (1 until depth).map { i => - math.pow(1024, i.toDouble).toInt - }.sum + val numEntries = + UnsafeUtils.roundUpAlignment(arrayRandomStarts.length, 1024) + (1 until depth).map { i => + math.pow(1024, i.toDouble).toInt + }.sum // make sure index size is correct val indexCorrectSize = if (indexSize == (numEntries * 8)) true else false @@ -64,10 +64,10 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { println(s"depth=$depthCorrect indexCorrect=$indexCorrectSize queryCorrect=$queryCorrect") btree.close() - assert(depthCorrect && indexCorrectSize && queryCorrect) + depthCorrect && indexCorrectSize && queryCorrect } - @Test def oneVariant(): Unit = { + test("oneVariant") { val index = Array(24.toLong) val fileSize = 30 // made-up value greater than index val idxFile = ctx.createTmpPath("testBtree_1variant", "idx") @@ -76,9 +76,9 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { IndexBTree.write(index, idxFile, fs) val btree = new IndexBTree(idxFile, fs) - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { btree.queryIndex(-5) - } + }: Unit assert(btree.queryIndex(0).contains(24)) assert(btree.queryIndex(10).contains(24)) @@ -88,15 +88,16 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(btree.queryIndex(fileSize.toLong - 1).isEmpty) } - @Test def zeroVariants(): Unit = - assertThrows[IllegalArgumentException] { + test("zeroVariants") { + intercept[IllegalArgumentException] { val index = Array[Long]() val idxFile = ctx.createTmpPath("testBtree_0variant", "idx") fs.delete(idxFile, recursive = true) IndexBTree.write(index, idxFile, fs) - } + }: Unit + } - @Test def testMultipleOfBranchingFactorDoesNotAddUnnecessaryElements(): Unit = { + test("MultipleOfBranchingFactorDoesNotAddUnnecessaryElements") { val in = Array[Long](10, 9, 8, 7, 6, 5, 4, 3) val bigEndianBytes = Array[Byte]( 0, 0, 0, 0, 0, 0, 0, 10, @@ -111,7 +112,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { sameElements bigEndianBytes) } - @Test def writeReadMultipleOfBranchingFactorDoesNotError(): Unit = { + test("writeReadMultipleOfBranchingFactorDoesNotError") { val idxFile = ctx.createTmpPath("btree") IndexBTree.write( Array.tabulate(1024)(i => i.toLong), @@ -122,7 +123,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(index.queryIndex(33).contains(33L)) } - @Test def queryArrayPositionAndFileOffsetIsCorrectSmallArray(): Unit = { + test("queryArrayPositionAndFileOffsetIsCorrectSmallArray") { val f = ctx.createTmpPath("btree") val v = Array[Long](1, 2, 3, 40, 50, 60, 70) val branchingFactor = 1024 @@ -140,7 +141,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(bt.queryArrayPositionAndFileOffset(71).isEmpty) } - @Test def queryArrayPositionAndFileOffsetIsCorrectTwoLevelsArray(): Unit = { + test("queryArrayPositionAndFileOffsetIsCorrectTwoLevelsArray") { def sqr(x: Long) = x * x val f = ctx.createTmpPath("btree") val v = Array.tabulate(1025)(x => sqr(x.toLong)) @@ -167,7 +168,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(bt.queryArrayPositionAndFileOffset(5).contains((3, sqr(3)))) } - @Test def queryArrayPositionAndFileOffsetIsCorrectThreeLevelsArray(): Unit = { + test("queryArrayPositionAndFileOffsetIsCorrectThreeLevelsArray") { def sqr(x: Long) = x * x val f = ctx.createTmpPath("btree") val v = Array.tabulate(1024 * 1024 + 1)(x => sqr(x.toLong)) @@ -214,7 +215,7 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024) + 1).isEmpty) } - @Test def onDiskBTreeIndexToValueSmallCorrect(): Unit = { + test("onDiskBTreeIndexToValueSmallCorrect") { val f = ctx.createTmpPath("btree") val v = Array[Long](1, 2, 3, 4, 5, 6, 7) val branchingFactor = 3 @@ -240,36 +241,30 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test def onDiskBTreeIndexToValueRandomized(): Unit = { - val g = - for { - longs <- nonEmptyContainerOf[Array, Long](choose(0L, Long.MaxValue)) - indices <- containerOf[Array, Int](choose(0, longs.length - 1)) - branchingFactor <- choose(2, 1024) - } yield (indices, longs, branchingFactor) - - forAll(g) { case (indices, longs, branchingFactor) => - val f = ctx.createTmpPath("test") - try { - IndexBTree.write(longs, f, fs, branchingFactor) - val bt = new OnDiskBTreeIndexToValue(f, fs, branchingFactor) - val actual = bt.positionOfVariants(indices) - val expected = indices.sorted.map(longs) - assert( - actual sameElements expected, - s"${actual.toSeq} not same elements as expected ${expected.toSeq}", + property("onDiskBTreeIndexToValueRandomized") = forAll( + for { + longs <- nonEmptyContainerOf[Array, Long](choose(0L, Long.MaxValue)) + indices <- containerOf[Array, Int](choose(0, longs.length - 1)) + branchingFactor <- choose(2, 1024) + } yield (indices, longs, branchingFactor) + ) { case (indices, longs, branchingFactor) => + val f = ctx.createTmpPath("test") + try { + IndexBTree.write(longs, f, fs, branchingFactor) + val bt = new OnDiskBTreeIndexToValue(f, fs, branchingFactor) + val actual = bt.positionOfVariants(indices) + val expected = indices.sorted.map(longs) + actual sameElements expected + } catch { + case t: Throwable => + throw new RuntimeException( + "exception while checking BTree: " + IndexBTree.toString(longs, branchingFactor), + t, ) - } catch { - case t: Throwable => - throw new RuntimeException( - "exception while checking BTree: " + IndexBTree.toString(longs, branchingFactor), - t, - ) - } } } - @Test def onDiskBTreeIndexToValueFourLayers(): Unit = { + test("onDiskBTreeIndexToValueFourLayers") { val longs = Array.tabulate(3 * 3 * 3 * 3)(x => x.toLong) val indices = Array(0, 3, 10, 20, 26, 27, 34, 55, 79, 80) val f = ctx.createTmpPath("btree") @@ -292,26 +287,26 @@ class IndexBTreeSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test def calcDepthIsCorrect(): Unit = { + test("calcDepthIsCorrect") { def sqr(x: Long) = x * x def cube(x: Long) = x * x * x def f(x: Long) = IndexBTree.calcDepth(x, 1024) - assert(f(1) == 1) - assert(f(1023) == 1) - assert(f(1024) == 1) - assert(f(1025) == 2) - - assert(f(sqr(1024) - 1) == 2) - assert(f(sqr(1024)) == 2) - assert(f(sqr(1024) + 1024) == 2) - assert(f(sqr(1024) + 1024 + 1) == 3) - - assert(f(cube(1024) - 1) == 3) - assert(f(cube(1024)) == 3) - assert(f(cube(1024) + sqr(1024)) == 3) - assert(f(cube(1024) + sqr(1024) + 1024) == 3) - assert(f(cube(1024) + sqr(1024) + 1024 + 1) == 4) + assertEquals(f(1), 1) + assertEquals(f(1023), 1) + assertEquals(f(1024), 1) + assertEquals(f(1025), 2) + + assertEquals(f(sqr(1024) - 1), 2) + assertEquals(f(sqr(1024)), 2) + assertEquals(f(sqr(1024) + 1024), 2) + assertEquals(f(sqr(1024) + 1024 + 1), 3) + + assertEquals(f(cube(1024) - 1), 3) + assertEquals(f(cube(1024)), 3) + assertEquals(f(cube(1024) + sqr(1024)), 3) + assertEquals(f(cube(1024) + sqr(1024) + 1024), 3) + assertEquals(f(cube(1024) + sqr(1024) + 1024 + 1), 4) } } diff --git a/hail/hail/test/src/is/hail/io/IndexSuite.scala b/hail/hail/test/src/is/hail/io/IndexSuite.scala index 5997d6f1b8f..3b7cc7e7350 100644 --- a/hail/hail/test/src/is/hail/io/IndexSuite.scala +++ b/hail/hail/test/src/is/hail/io/IndexSuite.scala @@ -10,7 +10,6 @@ import is.hail.types.virtual._ import is.hail.utils._ import org.apache.spark.sql.Row -import org.testng.annotations.{DataProvider, Test} class IndexSuite extends HailSuite { val strings = ArraySeq( @@ -31,10 +30,6 @@ class IndexSuite extends HailSuite { LeafChild(s, i.toLong, Row()) } - @DataProvider(name = "elements") - def data(): Array[Array[ArraySeq[String]]] = - (1 to strings.length).map(i => Array(strings.take(i))).toArray - def writeIndex( file: String, data: IndexedSeq[Any], @@ -99,53 +94,58 @@ class IndexSuite extends HailSuite { attributes, ) - @Test(dataProvider = "elements") - def writeReadGivesSameAsInput(data: IndexedSeq[String]): Unit = { - val file = ctx.createTmpPath("test", "idx") - val attributes: Map[String, Any] = Map("foo" -> true, "bar" -> 5) - - val a: (Int) => Annotation = (i: Int) => Row(i % 2 == 0) - - for (branchingFactor <- 2 to 5) { - writeIndex( - file, - data, - data.indices.map(i => a(i)), - TStruct("a" -> TBoolean), - branchingFactor, - attributes, - ) - assert(fs.getFileSize(file + "/index") != 0) - assert(fs.getFileSize(file + "/metadata.json.gz") != 0) - - val index = indexReader(file, TStruct("a" -> TBoolean)) - - assert(index.attributes == attributes) - - data.zipWithIndex.foreach { case (s, i) => - assert({ + object checkWriteReadGivesSameAsInput extends TestCases { + def apply( + data: IndexedSeq[String] + )(implicit loc: munit.Location + ): Unit = test("writeReadGivesSameAsInput") { + val file = ctx.createTmpPath("test", "idx") + val attributes: Map[String, Any] = Map("foo" -> true, "bar" -> 5) + + val a: (Int) => Annotation = (i: Int) => Row(i % 2 == 0) + + for (branchingFactor <- 2 to 5) { + writeIndex( + file, + data, + data.indices.map(i => a(i)), + TStruct("a" -> TBoolean), + branchingFactor, + attributes, + ) + assert(fs.getFileSize(file + "/index") != 0) + assert(fs.getFileSize(file + "/metadata.json.gz") != 0) + + val index = indexReader(file, TStruct("a" -> TBoolean)) + + assertEquals(index.attributes, attributes) + + data.zipWithIndex.foreach { case (s, i) => val result = index.queryByIndex(i.toLong) - result.key == s && result.annotation == a(i) - }) - } + assertEquals(result.key, s) + assertEquals(result.annotation, a(i)) + } - index.close() + index.close() + } } } - @Test def testEmptyKeys(): Unit = { + (1 to strings.length).foreach(i => checkWriteReadGivesSameAsInput(strings.take(i))) + + test("EmptyKeys") { val file = ctx.createTmpPath("empty", "idx") writeIndex(file, ArraySeq.empty, ArraySeq.empty, TStruct("a" -> TBoolean), 2) assert(fs.getFileSize(file + "/index") != 0) assert(fs.getFileSize(file + "/metadata.json.gz") != 0) val index = indexReader(file, TStruct("a" -> TBoolean)) - assertThrows[IllegalArgumentException](index.queryByIndex(0L)) + intercept[IllegalArgumentException](index.queryByIndex(0L)): Unit assert(index.queryByKey("moo").isEmpty) assert(index.queryByInterval("bear", "cat", includesStart = true, includesEnd = true).isEmpty) index.close() } - @Test def testLowerBound(): Unit = { + test("LowerBound") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("lowerBound", "idx") writeIndex( @@ -173,12 +173,12 @@ class IndexSuite extends HailSuite { ) expectedResult.foreach { case (s, expectedIdx) => - assert(index.lowerBound(s) == expectedIdx) // test full b-tree search works + assertEquals(index.lowerBound(s), expectedIdx.toLong) } } } - @Test def testUpperBound(): Unit = { + test("UpperBound") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("upperBound", "idx") writeIndex( @@ -206,12 +206,12 @@ class IndexSuite extends HailSuite { ) expectedResult.foreach { case (s, expectedIdx) => - assert(index.upperBound(s) == expectedIdx) // test full b-tree search works + assertEquals(index.upperBound(s), expectedIdx.toLong) } } } - @Test def testRangeIterator(): Unit = { + test("RangeIterator") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("range", "idx") val a = { (i: Int) => Row() } @@ -236,7 +236,7 @@ class IndexSuite extends HailSuite { } } - @Test def testQueryByKey(): Unit = { + test("QueryByKey") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("key", "idx") writeIndex( @@ -258,7 +258,7 @@ class IndexSuite extends HailSuite { } } - @Test def testIntervalIterator(): Unit = { + test("IntervalIterator") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("interval", "idx") writeIndex( @@ -270,89 +270,116 @@ class IndexSuite extends HailSuite { ) val index = indexReader(file, TStruct.empty) // intervals with endpoint in list - assert(index.queryByInterval( - "bear", - "bear", - includesStart = true, - includesEnd = true, - ).toFastSeq == index.iterator(0, 2).toFastSeq) + assertEquals( + index.queryByInterval( + "bear", + "bear", + includesStart = true, + includesEnd = true, + ).toFastSeq, + index.iterator(0, 2).toFastSeq, + ) - assert(index.queryByInterval( - "bear", - "cat", - includesStart = true, - includesEnd = false, - ).toFastSeq == index.iterator(0, 2).toFastSeq) - assert(index.queryByInterval( - "bear", - "cat", - includesStart = true, - includesEnd = true, - ).toFastSeq == index.iterator(0, 9).toFastSeq) - assert(index.queryByInterval( - "bear", - "cat", - includesStart = false, - includesEnd = true, - ).toFastSeq == index.iterator(2, 9).toFastSeq) - assert(index.queryByInterval( - "bear", - "cat", - includesStart = false, - includesEnd = false, - ).toFastSeq == index.iterator(2, 2).toFastSeq) + assertEquals( + index.queryByInterval( + "bear", + "cat", + includesStart = true, + includesEnd = false, + ).toFastSeq, + index.iterator(0, 2).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "bear", + "cat", + includesStart = true, + includesEnd = true, + ).toFastSeq, + index.iterator(0, 9).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "bear", + "cat", + includesStart = false, + includesEnd = true, + ).toFastSeq, + index.iterator(2, 9).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "bear", + "cat", + includesStart = false, + includesEnd = false, + ).toFastSeq, + index.iterator(2, 2).toFastSeq, + ) // intervals with endpoint(s) not in list - assert(index.queryByInterval( - "cat", - "snail", - includesStart = true, - includesEnd = false, - ).toFastSeq == index.iterator(2, 15).toFastSeq) - assert(index.queryByInterval( - "cat", - "snail", - includesStart = true, - includesEnd = true, - ).toFastSeq == index.iterator(2, 15).toFastSeq) - assert(index.queryByInterval( - "aardvark", - "cat", - includesStart = true, - includesEnd = true, - ).toFastSeq == index.iterator(0, 9).toFastSeq) - assert(index.queryByInterval( - "aardvark", - "cat", - includesStart = false, - includesEnd = true, - ).toFastSeq == index.iterator(0, 9).toFastSeq) + assertEquals( + index.queryByInterval( + "cat", + "snail", + includesStart = true, + includesEnd = false, + ).toFastSeq, + index.iterator(2, 15).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "cat", + "snail", + includesStart = true, + includesEnd = true, + ).toFastSeq, + index.iterator(2, 15).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "aardvark", + "cat", + includesStart = true, + includesEnd = true, + ).toFastSeq, + index.iterator(0, 9).toFastSeq, + ) + assertEquals( + index.queryByInterval( + "aardvark", + "cat", + includesStart = false, + includesEnd = true, + ).toFastSeq, + index.iterator(0, 9).toFastSeq, + ) // illegal interval queries - assertThrows[IllegalArgumentException](index.queryByInterval( + intercept[IllegalArgumentException](index.queryByInterval( "bear", "bear", includesStart = false, includesEnd = false, - ).toFastSeq) - assertThrows[IllegalArgumentException](index.queryByInterval( + ).toFastSeq): Unit + intercept[IllegalArgumentException](index.queryByInterval( "bear", "bear", includesStart = false, includesEnd = true, - ).toFastSeq) - assertThrows[IllegalArgumentException](index.queryByInterval( + ).toFastSeq): Unit + intercept[IllegalArgumentException](index.queryByInterval( "bear", "bear", includesStart = true, includesEnd = false, - ).toFastSeq) - assertThrows[IllegalArgumentException](index.queryByInterval( + ).toFastSeq): Unit + intercept[IllegalArgumentException](index.queryByInterval( "cat", "bear", includesStart = true, includesEnd = true, - ).toFastSeq) + ).toFastSeq): Unit val endPoints = (stringsWithDups.distinct ++ Array("aardvark", "boar", "elk", "oppossum", "snail", "zoo")).combinations(2) @@ -377,20 +404,26 @@ class IndexSuite extends HailSuite { else stringsWithDups.lastIndexWhere(bounds(1) > _) + 1 // want last index before transition point and then want to include that value so add 1 - assert(index.queryByInterval( - bounds(0), - bounds(1), - includesStart, - includesEnd, - ).toFastSeq == - index.iterator(lowerBoundIdx.toLong, upperBoundIdx.toLong).toFastSeq) + assertEquals( + index.queryByInterval( + bounds(0), + bounds(1), + includesStart, + includesEnd, + ).toFastSeq, + index.iterator(lowerBoundIdx.toLong, upperBoundIdx.toLong).toFastSeq, + ) if (includesStart) - assert(index.iterateFrom(bounds(0)).toFastSeq == - leafsWithDups.slice(lowerBoundIdx, stringsWithDups.length)) + assertEquals( + index.iterateFrom(bounds(0)).toFastSeq, + leafsWithDups.slice(lowerBoundIdx, stringsWithDups.length), + ) if (!includesEnd) - assert(index.iterateUntil(bounds(1)).toFastSeq == - leafsWithDups.slice(0, upperBoundIdx)) + assertEquals( + index.iterateUntil(bounds(1)).toFastSeq, + leafsWithDups.slice(0, upperBoundIdx), + ) } else intercept[IllegalArgumentException](index.queryByInterval( bounds(0), @@ -404,7 +437,7 @@ class IndexSuite extends HailSuite { } } - @Test def testIntervalIteratorWorksWithGeneralEndpoints(): Unit = { + test("IntervalIteratorWorksWithGeneralEndpoints") { for (branchingFactor <- 2 to 5) { val keyType = PCanonicalStruct("a" -> PCanonicalString(), "b" -> PInt32()) val file = ctx.createTmpPath("from", "idx") @@ -427,45 +460,55 @@ class IndexSuite extends HailSuite { +PCanonicalStruct(), keyPType = PCanonicalStruct("a" -> PCanonicalString(), "b" -> PInt32()), ) - assert(index.queryByInterval( - Row("cat", 3), - Row("cat", 5), - includesStart = true, - includesEnd = false, - ).toFastSeq == - leafChildren.slice(3, 5)) - assert(index.queryByInterval( - Row("cat"), - Row("cat", 5), - includesStart = true, - includesEnd = false, - ).toFastSeq == - leafChildren.slice(2, 5)) - assert(index.queryByInterval( - Row(), - Row(), - includesStart = true, - includesEnd = true, - ).toFastSeq == - leafChildren) - assert(index.queryByInterval( - Row(), - Row("cat"), - includesStart = true, - includesEnd = false, - ).toFastSeq == - leafChildren.take(2)) - assert(index.queryByInterval( - Row("zebra"), - Row(), - includesStart = true, - includesEnd = true, - ).toFastSeq == - leafChildren.takeRight(3)) + assertEquals( + index.queryByInterval( + Row("cat", 3), + Row("cat", 5), + includesStart = true, + includesEnd = false, + ).toFastSeq, + leafChildren.slice(3, 5), + ) + assertEquals( + index.queryByInterval( + Row("cat"), + Row("cat", 5), + includesStart = true, + includesEnd = false, + ).toFastSeq, + leafChildren.slice(2, 5), + ) + assertEquals( + index.queryByInterval( + Row(), + Row(), + includesStart = true, + includesEnd = true, + ).toFastSeq, + leafChildren, + ) + assertEquals( + index.queryByInterval( + Row(), + Row("cat"), + includesStart = true, + includesEnd = false, + ).toFastSeq, + leafChildren.take(2), + ) + assertEquals( + index.queryByInterval( + Row("zebra"), + Row(), + includesStart = true, + includesEnd = true, + ).toFastSeq, + leafChildren.takeRight(3), + ) } } - @Test def testIterateFromUntil(): Unit = { + test("IterateFromUntil") { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("from", "idx") writeIndex( @@ -483,13 +526,16 @@ class IndexSuite extends HailSuite { var start = stringsWithDups.indexWhere(s <= _) if (start == -1) start = stringsWithDups.length - assert(index.iterateFrom(s).toFastSeq == leafsWithDups.slice( - start, - stringsWithDups.length, - )) + assertEquals( + index.iterateFrom(s).toFastSeq, + leafsWithDups.slice( + start, + stringsWithDups.length, + ), + ) val end = stringsWithDups.lastIndexWhere(s > _) + 1 - assert(index.iterateUntil(s).toFastSeq == leafsWithDups.slice(0, end)) + assertEquals(index.iterateUntil(s).toFastSeq, leafsWithDups.slice(0, end)) } } } diff --git a/hail/hail/test/src/is/hail/io/TabixSuite.scala b/hail/hail/test/src/is/hail/io/TabixSuite.scala index 86dfd6da168..d31f7fc04b5 100644 --- a/hail/hail/test/src/is/hail/io/TabixSuite.scala +++ b/hail/hail/test/src/is/hail/io/TabixSuite.scala @@ -5,11 +5,6 @@ import is.hail.io.tabix._ import is.hail.io.vcf.TabixVCF import htsjdk.tribble.readers.{TabixReader => HtsjdkTabixReader} -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.scalatest.matchers.must.Matchers.contain -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.testng.annotations.Test class TabixSuite extends HailSuite { // use .gz for several tests and .bgz for another to test handling of both @@ -20,42 +15,42 @@ class TabixSuite extends HailSuite { lazy val reader = new TabixReader(vcfGzFile, fs) - @Test def testLargeNumberOfSequences(): Unit = { + test("LargeNumberOfSequences") { val tbx = new TabixReader(null, fs, Some(getTestResource("large-tabix.tbi"))) // known length of sequences - assert(tbx.index.seqs.length == 3366) + assertEquals(tbx.index.seqs.length, 3366) } - @Test def testSequenceNames(): Unit = { - val expectedSeqNames = new Array[String](24); + test("SequenceNames") { + val expectedSeqNames = new Array[String](24) for (i <- 1 to 22) expectedSeqNames(i - 1) = i.toString - expectedSeqNames(22) = "X"; - expectedSeqNames(23) = "Y"; + expectedSeqNames(22) = "X" + expectedSeqNames(23) = "Y" val sequenceNames = reader.index.chr2tid.keySet - assert(expectedSeqNames.length == sequenceNames.size) - sequenceNames should contain theSameElementsAs expectedSeqNames + assertEquals(sequenceNames.size, expectedSeqNames.length) + assertEquals(sequenceNames, expectedSeqNames.toSet) } - @Test def testSequenceSet(): Unit = { + test("SequenceSet") { val chrs = reader.index.chr2tid.keySet assert(chrs.nonEmpty) assert(chrs.contains("1")) assert(!chrs.contains("MT")) } - @Test def testLineIterator(): Unit = { + test("LineIterator") { val htsjdkrdr = new HtsjdkTabixReader(vcfGzFile) // In range access - forAll(Seq("1", "19", "X")) { chr => + Seq("1", "19", "X").foreach { chr => val tid = reader.chr2tid(chr) - val pairs = reader.queryPairs(tid, 1, 400); + val pairs = reader.queryPairs(tid, 1, 400) val hailIter = new TabixLineIterator(fs, reader.filePath, pairs) - val htsIter = htsjdkrdr.query(chr, 1, 400); + val htsIter = htsjdkrdr.query(chr, 1, 400) val hailStr = hailIter.next() val htsStr = htsIter.next() - assert(hailStr == htsStr) + assertEquals(hailStr, htsStr) assert(hailIter.next() == null) } @@ -63,28 +58,28 @@ class TabixSuite extends HailSuite { // NOTE: We use the larger interval for the htsjdk iterator because the // hail iterator may return the one record that is contained in each of the // chromosomes we check - forAll(Seq("1", "19", "X")) { chr => + Seq("1", "19", "X").foreach { chr => val tid = reader.chr2tid(chr) - val pairs = reader.queryPairs(tid, 350, 400); + val pairs = reader.queryPairs(tid, 350, 400) val hailIter = new TabixLineIterator(fs, reader.filePath, pairs) - val htsIter = htsjdkrdr.query(chr, 1, 400); + val htsIter = htsjdkrdr.query(chr, 1, 400) val hailStr = hailIter.next() val htsStr = htsIter.next() if (hailStr != null) - assert(hailStr == htsStr) + assertEquals(hailStr, htsStr) assert(hailIter.next() == null) } // beg == end - forAll(Seq("1", "19", "X")) { chr => + Seq("1", "19", "X").foreach { chr => val tid = reader.chr2tid(chr) - val pairs = reader.queryPairs(tid, 100, 100); + val pairs = reader.queryPairs(tid, 100, 100) val hailIter = new TabixLineIterator(fs, reader.filePath, pairs) - val htsIter = htsjdkrdr.query(chr, 100, 100); + val htsIter = htsjdkrdr.query(chr, 100, 100) val hailStr = hailIter.next() val htsStr = htsIter.next() assert(hailStr == null) - assert(hailStr == htsStr) + assertEquals(hailStr, htsStr) } } @@ -94,18 +89,16 @@ class TabixSuite extends HailSuite { val hailrdr = new TabixReader(vcfFile, fs) val tid = hailrdr.chr2tid(chr) - forAll( - Seq( - (12990058, 12990059), // Small interval, containing just one locus at end - (10570000, 13000000), // Open interval - (10019093, 16360860), // Closed interval - (11000000, 13029764), // Half open (beg, end] - (17434340, 18000000), // Half open [beg, end) - (13943975, 14733634), // Some random intervals - (11578765, 15291865), - (12703588, 16751726), - ) - ) { case (start, end) => + Seq( + (12990058, 12990059), // Small interval, containing just one locus at end + (10570000, 13000000), // Open interval + (10019093, 16360860), // Closed interval + (11000000, 13029764), // Half open (beg, end] + (17434340, 18000000), // Half open [beg, end) + (13943975, 14733634), // Some random intervals + (11578765, 15291865), + (12703588, 16751726), + ).foreach { case (start, end) => val pairs = hailrdr.queryPairs(tid, start, end) val htsIter = htsjdkrdr.query(chr, start, end) val hailIter = new TabixLineIterator(fs, hailrdr.filePath, pairs) @@ -127,10 +120,9 @@ class TabixSuite extends HailSuite { } } - @Test def testLineIterator2(): Unit = - _testLineIterator2(getTestResource("sample.vcf.bgz")) + test("LineIterator2")(_testLineIterator2(getTestResource("sample.vcf.bgz"))) - @Test def testWriter(): Unit = { + test("Writer") { val vcfFile = getTestResource("sample.vcf.bgz") val path = ctx.createTmpPath("test-tabix-write", "bgz") fs.copy(vcfFile, path) diff --git a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala index cdcda06dd25..2189914bf68 100644 --- a/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala +++ b/hail/hail/test/src/is/hail/io/compress/BGzipCodecSuite.scala @@ -17,12 +17,7 @@ import org.apache.{hadoop => hd} import org.apache.commons.io.IOUtils import org.apache.spark.sql.Row import org.scalacheck.Gen._ -import org.scalatest -import org.scalatest.matchers.must.Matchers.contain -import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, equal} -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { override def getSplits(job: hd.mapreduce.JobContext): java.util.List[hd.mapreduce.InputSplit] = { @@ -59,7 +54,7 @@ class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { } } -class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class BGzipCodecSuite extends HailSuite with munit.ScalaCheckSuite { val uncompPath = getTestResource("sample.vcf") // is actually a bgz file @@ -73,33 +68,33 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { * - concatenated the chunks */ val compPath = getTestResource("bgz.test.sample.vcf.bgz") - @Test def testGenericLinesSimpleUncompressed(): Unit = { + test("GenericLinesSimpleUncompressed") { val lines = Source.fromFile(uncompPath).getLines().toFastSeq val uncompStatus = fs.fileStatus(uncompPath) - scalatest.Inspectors.forAll(0 until 16) { i => + (0 until 16).foreach { i => val lines2 = GenericLines.collect( fs, GenericLines.read(fs, ArraySeq(uncompStatus), Some(i), None, None, false, false), ) - lines2 should equal(lines) + assertEquals(lines2, lines) } } - @Test def testGenericLinesSimpleBGZ(): Unit = { + test("GenericLinesSimpleBGZ") { val lines = Source.fromFile(uncompPath).getLines().toFastSeq val compStatus = fs.fileStatus(compPath) - scalatest.Inspectors.forAll(0 until 16) { i => + (0 until 16).foreach { i => val lines2 = GenericLines.collect( fs, GenericLines.read(fs, ArraySeq(compStatus), Some(i), None, None, false, false), ) - lines2 should equal(lines) + assertEquals(lines2, lines) } } - @Test def testGenericLinesSimpleGZ(): Unit = { + test("GenericLinesSimpleGZ") { val lines = Source.fromFile(uncompPath).getLines().toFastSeq // won't split, just run once @@ -108,16 +103,17 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { fs, GenericLines.read(fs, ArraySeq(gzStatus), Some(7), None, None, false, true), ) - lines2 should equal(lines) + assertEquals(lines2, lines) } - @Test def testGenericLinesRefuseGZ(): Unit = + test("GenericLinesRefuseGZ") { interceptFatal("Cowardly refusing") { val gzStatus = fs.fileStatus(gzPath) GenericLines.read(fs, ArraySeq(gzStatus), Some(7), None, None, false, false) } + } - @Test def testGenericLinesRandom(): Unit = { + property("GenericLinesRandom") { val lines = Source.fromFile(uncompPath).getLines().toFastSeq val compLength = 195353L @@ -136,18 +132,19 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } yield (Array(0L, compLength) ++ rawSplits).distinct.sorted - forAll(g) { splits => + Prop.forAll(g) { splits => val contexts = (0 until (splits.length - 1)) .map { i => val end = makeVirtualOffset(splits(i + 1), 0) Row(i, 0, compPath, splits(i), end, true) } val lines2 = GenericLines.collect(fs, GenericLines.read(fs, contexts, false, false)) - lines2 should equal(lines) + assertEquals(lines2, lines) + true } } - @Test def test(): Unit = { + property("BGzipCodec") { sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", 1L) val uncompIS = fs.open(uncompPath) @@ -162,12 +159,11 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val lines = Source.fromBytes(uncomp).getLines().toArray - assert(sc.textFile(uncompPath).collectOrdered() - .sameElements(lines)) + assert(sc.textFile(uncompPath).collectOrdered().sameElements(lines)) - scalatest.Inspectors.forAll(1 until 20) { i => + (1 until 20).foreach { i => val linesRDD = sc.textFile(compPath, i) - assert(linesRDD.partitions.length == i) + assertEquals(linesRDD.partitions.length, i) assert(linesRDD.collectOrdered().sameElements(lines)) } @@ -190,7 +186,7 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } yield (Array(0L, compLength) ++ rawSplits).distinct.sorted - forAll(g) { splits => + Prop.forAll(g) { splits => val jobConf = new hd.conf.Configuration(sc.hadoopConfiguration) jobConf.set("bgz.test.splits", splits.mkString(",")) val rdd = sc.newAPIHadoopFile[hd.io.LongWritable, hd.io.Text, TestFileInputFormat]( @@ -202,11 +198,12 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) val rddLines = rdd.map(_._2.toString).collectOrdered() - rddLines should contain theSameElementsAs (lines) + assertEquals(rddLines.sorted.toSeq, lines.sorted.toSeq) + true } } - @Test def testVirtualSeek(): Unit = { + test("VirtualSeek") { // real offsets of the start of some blocks, paired with the offset to the next block val blockStarts = Array[(Long, Long)]( (0, 14653), @@ -241,7 +238,7 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val vOff = BlockCompressedFilePointerUtil.makeFilePointer(cOff, extra) decompIS.virtualSeek(vOff) - assert(decompIS.getVirtualOffset() == vOff); + assert(decompIS.getVirtualOffset() == vOff) uncompIS.seek(uOff.toLong + extra) val decompRead = decompIS.readNBytes(decompData, 0, decompData.length) @@ -276,24 +273,24 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { )) // Trying to seek to the end of a block should fail - assertThrows[java.io.IOException] { + intercept[java.io.IOException] { val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(1)._1, maxBlockSize) decompIS.virtualSeek(vOff) - } + }: Unit // Trying to seek past the end of a block should fail - assertThrows[java.io.IOException] { + intercept[java.io.IOException] { val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(0)._1, maxBlockSize + 1) decompIS.virtualSeek(vOff) - } + }: Unit // Trying to seek to the end of the last block should fail - assertThrows[java.io.IOException] { + intercept[java.io.IOException] { val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._1, lastBlockLen) decompIS.virtualSeek(vOff) - } + }: Unit // trying to seek to the end of file should succeed decompIS.virtualSeek(0) @@ -303,7 +300,7 @@ class BGzipCodecSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { // seeking past end of file directly should fail decompIS.virtualSeek(0) - assertThrows[java.io.IOException] { + intercept[java.io.IOException] { val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._2, 1) decompIS.virtualSeek(vOff) } diff --git a/hail/hail/test/src/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/hail/test/src/is/hail/io/fs/AzureStorageFSSuite.scala index 3afb0f257c6..44414cb7223 100644 --- a/hail/hail/test/src/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/hail/test/src/is/hail/io/fs/AzureStorageFSSuite.scala @@ -2,19 +2,16 @@ package is.hail.io.fs import is.hail.services.oauth2.AzureCloudCredentials -import org.scalatestplus.testng.TestNGSuite -import org.testng.SkipException -import org.testng.annotations.{BeforeClass, Test} - -class AzureStorageFSSuite extends TestNGSuite with FSSuite { - @BeforeClass - def beforeclass(): Unit = - if ( - System.getenv("HAIL_CLOUD") != "azure" || - root == null || - fsResourcesRoot == null +class AzureStorageFSSuite extends FSSuite { + override def beforeAll(): Unit = { + super.beforeAll() + assume( + System.getenv("HAIL_CLOUD") == "azure" && + root != null && + fsResourcesRoot != null, + "not in Azure", ) - throw new SkipException("skip") + } override lazy val fs: FS = new AzureStorageFS( @@ -22,12 +19,12 @@ class AzureStorageFSSuite extends TestNGSuite with FSSuite { .scoped(AzureStorageFS.RequiredOAuthScopes) ) - @Test def testMakeQualified(): Unit = { + test("MakeQualified") { val qualifiedFileName = "https://account.blob.core.windows.net/container/path" - assert(fs.makeQualified(qualifiedFileName) == qualifiedFileName) + assertEquals(fs.makeQualified(qualifiedFileName), qualifiedFileName) val unqualifiedFileName = "https://account/container/path" - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { fs.makeQualified(unqualifiedFileName) } } diff --git a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala index 5ef2a1d6c67..4d99fc3b1a0 100644 --- a/hail/hail/test/src/is/hail/io/fs/FSSuite.scala +++ b/hail/hail/test/src/is/hail/io/fs/FSSuite.scala @@ -9,13 +9,8 @@ import java.io.FileNotFoundException import org.apache.commons.io.IOUtils import org.apache.hadoop.fs.FileAlreadyExistsException -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.scalatestplus.testng.TestNGSuiteLike -import org.testng.SkipException -import org.testng.annotations.Test -trait FSSuite extends TestNGSuiteLike with TestUtils { +trait FSSuite extends munit.FunSuite with TestUtils { val root: String = System.getenv("HAIL_TEST_STORAGE_URI") def fsResourcesRoot: String = System.getenv("HAIL_FS_TEST_CLOUD_RESOURCES_URI") def tmpdir: String = System.getenv("HAIL_TEST_STORAGE_URI") @@ -38,7 +33,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { def pathsRelResourcesRoot(statuses: IndexedSeq[FileListEntry]): Set[String] = pathsRelRoot(fsResourcesRoot, statuses) - @Test def testExistsOnDirectory(): Unit = { + test("ExistsOnDirectory") { assert(fs.exists(r("/dir"))) assert(fs.exists(r("/dir/"))) @@ -46,14 +41,14 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(r("/does_not_exist_dir/"))) } - @Test def testExistsOnFile(): Unit = { + test("ExistsOnFile") { assert(fs.exists(r("/a"))) assert(fs.exists(r("/zzz"))) assert(!fs.exists(r("/z"))) // prefix } - @Test def testFileStatusOnFile(): Unit = { + test("FileStatusOnFile") { // file val f = r("/a") val s = fs.fileStatus(f) @@ -61,7 +56,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(s.getLen == 12) } - @Test def testFileListEntryOnFile(): Unit = { + test("FileListEntryOnFile") { // file val f = r("/a") val s = fs.fileListEntry(f) @@ -71,14 +66,14 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(s.getLen == 12) } - @Test def testFileStatusOnDirIsFailure(): Unit = { + test("FileStatusOnDirIsFailure") { val f = r("/dir") interceptException[FileNotFoundException](f)( fs.fileStatus(f) ) } - @Test def testFileListEntryOnDir(): Unit = { + test("FileListEntryOnDir") { // file val f = r("/dir") val s = fs.fileListEntry(f) @@ -87,7 +82,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(s.isDirectory) } - @Test def testFileListEntryOnDirWithSlash(): Unit = { + test("FileListEntryOnDirWithSlash") { // file val f = r("/dir/") val s = fs.fileListEntry(f) @@ -96,23 +91,24 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(s.isDirectory) } - @Test def testFileListEntryOnMissingFile(): Unit = - assertThrows[FileNotFoundException] { + test("FileListEntryOnMissingFile") { + intercept[FileNotFoundException] { fs.fileListEntry(r("/does_not_exist")) } + } - @Test def testFileListEntryRoot(): Unit = { + test("FileListEntryRoot") { val s = fs.fileListEntry(root) assert(s.getPath == root) } - @Test def testFileListEntryRootWithSlash(): Unit = { - if (root.endsWith("/")) throw new SkipException("skipped") + test("FileListEntryRootWithSlash") { + assume(!root.endsWith("/"), "skipped") val s = fs.fileListEntry(s"$root/") assert(s.getPath == root) } - @Test def testDeleteRecursive(): Unit = { + test("DeleteRecursive") { val d = t() fs.mkDir(d) fs.touch(s"$d/x") @@ -133,34 +129,34 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(s"$d/subdir/z")) } - @Test def testDeleteFileDoesntExist(): Unit = { + test("DeleteFileDoesntExist") { val d = t() fs.mkDir(d) fs.delete(s"$d/foo", recursive = false) fs.delete(s"$d/foo", recursive = true) } - @Test def testListDirectory(): Unit = { + test("ListDirectory") { val statuses = fs.listDirectory(r("")) assert(pathsRelResourcesRoot(statuses) == Set("/a", "/adir", "/az", "/dir", "/zzz")) } - @Test def testListDirectoryWithSlash(): Unit = { + test("ListDirectoryWithSlash") { val statuses = fs.listDirectory(r("/")) assert(pathsRelResourcesRoot(statuses) == Set("/a", "/adir", "/az", "/dir", "/zzz")) } - @Test def testGlobOnDir(): Unit = { + test("GlobOnDir") { val statuses = fs.glob(r("")) assert(pathsRelResourcesRoot(statuses) == Set("")) } - @Test def testGlobMissingFile(): Unit = { + test("GlobMissingFile") { val statuses = fs.glob(r("/does_not_exist_dir/does_not_exist")) assert(pathsRelResourcesRoot(statuses) == Set()) } - @Test def testGlobFilename(): Unit = { + test("GlobFilename") { val statuses = fs.glob(r("/a*")) assert( pathsRelResourcesRoot(statuses) == Set("/a", "/adir", "/az"), @@ -168,7 +164,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { ) } - @Test def testGlobFilenameMatchSingleCharacter(): Unit = { + test("GlobFilenameMatchSingleCharacter") { val statuses = fs.glob(r("/a?")) assert( pathsRelResourcesRoot(statuses) == Set("/az"), @@ -176,7 +172,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { ) } - @Test def testGlobFilenameMatchSingleCharacterInMiddleOfName(): Unit = { + test("GlobFilenameMatchSingleCharacterInMiddleOfName") { val statuses = fs.glob(r("/a?ir")) assert( pathsRelResourcesRoot(statuses) == Set("/adir"), @@ -184,7 +180,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { ) } - @Test def testGlobDirnameMatchSingleCharacterInMiddleOfName(): Unit = { + test("GlobDirnameMatchSingleCharacterInMiddleOfName") { val statuses = fs.glob(r("/a?ir/x")) assert( pathsRelResourcesRoot(statuses) == Set("/adir/x"), @@ -192,7 +188,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { ) } - @Test def testGlobMatchDir(): Unit = { + test("GlobMatchDir") { val statuses = fs.glob(r("/*dir/x")) assert( pathsRelResourcesRoot(statuses) == Set("/adir/x", "/dir/x"), @@ -200,13 +196,13 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { ) } - @Test def testGlobRoot(): Unit = { + test("GlobRoot") { val statuses = fs.glob(root) // empty with respect to root (self) assert(pathsRelRoot(root, statuses) == Set("")) } - @Test def testFileEndingWithPeriod(): Unit = { + test("FileEndingWithPeriod") { val f = fs.makeQualified(t()) fs.touch(f + "/foo.") val statuses = fs.listDirectory(f) @@ -220,13 +216,13 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { } } - @Test def testGlobRootWithSlash(): Unit = { - if (root.endsWith("/")) throw new SkipException("skipped") + test("GlobRootWithSlash") { + assume(!root.endsWith("/"), "skipped") val statuses = fs.glob(s"$root/") assert(pathsRelRoot(root, statuses) == Set("")) } - @Test def testWriteRead(): Unit = { + test("WriteRead") { val s = "this is a test string" val f = t() @@ -247,7 +243,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(f)) } - @Test def testWriteReadCompressed(): Unit = { + test("WriteReadCompressed") { val s = "this is a test string" val f = t(extension = ".bgz") @@ -268,7 +264,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(f)) } - @Test def testWritePreexisting(): Unit = { + test("WritePreexisting") { val s1 = "first" val s2 = "second" val f = t() @@ -288,10 +284,11 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { } } - @Test def testGetCodecExtension(): Unit = + test("GetCodecExtension") { assert(fs.getCodecExtension("foo.vcf.bgz") == ".bgz") + } - @Test def testReadWriteBytes(): Unit = { + test("ReadWriteBytes") { val f = t() using(fs.create(f)) { os => @@ -313,7 +310,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(f)) } - @Test def testReadWriteBytesLargerThanBuffer(): Unit = { + test("ReadWriteBytesLargerThanBuffer") { val f = t() val numWrites = 1000000 @@ -349,7 +346,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(f)) } - @Test def testDropTrailingSlash(): Unit = { + test("DropTrailingSlash") { assert(dropTrailingSlash("") == "") assert(dropTrailingSlash("/foo/bar") == "/foo/bar") assert(dropTrailingSlash("foo/bar/") == "foo/bar") @@ -357,7 +354,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(dropTrailingSlash("///") == "") } - @Test def testSeekMoreThanMaxInt(): Unit = { + test("SeekMoreThanMaxInt") { val f = t() using(fs.create(f)) { os => val eight_mib = 8 * 1024 * 1024 @@ -388,7 +385,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fs.exists(f)) } - @Test def testSeekAndReadStraddlingBufferSize(): Unit = { + test("SeekAndReadStraddlingBufferSize") { val data = Array.tabulate(251)(_.toByte) val f = t() using(fs.create(f)) { os => @@ -411,11 +408,11 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { val toRead = new Array[Byte](512) is.readFully(toRead) - forAll(toRead.indices)(i => assert(toRead(i) == ((seekPos + i) % 251).toByte)) + toRead.indices.foreach(i => assert(toRead(i) == ((seekPos + i) % 251).toByte)) } } - @Test def largeDirectoryOperations(): Unit = { + test("largeDirectoryOperations") { val prefix = s"$tmpdir/fs-suite/delete-many-files/${java.util.UUID.randomUUID()}" for (i <- 0 until 2000) fs.touch(s"$prefix/$i.suffix") @@ -430,7 +427,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { fail(s"files not deleted:\n${fs.listDirectory(prefix).map(_.getPath).mkString("\n")}") } - @Test def testSeekAfterEOF(): Unit = { + test("SeekAfterEOF") { val prefix = s"$tmpdir/fs-suite/delete-many-files/${java.util.UUID.randomUUID()}" val p = s"$prefix/seek_file" using(fs.createNoCompression(p)) { os => @@ -450,25 +447,26 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { } } - @Test def fileAndDirectoryIsError(): Unit = { + test("fileAndDirectoryIsError") { val d = t() fs.mkDir(d) fs.touch(s"$d/x/file") - intercept[Exception] { + val caught = intercept[Exception] { fs.touch(s"$d/x") fs.fileListEntry(s"$d/x") - } match { + } + caught match { /* Hadoop, in particular, errors when you touch an object whose name is a prefix of another * object. */ case exc: FileAndDirectoryException if exc.getMessage() == s"$d/x appears as both file $d/x and directory $d/x/." => case exc: FileNotFoundException if exc.getMessage() == s"$d/x (Is a directory)" => - case other => fail(other) + case other => fail("unexpected exception", other) } } - @Test def testETag(): Unit = { + test("ETag") { val etag = fs.eTag(s"$fsResourcesRoot/a") if (fs.parseUrl(fsResourcesRoot).toString.startsWith("file:")) { // only the local file system should lack etags. @@ -478,7 +476,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { } } - @Test def fileAndDirectoryIsErrorEvenIfPrefixedFileIsNotLexicographicallyFirst(): Unit = { + test("fileAndDirectoryIsErrorEvenIfPrefixedFileIsNotLexicographicallyFirst") { val d = t() fs.mkDir(d) fs.touch(s"$d/x") @@ -514,21 +512,22 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { fs.touch(s"$d/x,") fs.touch(s"$d/x-") // fs.touch(s"$d/x.") // https://github.com/Azure/azure-sdk-for-java/issues/36674 - intercept[Exception] { + val caught = intercept[Exception] { fs.touch(s"$d/x/file") fs.fileListEntry(s"$d/x") - } match { + } + caught match { /* Hadoop, in particular, errors when you touch an object whose name is a prefix of another * object. */ case exc: FileAndDirectoryException if exc.getMessage() == s"$d/x appears as both file $d/x and directory $d/x/." => case exc: FileAlreadyExistsException if exc.getMessage() == s"Destination exists and is not a directory: $d/x" => - case other => fail(other) + case other => fail("unexpected exception", other) } } - @Test def fileListEntrySeesDirectoryEvenIfPrefixedFileIsNotLexicographicallyFirst(): Unit = { + test("fileListEntrySeesDirectoryEvenIfPrefixedFileIsNotLexicographicallyFirst") { val d = t() fs.mkDir(d) // fs.touch(s"$d/x ") // Hail does not support spaces in path names @@ -570,8 +569,7 @@ trait FSSuite extends TestNGSuiteLike with TestUtils { assert(!fle.isFile) } - @Test def fileListEntrySeesFileEvenWithPeersPreceedingThePositionOfANonPresentDirectoryEntry() - : Unit = { + test("fileListEntrySeesFileEvenWithPeersPreceedingThePositionOfANonPresentDirectoryEntry") { val d = t() fs.mkDir(d) fs.touch(s"$d/x") diff --git a/hail/hail/test/src/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/hail/test/src/is/hail/io/fs/GoogleStorageFSSuite.scala index 3db57f215f0..d4f1a378d5a 100644 --- a/hail/hail/test/src/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/hail/test/src/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -2,19 +2,16 @@ package is.hail.io.fs import is.hail.services.oauth2.GoogleCloudCredentials -import org.scalatestplus.testng.TestNGSuite -import org.testng.SkipException -import org.testng.annotations.{BeforeClass, Test} - -class GoogleStorageFSSuite extends TestNGSuite with FSSuite { - @BeforeClass - def beforeclass(): Unit = - if ( - System.getenv("HAIL_CLOUD") != "gcp" || - root == null || - fsResourcesRoot == null +class GoogleStorageFSSuite extends FSSuite { + override def beforeAll(): Unit = { + super.beforeAll() + assume( + System.getenv("HAIL_CLOUD") == "gcp" && + root != null && + fsResourcesRoot != null, + "not in GCP", ) - throw new SkipException("skip") + } override lazy val fs: FS = new GoogleStorageFS( @@ -23,13 +20,12 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { None, ) - @Test def testMakeQualified(): Unit = { + test("MakeQualified") { val qualifiedFileName = "gs://bucket/path" - assert(fs.makeQualified(qualifiedFileName) == qualifiedFileName) + assertEquals(fs.makeQualified(qualifiedFileName), qualifiedFileName) - val unqualifiedFileName = "not-gs://bucket/path" - assertThrows[IllegalArgumentException] { - fs.makeQualified(unqualifiedFileName) + intercept[IllegalArgumentException] { + fs.makeQualified("not-gs://bucket/path") } } } diff --git a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala index 1f49d39bab6..179f352c4e7 100644 --- a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala @@ -15,13 +15,10 @@ import breeze.linalg.{*, diag, DenseMatrix, DenseVector => BDV} import org.apache.spark.sql.Row import org.scalacheck._ import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen.{size, _} -import org.scalatest -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Gen._ +import org.scalacheck.Prop.forAll -class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class BlockMatrixSuite extends HailSuite with munit.ScalaCheckSuite { val interestingPosInt: Gen[Int] = oneOf( @@ -112,8 +109,8 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { x.rows == y.rows && x.cols == y.cols, s"dimension mismatch: ${x.rows} x ${x.cols} vs ${y.rows} x ${y.cols}", ) - scalatest.Inspectors.forAll(0 until x.cols) { j => - scalatest.Inspectors.forAll(0 until x.rows) { i => + (0 until x.cols).foreach { j => + (0 until x.rows).foreach { i => assert( !(D_==(x(i, j) - y(i, j), relTolerance) && !(x(i, j).isNaN && y(i, j).isNaN)), s"x=${x.toString(1000, 1000)}\n" ++ @@ -124,8 +121,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } } - @Test - def pointwiseSubtractCorrect(): Unit = { + test("pointwiseSubtractCorrect") { val m = toBM( 4, 4, @@ -147,11 +143,10 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) val actual = (m - m.T).toBreezeMatrix() - assert(actual == expected) + assertEquals(actual, expected) } - @Test - def multiplyByLocalMatrix(): Unit = { + test("multiplyByLocalMatrix") { val ll = toLM( 4, 4, @@ -174,66 +169,50 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ), ) - assert(ll * lr === l.dot(ctx, lr).toBreezeMatrix()) + assertEquals(ll * lr, l.dot(ctx, lr).toBreezeMatrix()) } - @Test - def randomMultiplyByLocalMatrix(): Unit = - forAll(genMultipliableDenseMatrices) { case (ll, lr) => - val l = toBM(ll) - assertDoubleMatrixNaNEqualsNaN(ll * lr, l.dot(ctx, lr).toBreezeMatrix()) - } + property("randomMultiplyByLocalMatrix") = forAll(genMultipliableDenseMatrices) { case (ll, lr) => + val l = toBM(ll) + assertDoubleMatrixNaNEqualsNaN(ll * lr, l.dot(ctx, lr).toBreezeMatrix()) + } - @Test - def multiplySameAsBreeze(): Unit = { + property("multiplySameAsBreeze") { forAll(genDenseMatrix(4, 4), genDenseMatrix(4, 4)) { (ll, lr) => val l = toBM(ll, 2) val r = toBM(lr, 2) - - assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) - } - - forAll(genDenseMatrix(9, 9), genDenseMatrix(9, 9)) { (ll, lr) => - val l = toBM(ll, 3) - val r = toBM(lr, 3) - - assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) - } - - forAll(genDenseMatrix(9, 9), genDenseMatrix(9, 9)) { (ll, lr) => - val l = toBM(ll, 2) - val r = toBM(lr, 2) - - assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) - } - - forAll(genDenseMatrix(2, 10), genDenseMatrix(10, 2)) { (ll, lr) => - val l = toBM(ll, 3) - val r = toBM(lr, 3) - assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) - } - - forAll(genMultipliableDenseMatrices, interestingPosInt) { case ((ll, lr), blockSize) => - val l = toBM(ll, blockSize) - val r = toBM(lr, blockSize) - - assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) - } + } ++ + forAll(genDenseMatrix(9, 9), genDenseMatrix(9, 9)) { (ll, lr) => + val l = toBM(ll, 3) + val r = toBM(lr, 3) + assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) + } ++ + forAll(genDenseMatrix(9, 9), genDenseMatrix(9, 9)) { (ll, lr) => + val l = toBM(ll, 2) + val r = toBM(lr, 2) + assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) + } ++ + forAll(genDenseMatrix(2, 10), genDenseMatrix(10, 2)) { (ll, lr) => + val l = toBM(ll, 3) + val r = toBM(lr, 3) + assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) + } ++ + forAll(genMultipliableDenseMatrices, interestingPosInt) { case ((ll, lr), blockSize) => + val l = toBM(ll, blockSize) + val r = toBM(lr, blockSize) + assertDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) + } } - @Test - def multiplySameAsBreezeRandomized(): Unit = { - forAll(twoMultipliableBlockMatrices) { - case (l: BlockMatrix, r: BlockMatrix) => - val actual = l.dot(r).toBreezeMatrix() - val expected = l.toBreezeMatrix() * r.toBreezeMatrix() - assertDoubleMatrixNaNEqualsNaN(actual, expected) - } + property("multiplySameAsBreezeRandomized") = forAll(twoMultipliableBlockMatrices) { + case (l: BlockMatrix, r: BlockMatrix) => + val actual = l.dot(r).toBreezeMatrix() + val expected = l.toBreezeMatrix() * r.toBreezeMatrix() + assertDoubleMatrixNaNEqualsNaN(actual, expected) } - @Test - def rowwiseMultiplication(): Unit = { + test("rowwiseMultiplication") { val l = toBM( 4, 4, @@ -256,28 +235,23 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { 13, 28, 45, 64), ) - assert(l.rowVectorMul(ctx, v).toBreezeMatrix() == result) + assertEquals(l.rowVectorMul(ctx, v).toBreezeMatrix(), result) } - @Test - def rowwiseMultiplicationRandom(): Unit = { - val g = for { + property("rowwiseMultiplicationRandom") = Prop.forAll( + for { l <- blockMatrixGen() v <- containerOfN[Array, Double](l.nCols.toInt, arbitrary[Double]) } yield (l, v) - - forAll(g) { case (l: BlockMatrix, v: Array[Double]) => - val actual = l.rowVectorMul(ctx, v).toBreezeMatrix() - val repeatedR = (0 until l.nRows.toInt).flatMap(_ => v).toArray - val repeatedRMatrix = new DenseMatrix(v.length, l.nRows.toInt, repeatedR).t - val expected = l.toBreezeMatrix() *:* repeatedRMatrix - - assertDoubleMatrixNaNEqualsNaN(actual, expected) - } + ) { case (l: BlockMatrix, v: Array[Double]) => + val actual = l.rowVectorMul(ctx, v).toBreezeMatrix() + val repeatedR = (0 until l.nRows.toInt).flatMap(_ => v).toArray + val repeatedRMatrix = new DenseMatrix(v.length, l.nRows.toInt, repeatedR).t + val expected = l.toBreezeMatrix() *:* repeatedRMatrix + assertDoubleMatrixNaNEqualsNaN(actual, expected) } - @Test - def colwiseMultiplication(): Unit = { + test("colwiseMultiplication") { val l = toBM( 4, 4, @@ -300,27 +274,23 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { 52, 56, 60, 64), ) - assert(l.colVectorMul(ctx, v).toBreezeMatrix() == result) + assertEquals(l.colVectorMul(ctx, v).toBreezeMatrix(), result) } - @Test - def colwiseMultiplicationRandom(): Unit = { - val g = for { + property("colwiseMultiplicationRandom") = forAll( + for { l <- blockMatrixGen() v <- containerOfN[Array, Double](l.nRows.toInt, arbitrary[Double]) } yield (l, v) - - forAll(g) { case (l: BlockMatrix, v: Array[Double]) => - val actual = l.colVectorMul(ctx, v).toBreezeMatrix() - val repeatedR = (0 until l.nCols.toInt).flatMap(_ => v).toArray - val repeatedRMatrix = new DenseMatrix(v.length, l.nCols.toInt, repeatedR) - val expected = l.toBreezeMatrix() *:* repeatedRMatrix - assertDoubleMatrixNaNEqualsNaN(actual, expected) - } + ) { case (l: BlockMatrix, v: Array[Double]) => + val actual = l.colVectorMul(ctx, v).toBreezeMatrix() + val repeatedR = (0 until l.nCols.toInt).flatMap(_ => v).toArray + val repeatedRMatrix = new DenseMatrix(v.length, l.nCols.toInt, repeatedR) + val expected = l.toBreezeMatrix() *:* repeatedRMatrix + assertDoubleMatrixNaNEqualsNaN(actual, expected) } - @Test - def colwiseAddition(): Unit = { + test("colwiseAddition") { val l = toBM( 4, 4, @@ -343,11 +313,10 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { 17, 18, 19, 20), ) - assert(l.colVectorAdd(ctx, v).toBreezeMatrix() == result) + assertEquals(l.colVectorAdd(ctx, v).toBreezeMatrix(), result) } - @Test - def rowwiseAddition(): Unit = { + test("rowwiseAddition") { val l = toBM( 4, 4, @@ -370,11 +339,10 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { 14, 16, 18, 20), ) - assert(l.rowVectorAdd(ctx, v).toBreezeMatrix() == result) + assertEquals(l.rowVectorAdd(ctx, v).toBreezeMatrix(), result) } - @Test - def diagonalTestTiny(): Unit = { + test("diagonalTestTiny") { val lm = toLM( 3, 4, @@ -386,34 +354,32 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm, blockSize = 2) - assert(m.diagonal().toSeq == Seq(1, 6, 11)) - assert(m.T.diagonal().toSeq == Seq(1, 6, 11)) - assert(m.dot(m.T).diagonal().toSeq == Seq(30, 174, 446)) + assertEquals(m.diagonal().toSeq, Seq[Double](1, 6, 11)) + assertEquals(m.T.diagonal().toSeq, Seq[Double](1, 6, 11)) + assertEquals(m.dot(m.T).diagonal().toSeq, Seq[Double](30, 174, 446)) } - @Test - def diagonalTestRandomized(): Unit = - forAll(squareBlockMatrixGen) { (m: BlockMatrix) => - val lm = m.toBreezeMatrix() - val diagonalLength = math.min(lm.rows, lm.cols) - val diagonal = Array.tabulate(diagonalLength)(i => lm(i, i)) + property("diagonalTestRandomized") = forAll(squareBlockMatrixGen) { (m: BlockMatrix) => + val lm = m.toBreezeMatrix() + val diagonalLength = math.min(lm.rows, lm.cols) + val diagonal = Array.tabulate(diagonalLength)(i => lm(i, i)) - assert( - m.diagonal().toSeq == diagonal.toSeq, - s"lm: $lm\n${m.diagonal().toSeq} != ${diagonal.toSeq}", - ) - } + assertEquals( + m.diagonal().toSeq, + diagonal.toSeq, + s"lm: $lm\n${m.diagonal().toSeq} != ${diagonal.toSeq}", + ) + } - @Test - def fromLocalTest(): Unit = - forAll(arbitrary[DenseMatrix[Double]].flatMap { m => + property("fromLocalTest") = forAll( + arbitrary[DenseMatrix[Double]].flatMap { m => Gen.zip(Gen.const(m), Gen.choose(math.sqrt(m.rows.toDouble).toInt, m.rows + 16)) - }) { case (lm, blockSize) => - assert(lm === BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize).toBreezeMatrix()) } + ) { case (lm, blockSize) => + assertEquals(lm, BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize).toBreezeMatrix()) + } - @Test - def readWriteIdentityTrivial(): Unit = { + test("readWriteIdentityTrivial") { val m = toBM( 4, 4, @@ -426,15 +392,14 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val fname = ctx.createTmpPath("test") m.write(ctx, fname) - assert(m.toBreezeMatrix() == BlockMatrix.read(ctx, fname).toBreezeMatrix()) + assertEquals(m.toBreezeMatrix(), BlockMatrix.read(ctx, fname).toBreezeMatrix()) val fname2 = ctx.createTmpPath("test2") m.write(ctx, fname2, forceRowMajor = true) - assert(m.toBreezeMatrix() == BlockMatrix.read(ctx, fname2).toBreezeMatrix()) + assertEquals(m.toBreezeMatrix(), BlockMatrix.read(ctx, fname2).toBreezeMatrix()) } - @Test - def readWriteIdentityTrivialTransposed(): Unit = { + test("readWriteIdentityTrivialTransposed") { val m = toBM( 4, 4, @@ -447,73 +412,60 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val fname = ctx.createTmpPath("test") m.T.write(ctx, fname) - assert(m.T.toBreezeMatrix() == BlockMatrix.read(ctx, fname).toBreezeMatrix()) + assertEquals(m.T.toBreezeMatrix(), BlockMatrix.read(ctx, fname).toBreezeMatrix()) val fname2 = ctx.createTmpPath("test2") m.T.write(ctx, fname2, forceRowMajor = true) - assert(m.T.toBreezeMatrix() == BlockMatrix.read(ctx, fname2).toBreezeMatrix()) + assertEquals(m.T.toBreezeMatrix(), BlockMatrix.read(ctx, fname2).toBreezeMatrix()) } - @Test - def readWriteIdentityRandom(): Unit = { - forAll(blockMatrixGen()) { (m: BlockMatrix) => - val fname = ctx.createTmpPath("test") - m.write(ctx, fname) - assertDoubleMatrixNaNEqualsNaN( - m.toBreezeMatrix(), - BlockMatrix.read(ctx, fname).toBreezeMatrix(), - ) - } + property("readWriteIdentityRandom") = forAll(blockMatrixGen()) { (m: BlockMatrix) => + val fname = ctx.createTmpPath("test") + m.write(ctx, fname) + assertDoubleMatrixNaNEqualsNaN( + m.toBreezeMatrix(), + BlockMatrix.read(ctx, fname).toBreezeMatrix(), + ) } - @Test - def transpose(): Unit = { - forAll(blockMatrixGen()) { (m: BlockMatrix) => - val transposed = m.toBreezeMatrix().t - assert(transposed.rows == m.nCols) - assert(transposed.cols == m.nRows) - assert(transposed === m.T.toBreezeMatrix()) - } + property("transpose") = forAll(blockMatrixGen()) { (m: BlockMatrix) => + val transposed = m.toBreezeMatrix().t + assertEquals(transposed.rows.toLong, m.nCols) + assertEquals(transposed.cols.toLong, m.nRows) + assertEquals(transposed, m.T.toBreezeMatrix()) } - @Test - def doubleTransposeIsIdentity(): Unit = { - forAll(blockMatrixGen(element = nonExtremeDouble)) { (m: BlockMatrix) => + property("doubleTransposeIsIdentity") = forAll(blockMatrixGen(element = nonExtremeDouble)) { + (m: BlockMatrix) => val mt = m.T.cache() val mtt = m.T.T.cache() - assert(mtt.nRows == m.nRows) - assert(mtt.nCols == m.nCols) + assertEquals(mtt.nRows, m.nRows) + assertEquals(mtt.nCols, m.nCols) assertDoubleMatrixNaNEqualsNaN(mtt.toBreezeMatrix(), m.toBreezeMatrix()) assertDoubleMatrixNaNEqualsNaN(mt.dot(mtt).toBreezeMatrix(), mt.dot(m).toBreezeMatrix()) - } } - @Test - def cachedOpsOK(): Unit = - forAll(twoMultipliableBlockMatrices) { - case (l: BlockMatrix, r: BlockMatrix) => - l.cache() - r.cache() - - val actual = l.dot(r).toBreezeMatrix() - val expected = l.toBreezeMatrix() * r.toBreezeMatrix() - assertDoubleMatrixNaNEqualsNaN(actual, expected) - assertDoubleMatrixNaNEqualsNaN(l.T.cache().T.toBreezeMatrix(), l.toBreezeMatrix()) - } + property("cachedOpsOK") = forAll(twoMultipliableBlockMatrices) { + case (l: BlockMatrix, r: BlockMatrix) => + l.cache() + r.cache() + + val actual = l.dot(r).toBreezeMatrix() + val expected = l.toBreezeMatrix() * r.toBreezeMatrix() + assertDoubleMatrixNaNEqualsNaN(actual, expected) + assertDoubleMatrixNaNEqualsNaN(l.T.cache().T.toBreezeMatrix(), l.toBreezeMatrix()) + } - @Test - def toIRMToHBMIdentity(): Unit = - forAll(blockMatrixGen()) { (m: BlockMatrix) => - val roundtrip = m.toIndexedRowMatrix().toHailBlockMatrix(m.blockSize) + property("toIRMToHBMIdentity") = forAll(blockMatrixGen()) { (m: BlockMatrix) => + val roundtrip = m.toIndexedRowMatrix().toHailBlockMatrix(m.blockSize) - val roundtriplm = roundtrip.toBreezeMatrix() - val lm = m.toBreezeMatrix() + val roundtriplm = roundtrip.toBreezeMatrix() + val lm = m.toBreezeMatrix() - assert(roundtriplm == lm) - } + assertEquals(roundtriplm, lm) + } - @Test - def map2RespectsTransposition(): Unit = { + test("map2RespectsTransposition") { val lm = toLM( 4, 2, @@ -534,15 +486,14 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm) val mt = toBM(lmt) - assert(m.map2(mt.T, _ + _).toBreezeMatrix() === lm + lm) + assertEquals(m.map2(mt.T, _ + _).toBreezeMatrix(), lm + lm) assert( - mt.T.map2(m, _ + _).toBreezeMatrix() === lm + lm, + mt.T.map2(m, _ + _).toBreezeMatrix() == lm + lm, s"${mt.toBreezeMatrix()}\n${mt.T.toBreezeMatrix()}\n${m.toBreezeMatrix()}", ) } - @Test - def map4RespectsTransposition(): Unit = { + test("map4RespectsTransposition") { val lm = toLM( 4, 2, @@ -563,12 +514,14 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm) val mt = toBM(lmt) - assert(m.map4(m, mt.T, mt.T.T.T, _ + _ + _ + _).toBreezeMatrix() === lm + lm + lm + lm) - assert(mt.map4(mt, m.T, mt.T.T, _ + _ + _ + _).toBreezeMatrix() === lm.t + lm.t + lm.t + lm.t) + assertEquals(m.map4(m, mt.T, mt.T.T.T, _ + _ + _ + _).toBreezeMatrix(), lm + lm + lm + lm) + assertEquals( + mt.map4(mt, m.T, mt.T.T, _ + _ + _ + _).toBreezeMatrix(), + lm.t + lm.t + lm.t + lm.t, + ) } - @Test - def mapRespectsTransposition(): Unit = { + test("mapRespectsTransposition") { val lm = toLM( 4, 2, @@ -589,13 +542,12 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm) val mt = toBM(lmt) - assert(m.T.map(_ * 4).toBreezeMatrix() === lm.t.map(_ * 4)) - assert(m.T.T.map(_ * 4).toBreezeMatrix() === lm.map(_ * 4)) - assert(mt.T.map(_ * 4).toBreezeMatrix() === lm.map(_ * 4)) + assertEquals(m.T.map(_ * 4).toBreezeMatrix(), lm.t.map(_ * 4)) + assertEquals(m.T.T.map(_ * 4).toBreezeMatrix(), lm.map(_ * 4)) + assertEquals(mt.T.map(_ * 4).toBreezeMatrix(), lm.map(_ * 4)) } - @Test - def mapWithIndexRespectsTransposition(): Unit = { + test("mapWithIndexRespectsTransposition") { val lm = toLM( 4, 2, @@ -616,22 +568,20 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm) val mt = toBM(lmt) - assert(m.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix() === lm.t.map(_ * 4)) - assert(m.T.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix() === lm.map(_ * 4)) - assert(mt.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix() === lm.map(_ * 4)) - - assert(m.T.mapWithIndex((i, j, x) => i * 10 + j + x).toBreezeMatrix() === - mt.mapWithIndex((i, j, x) => i * 10 + j + x).toBreezeMatrix()) - assert(m.T.mapWithIndex((i, j, x) => x + j * 2 + i + 1).toBreezeMatrix() === - lm.t + lm.t) - assert(mt.mapWithIndex((i, j, x) => x + j * 2 + i + 1).toBreezeMatrix() === - lm.t + lm.t) - assert(mt.T.mapWithIndex((i, j, x) => x + i * 2 + j + 1).toBreezeMatrix() === - lm + lm) + assertEquals(m.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix(), lm.t.map(_ * 4)) + assertEquals(m.T.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix(), lm.map(_ * 4)) + assertEquals(mt.T.mapWithIndex((_, _, x) => x * 4).toBreezeMatrix(), lm.map(_ * 4)) + + assertEquals( + m.T.mapWithIndex((i, j, x) => i * 10 + j + x).toBreezeMatrix(), + mt.mapWithIndex((i, j, x) => i * 10 + j + x).toBreezeMatrix(), + ) + assertEquals(m.T.mapWithIndex((i, j, x) => x + j * 2 + i + 1).toBreezeMatrix(), lm.t + lm.t) + assertEquals(mt.mapWithIndex((i, j, x) => x + j * 2 + i + 1).toBreezeMatrix(), lm.t + lm.t) + assertEquals(mt.T.mapWithIndex((i, j, x) => x + i * 2 + j + 1).toBreezeMatrix(), lm + lm) } - @Test - def map2WithIndexRespectsTransposition(): Unit = { + test("map2WithIndexRespectsTransposition") { val lm = toLM( 4, 2, @@ -652,156 +602,149 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val m = toBM(lm) val mt = toBM(lmt) - assert(m.map2WithIndex(mt.T, (_, _, x, y) => x + y).toBreezeMatrix() === lm + lm) - assert(mt.map2WithIndex(m.T, (_, _, x, y) => x + y).toBreezeMatrix() === lm.t + lm.t) - assert(mt.T.map2WithIndex(m, (_, _, x, y) => x + y).toBreezeMatrix() === lm + lm) - assert(m.T.T.map2WithIndex(mt.T, (_, _, x, y) => x + y).toBreezeMatrix() === lm + lm) - - assert(m.T.map2WithIndex(mt, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix() === - mt.map2WithIndex(m.T, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix()) - assert(m.T.map2WithIndex(m.T, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix() === - mt.map2WithIndex(mt, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix()) - assert(m.T.map2WithIndex(mt, (i, j, x, y) => x + 2 * y + j * 2 + i + 1).toBreezeMatrix() === - 4.0 * lm.t) - assert(mt.map2WithIndex(m.T, (i, j, x, y) => x + 2 * y + j * 2 + i + 1).toBreezeMatrix() === - 4.0 * lm.t) - assert(mt.T.map2WithIndex( - m.T.T, - (i, j, x, y) => 3 * x + 5 * y + i * 2 + j + 1, - ).toBreezeMatrix() === - 9.0 * lm) + assertEquals(m.map2WithIndex(mt.T, (_, _, x, y) => x + y).toBreezeMatrix(), lm + lm) + assertEquals(mt.map2WithIndex(m.T, (_, _, x, y) => x + y).toBreezeMatrix(), lm.t + lm.t) + assertEquals(mt.T.map2WithIndex(m, (_, _, x, y) => x + y).toBreezeMatrix(), lm + lm) + assertEquals(m.T.T.map2WithIndex(mt.T, (_, _, x, y) => x + y).toBreezeMatrix(), lm + lm) + + assertEquals( + m.T.map2WithIndex(mt, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix(), + mt.map2WithIndex(m.T, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix(), + ) + assertEquals( + m.T.map2WithIndex(m.T, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix(), + mt.map2WithIndex(mt, (i, j, x, y) => i * 10 + j + x + y).toBreezeMatrix(), + ) + assertEquals( + m.T.map2WithIndex(mt, (i, j, x, y) => x + 2 * y + j * 2 + i + 1).toBreezeMatrix(), + 4.0 * lm.t, + ) + assertEquals( + mt.map2WithIndex(m.T, (i, j, x, y) => x + 2 * y + j * 2 + i + 1).toBreezeMatrix(), + 4.0 * lm.t, + ) + assertEquals( + mt.T.map2WithIndex( + m.T.T, + (i, j, x, y) => 3 * x + 5 * y + i * 2 + j + 1, + ).toBreezeMatrix(), + 9.0 * lm, + ) } - @Test - def filterCols(): Unit = { + test("filterCols") { val lm = new DenseMatrix[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - scalatest.Inspectors.forAll(Seq(1, 2, 3, 5, 10, 11)) { blockSize => + Seq(1, 2, 3, 5, 10, 11).foreach { blockSize => val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize) - scalatest.Inspectors.forAll { - Seq( - ArraySeq(0), - ArraySeq(1), - ArraySeq(9), - ArraySeq(0, 3, 4, 5, 7), - ArraySeq(1, 4, 5, 7, 8, 9), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), - ) - } { keep => + Seq( + ArraySeq(0), + ArraySeq(1), + ArraySeq(9), + ArraySeq(0, 3, 4, 5, 7), + ArraySeq(1, 4, 5, 7, 8, 9), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ).foreach { keep => val filteredViaBlock = bm.filterCols(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(::, keep).copy - assert(filteredViaBlock === filteredViaBreeze) + assertEquals(filteredViaBlock, filteredViaBreeze) } } } - @Test - def filterColsTranspose(): Unit = { + test("filterColsTranspose") { val lm = new DenseMatrix[Double](9, 10, (0 until 90).map(_.toDouble).toArray) val lmt = lm.t - scalatest.Inspectors.forAll(Seq(2, 3)) { blockSize => + Seq(2, 3).foreach { blockSize => val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize).transpose() - scalatest.Inspectors.forAll { - Seq( - ArraySeq(0), - ArraySeq(1, 4, 5, 7, 8), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), - ) - } { keep => + Seq( + ArraySeq(0), + ArraySeq(1, 4, 5, 7, 8), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), + ).foreach { keep => val filteredViaBlock = bm.filterCols(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lmt(::, keep).copy - assert(filteredViaBlock === filteredViaBreeze) + assertEquals(filteredViaBlock, filteredViaBreeze) } } } - @Test - def filterRows(): Unit = { + test("filterRows") { val lm = new DenseMatrix[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - scalatest.Inspectors.forAll(Seq(2, 3)) { blockSize => + Seq(2, 3).foreach { blockSize => val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize) - scalatest.Inspectors.forAll { - Seq( - ArraySeq(0), - ArraySeq(1, 4, 5, 7, 8), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), - ) - } { keep => + Seq( + ArraySeq(0), + ArraySeq(1, 4, 5, 7, 8), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), + ).foreach { keep => val filteredViaBlock = bm.filterRows(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keep, ::).copy - assert(filteredViaBlock === filteredViaBreeze) + assertEquals(filteredViaBlock, filteredViaBreeze) } } } - @Test - def filterSymmetric(): Unit = { + test("filterSymmetric") { val lm = new DenseMatrix[Double](10, 10, (0 until 100).map(_.toDouble).toArray) - scalatest.Inspectors.forAll(Seq(1, 2, 3, 5, 10, 11)) { blockSize => + Seq(1, 2, 3, 5, 10, 11).foreach { blockSize => val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize) - scalatest.Inspectors.forAll { - Seq( - ArraySeq(0), - ArraySeq(1), - ArraySeq(9), - ArraySeq(0, 3, 4, 5, 7), - ArraySeq(1, 4, 5, 7, 8, 9), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), - ) - } { keep => + Seq( + ArraySeq(0), + ArraySeq(1), + ArraySeq(9), + ArraySeq(0, 3, 4, 5, 7), + ArraySeq(1, 4, 5, 7, 8, 9), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ).foreach { keep => val filteredViaBlock = bm.filter(keep.map(_.toLong), keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keep, keep).copy - assert(filteredViaBlock === filteredViaBreeze) + assertEquals(filteredViaBlock, filteredViaBreeze) } } } - @Test - def filter(): Unit = { + test("filter") { val lm = new DenseMatrix[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - scalatest.Inspectors.forAll(Seq(1, 2, 3, 5, 10, 11)) { blockSize => + Seq(1, 2, 3, 5, 10, 11).foreach { blockSize => val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize) - scalatest.Inspectors.forAll { - for { - keepRows <- Seq( - ArraySeq(1), - ArraySeq(0, 3, 4, 5, 7), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), - ) - keepCols <- Seq( - ArraySeq(2), - ArraySeq(1, 4, 5, 7, 8, 9), - ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), - ) - } yield (keepRows, keepCols) - } { case (keepRows, keepCols) => + (for { + keepRows <- Seq( + ArraySeq(1), + ArraySeq(0, 3, 4, 5, 7), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8), + ) + keepCols <- Seq( + ArraySeq(2), + ArraySeq(1, 4, 5, 7, 8, 9), + ArraySeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) + } yield (keepRows, keepCols)).foreach { case (keepRows, keepCols) => val filteredViaBlock = bm.filter(keepRows.map(_.toLong), keepCols.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keepRows, keepCols).copy - assert(filteredViaBlock === filteredViaBreeze) + assertEquals(filteredViaBlock, filteredViaBreeze) } } } - @Test - def writeLocalAsBlockTest(): Unit = { + test("writeLocalAsBlockTest") { val lm = new DenseMatrix[Double](10, 10, (0 until 100).map(_.toDouble).toArray) - scalatest.Inspectors.forAll(Seq(1, 2, 3, 5, 10, 11)) { blockSize => + Seq(1, 2, 3, 5, 10, 11).foreach { blockSize => val fname = ctx.createTmpPath("test") lm.writeBlockMatrix(fs, fname, blockSize) - assert(lm === BlockMatrix.read(ctx, fname).toBreezeMatrix()) + assertEquals(lm, BlockMatrix.read(ctx, fname).toBreezeMatrix()) } } - @Test - def randomTest(): Unit = { + test("randomTest") { var lm1 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() var lm2 = @@ -812,10 +755,10 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = false).toBreezeMatrix() println(lm1) - assert(lm1 === lm2) - assert(lm1 !== lm3) - assert(lm1 !== lm4) - assert(lm3 !== lm4) + assertEquals(lm1, lm2) + assert(lm1 != lm3) + assert(lm1 != lm4) + assert(lm3 != lm4) assert(lm1.data.forall(x => x >= 0 && x <= 1)) lm1 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = true).toBreezeMatrix() @@ -823,36 +766,35 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { lm3 = BlockMatrix.random(5, 10, 2, staticUID = 2, nonce = 1, gaussian = true).toBreezeMatrix() lm4 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = true).toBreezeMatrix() - assert(lm1 === lm2) - assert(lm1 !== lm3) - assert(lm1 !== lm4) - assert(lm3 !== lm4) + assertEquals(lm1, lm2) + assert(lm1 != lm3) + assert(lm1 != lm4) + assert(lm3 != lm4) } - @Test - def testEntriesTable(): Unit = { + test("EntriesTable") { val data = (0 until 90).map(_.toDouble).toArray val lm = new DenseMatrix[Double](9, 10, data) val expectedEntries = data.map(x => ((x % 9).toLong, (x / 9).toLong, x)).toSet val expectedSignature = TStruct("i" -> TInt64, "j" -> TInt64, "entry" -> TFloat64) - scalatest.Inspectors.forAll(Seq(1, 4, 10)) { blockSize => + Seq(1, 4, 10).foreach { blockSize => val entriesLiteral = TableLiteral(toBM(lm, blockSize).entriesTable(ctx), theHailClassLoader) - assert(entriesLiteral.typ.rowType == expectedSignature) + assertEquals(entriesLiteral.typ.rowType, expectedSignature) val rows = CompileAndEvaluate[IndexedSeq[Row]]( ctx, GetField(TableCollect(entriesLiteral), "rows"), lower = LoweringPipeline.relationalLowerer, ) - val entries = rows.map(row => (row.get(0), row.get(1), row.get(2))).toSet + val entries = + rows.map(row => (row.getAs[Long](0), row.getAs[Long](1), row.getAs[Double](2))).toSet // block size affects order of rows in table, but sets will be the same - assert(entries === expectedEntries) + assertEquals(entries, expectedEntries) } } - @Test - def testEntriesTableWhenKeepingOnlySomeBlocks(): Unit = { + test("EntriesTableWhenKeepingOnlySomeBlocks") { val data = (0 until 50).map(_.toDouble).toArray val lm = new DenseMatrix[Double](5, 10, data) val bm = toBM(lm, blockSize = 2) @@ -877,8 +819,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(expected sameElements Array[Double](0, 5, 20, 25, 1, 6, 21, 26, 2, 7, 3, 8)) } - @Test - def testPowSqrt(): Unit = { + test("PowSqrt") { val lm = new DenseMatrix[Double](2, 3, Array(0.0, 1.0, 4.0, 9.0, 16.0, 25.0)) val bm = BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize = 2) val expected = new DenseMatrix[Double](2, 3, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) @@ -891,8 +832,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { def filteredEquals(bm1: BlockMatrix, bm2: BlockMatrix): Boolean = bm1.blocks.collect() sameElements bm2.blocks.collect() - @Test - def testSparseFilterEdges(): Unit = { + test("SparseFilterEdges") { val lm = new DenseMatrix[Double](12, 12, (0 to 143).map(_.toDouble).toArray) val bm = toBM(lm, blockSize = 5) @@ -905,23 +845,27 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { 143).map(_.toDouble)) assert(onlyEightColEleven.toArray sameElements Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 142, 143).map(_.toDouble)) - assert(onlyEightCornerFour == new DenseMatrix[Double](2, 2, Array(130.0, 131.0, 142.0, 143.0))) + assertEquals( + onlyEightCornerFour, + new DenseMatrix[Double](2, 2, Array(130.0, 131.0, 142.0, 143.0)), + ) } - @Test - def testSparseTransposeMaybeBlocks(): Unit = { + test("SparseTransposeMaybeBlocks") { val lm = new DenseMatrix[Double](9, 12, (0 to 107).map(_.toDouble).toArray) val bm = toBM(lm, blockSize = 3) val sparse = bm.filterBand(0, 0, true) - assert(sparse.transpose().gp.partitionIndexToBlockIndex.get.toIndexedSeq == IndexedSeq( - 0, - 5, - 10, - )) + assertEquals( + sparse.transpose().gp.partitionIndexToBlockIndex.get.toIndexedSeq, + IndexedSeq( + 0, + 5, + 10, + ), + ) } - @Test - def filterRowsRectangleSum(): Unit = { + test("filterRowsRectangleSum") { val nRows = 10 val nCols = 50 val bm = BlockMatrix.fill(nRows.toLong, nCols.toLong, 2, 1) @@ -935,8 +879,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(summed sameElements expected) } - @Test - def testFilterBlocks(): Unit = { + test("FilterBlocks") { val lm = toLM( 4, 4, @@ -961,10 +904,10 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val localBlocks = Array(lm(0 to 1, 0 to 1), lm(2 to 3, 0 to 1), lm(0 to 1, 2 to 3), lm(2 to 3, 2 to 3)) - scalatest.Inspectors.forAll(keepArray) { keep => + keepArray.foreach { keep => val fbm = bm.filterBlocks(keep) - assert(fbm.blocks.count() == keep.length) + assertEquals(fbm.blocks.count(), keep.length.toLong) assert(fbm.blocks.collect().forall { case ((i, j), block) => block == localBlocks(fbm.gp.coordinatesBlock(i, j)) }) @@ -986,8 +929,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { )) } - @Test - def testSparseBlockMatrixIO(): Unit = { + test("SparseBlockMatrixIO") { val lm = toLM( 4, 4, @@ -1022,25 +964,24 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { } // test toBlockMatrix, toIndexedRowMatrix, toRowMatrix, read/write identity - scalatest.Inspectors.forAll(keepArray) { keep => + keepArray.foreach { keep => val fbm = bm.filterBlocks(keep) val flm = filterBlocks(keep) - assert(fbm.toBreezeMatrix() === flm) + assertEquals(fbm.toBreezeMatrix(), flm) - assert(flm === fbm.toIndexedRowMatrix().toHailBlockMatrix().toBreezeMatrix()) + assertEquals(flm, fbm.toIndexedRowMatrix().toHailBlockMatrix().toBreezeMatrix()) val fname = ctx.createTmpPath("test") fbm.write(ctx, fname, forceRowMajor = true) - assert(RowMatrix.readBlockMatrix(ctx, fname, 3).toBreezeMatrix() === flm) + assertEquals(RowMatrix.readBlockMatrix(ctx, fname, 3).toBreezeMatrix(), flm) assert(filteredEquals(fbm, BlockMatrix.read(ctx, fname))) } } - @Test - def testSparseBlockMatrixMathAndFilter(): Unit = { + test("SparseBlockMatrixMathAndFilter") { val lm = toLM( 4, 4, @@ -1079,7 +1020,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val v = Array(1.0, 2.0, 3.0, 4.0) // test transpose, diagonal, math ops, filter ops - scalatest.Inspectors.forAll(keepArray) { keep => + keepArray.foreach { keep => println(s"Test says keep block: $keep") val fbm = bm.filterBlocks(keep) val flm = filterBlocks(keep) @@ -1112,26 +1053,26 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(filteredEquals(fbm.sqrt(), bm.sqrt().filterBlocks(keep))) assert(filteredEquals(fbm.pow(3), bm.pow(3).filterBlocks(keep))) - assert(fbm.dot(fbm).toBreezeMatrix() === flm * flm) + assertEquals(fbm.dot(fbm).toBreezeMatrix(), flm * flm) // densifying ops - assert((fbm + 2).toBreezeMatrix() === flm + 2.0) - assert((2 + fbm).toBreezeMatrix() === flm + 2.0) - assert((fbm - 2).toBreezeMatrix() === flm - 2.0) - assert((2 - fbm).toBreezeMatrix() === 2.0 - flm) + assertEquals((fbm + 2).toBreezeMatrix(), flm + 2.0) + assertEquals((2 + fbm).toBreezeMatrix(), flm + 2.0) + assertEquals((fbm - 2).toBreezeMatrix(), flm - 2.0) + assertEquals((2 - fbm).toBreezeMatrix(), 2.0 - flm) - assert(fbm.rowVectorAdd(ctx, v).toBreezeMatrix() === flm(*, ::) + BDV(v)) - assert(fbm.rowVectorSub(ctx, v).toBreezeMatrix() === flm(*, ::) - BDV(v)) - assert(fbm.reverseRowVectorSub(ctx, v).toBreezeMatrix() === -(flm(*, ::) - BDV(v))) + assertEquals(fbm.rowVectorAdd(ctx, v).toBreezeMatrix(), flm(*, ::) + BDV(v)) + assertEquals(fbm.rowVectorSub(ctx, v).toBreezeMatrix(), flm(*, ::) - BDV(v)) + assertEquals(fbm.reverseRowVectorSub(ctx, v).toBreezeMatrix(), -(flm(*, ::) - BDV(v))) - assert(fbm.colVectorAdd(ctx, v).toBreezeMatrix() === flm(::, *) + BDV(v)) - assert(fbm.colVectorSub(ctx, v).toBreezeMatrix() === flm(::, *) - BDV(v)) - assert(fbm.reverseColVectorSub(ctx, v).toBreezeMatrix() === -(flm(::, *) - BDV(v))) + assertEquals(fbm.colVectorAdd(ctx, v).toBreezeMatrix(), flm(::, *) + BDV(v)) + assertEquals(fbm.colVectorSub(ctx, v).toBreezeMatrix(), flm(::, *) - BDV(v)) + assertEquals(fbm.reverseColVectorSub(ctx, v).toBreezeMatrix(), -(flm(::, *) - BDV(v))) // filter ops - assert(fbm.filterRows(ArraySeq(1, 2)).toBreezeMatrix() === flm(1 to 2, ::)) - assert(fbm.filterCols(ArraySeq(1, 2)).toBreezeMatrix() === flm(::, 1 to 2)) - assert(fbm.filter(ArraySeq(1, 2), ArraySeq(1, 2)).toBreezeMatrix() === flm(1 to 2, 1 to 2)) + assertEquals(fbm.filterRows(ArraySeq(1, 2)).toBreezeMatrix(), flm(1 to 2, ::)) + assertEquals(fbm.filterCols(ArraySeq(1, 2)).toBreezeMatrix(), flm(::, 1 to 2)) + assertEquals(fbm.filter(ArraySeq(1, 2), ArraySeq(1, 2)).toBreezeMatrix(), flm(1 to 2, 1 to 2)) } val bm0 = bm.filterBlocks(ArraySeq(0)) @@ -1147,26 +1088,26 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { // test +/- with mismatched blocks assert(filteredEquals(bm0 + bm13, bm.filterBlocks(ArraySeq(0, 1, 3)))) - assert((bm0 + bm).toBreezeMatrix() === lm0 + lm) - assert((bm + bm0).toBreezeMatrix() === lm + lm0) - assert( - (bm0 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm123).toBreezeMatrix() === - lm0 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm123 + assertEquals((bm0 + bm).toBreezeMatrix(), lm0 + lm) + assertEquals((bm + bm0).toBreezeMatrix(), lm + lm0) + assertEquals( + (bm0 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm123).toBreezeMatrix(), + lm0 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm123, ) - assert( - (bm123 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm0).toBreezeMatrix() === - lm123 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm0 + assertEquals( + (bm123 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm0).toBreezeMatrix(), + lm123 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm0, ) - assert((bm0 - bm).toBreezeMatrix() === lm0 - lm) - assert((bm - bm0).toBreezeMatrix() === lm - lm0) - assert( - (bm0 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm123).toBreezeMatrix() === - lm0 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm123 + assertEquals((bm0 - bm).toBreezeMatrix(), lm0 - lm) + assertEquals((bm - bm0).toBreezeMatrix(), lm - lm0) + assertEquals( + (bm0 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm123).toBreezeMatrix(), + lm0 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm123, ) - assert( - (bm123 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm0).toBreezeMatrix() === - lm123 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm0 + assertEquals( + (bm123 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm0).toBreezeMatrix(), + lm123 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm0, ) // test * with mismatched blocks @@ -1193,8 +1134,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { interceptFatal(notSupported)(bm0.pow(-1)) } - @Test - def testRealizeBlocks(): Unit = { + test("RealizeBlocks") { val lm = toLM( 4, 4, @@ -1231,7 +1171,7 @@ class BlockMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(filteredEquals(bm.densify(), bm)) assert(filteredEquals(bm.realizeBlocks(None), bm)) - scalatest.Inspectors.forAll(keepArray) { keep => + keepArray.foreach { keep => val fbm = bm.filterBlocks(keep) val flm = filterBlocks(keep) diff --git a/hail/hail/test/src/is/hail/linalg/GridPartitionerSuite.scala b/hail/hail/test/src/is/hail/linalg/GridPartitionerSuite.scala index a7f3cf91cec..8d2d25e3936 100644 --- a/hail/hail/test/src/is/hail/linalg/GridPartitionerSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/GridPartitionerSuite.scala @@ -2,26 +2,22 @@ package is.hail.linalg import is.hail.collection.compat.immutable.ArraySeq -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class GridPartitionerSuite extends TestNGSuite { +class GridPartitionerSuite extends munit.FunSuite { private def assertLayout(hg: GridPartitioner, layout: ((Int, Int), Int)*): Unit = { layout.foreach { case ((i, j), p) => - assert(hg.coordinatesBlock(i, j) === p, s"at coordinates ${(i, j)}") + assertEquals(hg.coordinatesBlock(i, j), p, s"at coordinates ${(i, j)}") } layout.foreach { case ((i, j), p) => - assert(hg.blockCoordinates(p) === ((i, j)), s"at pid $p") + assertEquals(hg.blockCoordinates(p), ((i, j)), s"at pid $p") } } - @Test - def squareIsColumnMajor(): Unit = + test("squareIsColumnMajor") { assertLayout(GridPartitioner(2, 4, 4), (0, 0) -> 0, (1, 0) -> 1, (0, 1) -> 2, (1, 1) -> 3) + } - @Test - def rectangleMoreRowsIsColumnMajor(): Unit = { + test("rectangleMoreRowsIsColumnMajor") { assertLayout( GridPartitioner(2, 6, 4), (0, 0) -> 0, @@ -33,8 +29,7 @@ class GridPartitionerSuite extends TestNGSuite { ) } - @Test - def rectangleMoreColsIsColumnMajor(): Unit = { + test("rectangleMoreColsIsColumnMajor") { assertLayout( GridPartitioner(2, 4, 6), (0, 0) -> 0, @@ -46,8 +41,7 @@ class GridPartitionerSuite extends TestNGSuite { ) } - @Test - def bandedBlocksTest(): Unit = { + test("bandedBlocks") { // 0 3 6 9 // 1 4 7 10 // 2 5 8 11 @@ -76,8 +70,7 @@ class GridPartitionerSuite extends TestNGSuite { } } - @Test - def rectangularBlocksTest(): Unit = { + test("rectangularBlocks") { // 0 3 6 9 // 1 4 7 10 // 2 5 8 11 @@ -85,23 +78,28 @@ class GridPartitionerSuite extends TestNGSuite { val gp2 = GridPartitioner(10, 21, 31) for (gp <- Seq(gp1, gp2)) { - assert(gp.rectangleBlocks(ArraySeq[Long](0, 1, 0, 1)) == ArraySeq(0)) - assert(gp.rectanglesBlocks(ArraySeq(ArraySeq[Long](0, 1, 0, 1))) == ArraySeq(0)) - - assert(gp.rectangleBlocks(ArraySeq[Long](0, 10, 0, 10)) == ArraySeq(0)) - - assert(gp.rectangleBlocks(ArraySeq[Long](9, 11, 9, 11)) == ArraySeq(0, 1, 3, 4)) - assert(gp.rectanglesBlocks(ArraySeq(ArraySeq[Long](9, 11, 9, 11))) == ArraySeq(0, 1, 3, 4)) - - assert(gp.rectangleBlocks(ArraySeq[Long](10, 20, 10, 30)) == ArraySeq(4, 7)) - - assert(gp.rectanglesBlocks(ArraySeq( - ArraySeq[Long](9, 11, 9, 11), - ArraySeq(10, 20, 10, 30), - ArraySeq(0, 1, 20, 21), - ArraySeq(20, 21, 20, 31), - )) - == ArraySeq(0, 1, 3, 4, 6, 7, 8, 11)) + assertEquals(gp.rectangleBlocks(ArraySeq[Long](0, 1, 0, 1)), ArraySeq(0)) + assertEquals(gp.rectanglesBlocks(ArraySeq(ArraySeq[Long](0, 1, 0, 1))), ArraySeq(0)) + + assertEquals(gp.rectangleBlocks(ArraySeq[Long](0, 10, 0, 10)), ArraySeq(0)) + + assertEquals(gp.rectangleBlocks(ArraySeq[Long](9, 11, 9, 11)), ArraySeq(0, 1, 3, 4)) + assertEquals( + gp.rectanglesBlocks(ArraySeq(ArraySeq[Long](9, 11, 9, 11))), + ArraySeq(0, 1, 3, 4), + ) + + assertEquals(gp.rectangleBlocks(ArraySeq[Long](10, 20, 10, 30)), ArraySeq(4, 7)) + + assertEquals( + gp.rectanglesBlocks(ArraySeq( + ArraySeq[Long](9, 11, 9, 11), + ArraySeq(10, 20, 10, 30), + ArraySeq(0, 1, 20, 21), + ArraySeq(20, 21, 20, 31), + )), + ArraySeq(0, 1, 3, 4, 6, 7, 8, 11), + ) assert(gp.rectangleBlocks(ArraySeq[Long](0, 21, 0, 31)) == (0 until 12)) } diff --git a/hail/hail/test/src/is/hail/linalg/MatrixSparsitySuite.scala b/hail/hail/test/src/is/hail/linalg/MatrixSparsitySuite.scala index 09b1321b537..c481739b343 100644 --- a/hail/hail/test/src/is/hail/linalg/MatrixSparsitySuite.scala +++ b/hail/hail/test/src/is/hail/linalg/MatrixSparsitySuite.scala @@ -1,11 +1,9 @@ package is.hail.linalg +import is.hail.TestCaseSupport import is.hail.collection.compat.immutable.ArraySeq -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.{DataProvider, Test} - -class MatrixSparsitySuite extends TestNGSuite { +class MatrixSparsitySuite extends munit.FunSuite with TestCaseSupport { def newToOldReference(from: MatrixSparsity, to: MatrixSparsity): IndexedSeq[Integer] = to.definedCoords.map { coords => val i = from.definedCoords.indexOf(coords) @@ -23,50 +21,79 @@ class MatrixSparsitySuite extends TestNGSuite { MatrixSparsity.apply(4, 3, ArraySeq()), ) - @DataProvider(name = "sparsity_pairs_4_3") - def sparsityPairs43(): Array[Array[Object]] = ( + val sparsitySubsetPairs43: Array[(MatrixSparsity, MatrixSparsity.Sparse)] = ( for { s1 <- sparsities_4_3 s2 <- sparsities_4_3 if s2.isInstanceOf[MatrixSparsity.Sparse] - } yield Array[Object](s1, s2) + if isSubset(s1, s2) + } yield (s1, s2.asInstanceOf[MatrixSparsity.Sparse]) ).toArray - @DataProvider(name = "sparsity_subset_pairs_4_3") - def sparsitySubsetPairs(): Array[Array[Object]] = ( + val sparsityPairs43: Array[(MatrixSparsity, MatrixSparsity.Sparse)] = ( for { s1 <- sparsities_4_3 s2 <- sparsities_4_3 if s2.isInstanceOf[MatrixSparsity.Sparse] - if isSubset(s1, s2) - } yield Array[Object](s1, s2) + } yield (s1, s2.asInstanceOf[MatrixSparsity.Sparse]) ).toArray - @Test(dataProvider = "sparsity_subset_pairs_4_3") - def newToOld(s1: MatrixSparsity, s2: MatrixSparsity.Sparse): Unit = - assertResult(newToOldReference(s1, s2))(s1.newToOldPos(s2)) + object checkNewToOld extends TestCases { + def apply( + s1: MatrixSparsity, + s2: MatrixSparsity.Sparse, + )(implicit loc: munit.Location + ): Unit = test("newToOld") { + assert(s1.newToOldPos(s2) == newToOldReference(s1, s2)) + } + } + + sparsitySubsetPairs43.foreach { case (s1, s2) => checkNewToOld(s1, s2) } - @Test(dataProvider = "sparsity_pairs_4_3") - def newToOldNonSubset(s1: MatrixSparsity, s2: MatrixSparsity.Sparse): Unit = - assertResult(newToOldReference(s1, s2))(s1.newToOldPosNonSubset(s2)) + object checkNewToOldNonSubset extends TestCases { + def apply( + s1: MatrixSparsity, + s2: MatrixSparsity.Sparse, + )(implicit loc: munit.Location + ): Unit = test("newToOldNonSubset") { + assert(s1.newToOldPosNonSubset(s2) == newToOldReference(s1, s2)) + } + } + + sparsityPairs43.foreach { case (s1, s2) => checkNewToOldNonSubset(s1, s2) } def sparsities_0_0: Iterator[MatrixSparsity] = Iterator( MatrixSparsity.dense(0, 0), MatrixSparsity.apply(0, 0, ArraySeq()), ) - @DataProvider(name = "sparsity_pairs_0_0") - def sparsityPairs00(): Array[Array[Object]] = ( + val sparsityPairs00: Array[(MatrixSparsity, MatrixSparsity.Sparse)] = ( for { s1 <- sparsities_0_0 - } yield Array[Object](s1, MatrixSparsity.apply(0, 0, ArraySeq())) + } yield (s1, MatrixSparsity.apply(0, 0, ArraySeq()).asInstanceOf[MatrixSparsity.Sparse]) ).toArray - @Test(dataProvider = "sparsity_pairs_0_0") - def newToOldDegenerate(s1: MatrixSparsity, s2: MatrixSparsity.Sparse): Unit = - assertResult(newToOldReference(s1, s2))(s1.newToOldPos(s2)) + object checkNewToOldDegenerate extends TestCases { + def apply( + s1: MatrixSparsity, + s2: MatrixSparsity.Sparse, + )(implicit loc: munit.Location + ): Unit = test("newToOldDegenerate") { + assert(s1.newToOldPos(s2) == newToOldReference(s1, s2)) + } + } + + sparsityPairs00.foreach { case (s1, s2) => checkNewToOldDegenerate(s1, s2) } + + object checkNewToOldNonSubsetDegenerate extends TestCases { + def apply( + s1: MatrixSparsity, + s2: MatrixSparsity.Sparse, + )(implicit loc: munit.Location + ): Unit = test("newToOldNonSubsetDegenerate") { + assert(s1.newToOldPosNonSubset(s2) == newToOldReference(s1, s2)) + } + } - @Test(dataProvider = "sparsity_pairs_0_0") - def newToOldNonSubsetDegenerate(s1: MatrixSparsity, s2: MatrixSparsity.Sparse): Unit = - assertResult(newToOldReference(s1, s2))(s1.newToOldPosNonSubset(s2)) + sparsityPairs00.foreach { case (s1, s2) => checkNewToOldNonSubsetDegenerate(s1, s2) } } diff --git a/hail/hail/test/src/is/hail/linalg/RichDenseMatrixDoubleSuite.scala b/hail/hail/test/src/is/hail/linalg/RichDenseMatrixDoubleSuite.scala index b8758b1219b..1ea729f594b 100644 --- a/hail/hail/test/src/is/hail/linalg/RichDenseMatrixDoubleSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/RichDenseMatrixDoubleSuite.scala @@ -4,33 +4,30 @@ import is.hail.HailSuite import is.hail.linalg.implicits._ import breeze.linalg.{DenseMatrix => BDM} -import org.testng.annotations.Test class RichDenseMatrixDoubleSuite extends HailSuite { - @Test - def readWriteBDM(): Unit = { + test("readWriteBDM") { val m = BDM.rand[Double](256, 129) // 33024 doubles val fname = ctx.createTmpPath("test") m.write(fs, fname, bufferSpec = BlockMatrix.bufferSpec) val m2 = RichDenseMatrixDouble.read(fs, fname, BlockMatrix.bufferSpec) - assert(m === m2) + assertEquals(m, m2) } - @Test - def testReadWriteDoubles(): Unit = { + test("ReadWriteDoubles") { val file = ctx.createTmpPath("test") val m = BDM.rand[Double](50, 100) RichDenseMatrixDouble.exportToDoubles(fs, file, m, forceRowMajor = false): Unit val m2 = RichDenseMatrixDouble.importFromDoubles(fs, file, 50, 100, rowMajor = false) - assert(m === m2) + assertEquals(m, m2) val fileT = ctx.createTmpPath("test2") val mT = m.t RichDenseMatrixDouble.exportToDoubles(fs, fileT, mT, forceRowMajor = true): Unit val lmT2 = RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 50, rowMajor = true) - assert(mT === lmT2) + assertEquals(mT, lmT2) interceptFatal("Premature") { RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 100, rowMajor = true) diff --git a/hail/hail/test/src/is/hail/linalg/RichIndexedRowMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/RichIndexedRowMatrixSuite.scala index b0250b8d5fe..59ba2e16c14 100644 --- a/hail/hail/test/src/is/hail/linalg/RichIndexedRowMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/RichIndexedRowMatrixSuite.scala @@ -8,14 +8,11 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} import org.apache.spark.rdd.RDD -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.Test /** Testing RichIndexedRowMatrix. */ class RichIndexedRowMatrixSuite extends HailSuite { - @Test def testToBlockMatrixDense(): Unit = { + test("toBlockMatrixDense") { val nRows = 9L val nCols = 6L val data = Seq( @@ -33,25 +30,25 @@ class RichIndexedRowMatrixSuite extends HailSuite { val irm = new IndexedRowMatrix(indexedRows) val irmLocal = irm.toBlockMatrix().toLocalMatrix() - forAll(Seq(1, 2, 3, 4, 6, 7, 9, 10)) { blockSize => + Seq(1, 2, 3, 4, 6, 7, 9, 10).foreach { blockSize => val blockMat = irm.toHailBlockMatrix(blockSize) - assert(blockMat.nRows === nRows) - assert(blockMat.nCols === nCols) + assertEquals(blockMat.nRows, nRows) + assertEquals(blockMat.nCols, nCols) val blockMatAsBreeze = blockMat.toBreezeMatrix() - assert(blockMatAsBreeze.rows == irmLocal.numRows) - assert(blockMatAsBreeze.cols == irmLocal.numCols) - assert(blockMatAsBreeze.toArray.toIndexedSeq == irmLocal.toArray.toIndexedSeq) + assertEquals(blockMatAsBreeze.rows, irmLocal.numRows) + assertEquals(blockMatAsBreeze.cols, irmLocal.numCols) + assertEquals(blockMatAsBreeze.toArray.toIndexedSeq, irmLocal.toArray.toIndexedSeq) } - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { irm.toHailBlockMatrix(-1) - } - assertThrows[IllegalArgumentException] { + }: Unit + intercept[IllegalArgumentException] { irm.toHailBlockMatrix(0) - } + }: Unit } - @Test def emptyBlocks(): Unit = { + test("emptyBlocks") { val nRows = 9 val nCols = 2 val data = Seq( @@ -65,15 +62,16 @@ class RichIndexedRowMatrixSuite extends HailSuite { val irmLocal = irm.toBlockMatrix().toLocalMatrix() val m = irm.toHailBlockMatrix(2) - assert(m.nRows == nRows) - assert(m.nCols == nCols) + assertEquals(m.nRows, nRows.toLong) + assertEquals(m.nCols, nCols.toLong) val blockMatAsBreeze = m.toBreezeMatrix() - assert(blockMatAsBreeze.toArray.toIndexedSeq == irmLocal.toArray.toIndexedSeq) - assert(m.blocks.count() == 5) + assertEquals(blockMatAsBreeze.toArray.toIndexedSeq, irmLocal.toArray.toIndexedSeq) + assertEquals(m.blocks.count(), 5L) m.dot(m.T).toBreezeMatrix(): Unit // assert no exception - assert(m.mapWithIndex { case (i, j, v) => i + 10 * j + v }.toBreezeMatrix() === + assertEquals( + m.mapWithIndex { case (i, j, v) => i + 10 * j + v }.toBreezeMatrix(), new BDM[Double]( nRows, nCols, @@ -81,6 +79,7 @@ class RichIndexedRowMatrixSuite extends HailSuite { 0.0, 1.0, 2.0, 4.0, 5.0, 6.0, 6.0, 7.0, 9.0, 10.0, 11.0, 12.0, 15.0, 16.0, 17.0, 16.0, 17.0, 20.0, ), - )) + ), + ) } } diff --git a/hail/hail/test/src/is/hail/linalg/RowMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/RowMatrixSuite.scala index d6d1b836824..b0861e64f26 100644 --- a/hail/hail/test/src/is/hail/linalg/RowMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/RowMatrixSuite.scala @@ -3,14 +3,10 @@ package is.hail.linalg import is.hail.HailSuite import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.ExportType -import is.hail.scalacheck._ import breeze.linalg.DenseMatrix -import org.scalatest -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test -class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class RowMatrixSuite extends HailSuite { private def rowArrayToRowMatrix( a: IndexedSeq[Array[Double]], nPartitions: Int = sc.defaultParallelism, @@ -37,8 +33,7 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { new DenseMatrix[Double](nRows, nCols, a.view.flatten.toArray, 0, nCols, isTranspose = true) } - @Test - def localizeRowMatrix(): Unit = { + test("localizeRowMatrix") { val fname = ctx.createTmpPath("test") val rowArrays = ArraySeq( @@ -51,11 +46,10 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { BlockMatrix.fromBreezeMatrix(ctx, localMatrix).write(ctx, fname) - assert(rowMatrix.toBreezeMatrix() === localMatrix) + assertEquals(rowMatrix.toBreezeMatrix(), localMatrix) } - @Test - def readBlockSmall(): Unit = { + test("readBlockSmall") { val fname = ctx.createTmpPath("test") val localMatrix = DenseMatrix( @@ -67,29 +61,26 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val rowMatrixFromBlock = RowMatrix.readBlockMatrix(ctx, fname, 1) - assert(rowMatrixFromBlock.toBreezeMatrix() == localMatrix) + assertEquals(rowMatrixFromBlock.toBreezeMatrix(), localMatrix) } - @Test - def readBlock(): Unit = - forAll(genDenseMatrix(9, 10)) { lm => - val fname = ctx.createTmpPath("test") - scalatest.Inspectors.forAll { - cartesian( - Seq(1, 2, 3, 4, 6, 7, 9, 10), - Seq(1, 2, 4, 9, 11), - ) - } { case (blockSize, partSize) => - BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize).write( - ctx, - fname, - overwrite = true, - forceRowMajor = true, - ) - val rowMatrix = RowMatrix.readBlockMatrix(ctx, fname, partSize) - assert(rowMatrix.toBreezeMatrix() === lm) - } + test("readBlock") { + val lm = DenseMatrix.create(9, 10, Array.tabulate(9 * 10)(_.toDouble)) + val fname = ctx.createTmpPath("test") + cartesian( + Seq(1, 2, 3, 4, 6, 7, 9, 10), + Seq(1, 2, 4, 9, 11), + ).foreach { case (blockSize, partSize) => + BlockMatrix.fromBreezeMatrix(ctx, lm, blockSize).write( + ctx, + fname, + overwrite = true, + forceRowMajor = true, + ) + val rowMatrix = RowMatrix.readBlockMatrix(ctx, fname, partSize) + assertEquals(rowMatrix.toBreezeMatrix(), lm) } + } private def readCSV(fname: String): Array[Array[Double]] = fs.readLines(fname)(it => @@ -101,11 +92,13 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { private def exportImportAssert(`export`: (String) => Unit, expected: Array[Double]*): Unit = { val fname = ctx.createTmpPath("test") `export`(fname) - assert(readCSV(fname) === expected.toArray[Array[Double]]) + assertEquals( + readCSV(fname).map(_.toSeq).toSeq, + expected.toArray[Array[Double]].map(_.toSeq).toSeq, + ) } - @Test - def exportWithIndex(): Unit = { + test("exportWithIndex") { val rowArrays = ArraySeq( Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0), @@ -132,8 +125,7 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } - @Test - def exportSquare(): Unit = { + test("exportSquare") { val rowArrays = ArraySeq( Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0), @@ -208,8 +200,7 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } - @Test - def exportWide(): Unit = { + test("exportWide") { val rowArrays = ArraySeq( Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0), @@ -280,8 +271,7 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } - @Test - def exportTall(): Unit = { + test("exportTall") { val rowArrays = ArraySeq( Array(1.0, 2.0), Array(4.0, 5.0), @@ -354,8 +344,7 @@ class RowMatrixSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ) } - @Test - def exportBig(): Unit = { + test("exportBig") { val rowArrays: ArraySeq[Array[Double]] = ArraySeq.tabulate(20)(r => Array.tabulate(30)(c => 30.0 * c + r)) val rowMatrix = rowArrayToRowMatrix(rowArrays) diff --git a/hail/hail/test/src/is/hail/linalg/RowPartitionerSuite.scala b/hail/hail/test/src/is/hail/linalg/RowPartitionerSuite.scala index 0bbaa1363a7..777d87d90ca 100644 --- a/hail/hail/test/src/is/hail/linalg/RowPartitionerSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/RowPartitionerSuite.scala @@ -3,25 +3,20 @@ package is.hail.linalg import is.hail.collection.compat.immutable.ArraySeq import org.scalacheck.Arbitrary.arbitrary -import org.scalatest -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test +import org.scalacheck.Prop._ -class RowPartitionerSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { - @Test - def testGetPartition(): Unit = { +class RowPartitionerSuite extends munit.ScalaCheckSuite { + test("GetPartition") { val partitionStarts = ArraySeq[Long](0, 0, 0, 4, 5, 5, 8, 10, 10) val partitionCounts = Array(0, 0, 4, 1, 0, 3, 2, 0) val keyPart = partitionCounts.zipWithIndex.flatMap { case (count, pi) => Array.fill(count)(pi) } val rp = RowPartitioner(partitionStarts) - assert(rp.numPartitions == 8) - assert((0 until 10).forall(i => keyPart(i) == rp.getPartition(i.toLong))) + assertEquals(rp.numPartitions, 8) + (0 until 10).foreach(i => assertEquals(rp.getPartition(i.toLong), keyPart(i))) } - @Test def testFindInterval(): Unit = { + property("FindInterval") { def naiveFindInterval(a: IndexedSeq[Long], key: Long): Int = { if (a.length == 0 || key < a(0)) -1 @@ -37,13 +32,11 @@ class RowPartitionerSuite extends TestNGSuite with ScalaCheckDrivenPropertyCheck val moreKeys = Array(Long.MinValue, -1000L, -1L, 0L, 1L, 1000L, Long.MaxValue) - forAll(arbitrary[ArraySeq[Long]] map { _.sorted }) { a => - whenever(a.nonEmpty) { - scalatest.Inspectors.forAll(a ++ moreKeys) { key => - assert( - !(key > a.head && key < a.last) || - RowPartitioner.findInterval(a, key) == naiveFindInterval(a, key) - ) + forAll(arbitrary[ArraySeq[Long]].map(_.sorted)) { a => + (a.nonEmpty) ==> { + (a ++ moreKeys).forall { key => + !(key > a.head && key < a.last) || + RowPartitioner.findInterval(a, key) == naiveFindInterval(a, key) } } } diff --git a/hail/hail/test/src/is/hail/lir/CompileTimeRequirednessSuite.scala b/hail/hail/test/src/is/hail/lir/CompileTimeRequirednessSuite.scala index 4faa96ff20d..c5ce8202a8c 100644 --- a/hail/hail/test/src/is/hail/lir/CompileTimeRequirednessSuite.scala +++ b/hail/hail/test/src/is/hail/lir/CompileTimeRequirednessSuite.scala @@ -3,10 +3,8 @@ package is.hail.lir import is.hail.HailSuite import is.hail.asm4s._ -import org.testng.annotations.Test - class CompileTimeRequirednessSuite extends HailSuite { - @Test def testCodeBooleanFolding(): Unit = { + test("CodeBooleanFolding") { val cFalse = const(false) val cTrue = const(true) diff --git a/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala b/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala index 060ceb1157c..3a30d9586ee 100644 --- a/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala +++ b/hail/hail/test/src/is/hail/lir/LIRSplitSuite.scala @@ -4,11 +4,9 @@ import is.hail.HailSuite import is.hail.asm4s._ import is.hail.expr.ir.{EmitFunctionBuilder, ParamType} -import org.testng.annotations.Test - class LIRSplitSuite extends HailSuite { - @Test def testSplitPreservesParameterMutation(): Unit = { + test("SplitPreservesParameterMutation") { val f = EmitFunctionBuilder[Unit](ctx, "F") f.emitWithBuilder { cb => val mb = f.newEmitMethod("m", IndexedSeq[ParamType](typeInfo[Long]), typeInfo[Unit]) diff --git a/hail/hail/test/src/is/hail/methods/ExprSuite.scala b/hail/hail/test/src/is/hail/methods/ExprSuite.scala index f9fe089095f..7217d3fc731 100644 --- a/hail/hail/test/src/is/hail/methods/ExprSuite.scala +++ b/hail/hail/test/src/is/hail/methods/ExprSuite.scala @@ -13,15 +13,13 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalacheck.Arbitrary._ import org.scalacheck.Gen -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class ExprSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { +class ExprSuite extends HailSuite with munit.ScalaCheckSuite { def sm: HailStateManager = ctx.stateManager - @Test def testTypePretty(): Unit = { + property("TypePretty") { // for arbType val sb = new StringBuilder @@ -31,27 +29,24 @@ class ExprSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val res = sb.result() val parsed = IRParser.parseType(res) t == parsed - } - - forAll { (t: Type) => - sb.clear() - t.pretty(sb, 0, compact = false) - val res = sb.result() - val parsed = IRParser.parseType(res) - t == parsed - } - - forAll { (t: Type) => - val s = t.parsableString() - val parsed = IRParser.parseType(s) - assert(t == parsed) - } + } ++ + forAll { (t: Type) => + sb.clear() + t.pretty(sb, 0, compact = false) + val res = sb.result() + val parsed = IRParser.parseType(res) + t == parsed + } ++ + forAll { (t: Type) => + val s = t.parsableString() + val parsed = IRParser.parseType(s) + t == parsed + } } - @Test def testEscaping(): Unit = - forAll((s: String) => assert(s == unescapeString(escapeString(s)))) + property("Escaping") = forAll((s: String) => s == unescapeString(escapeString(s))) - @Test def testEscapingSimple(): Unit = { + property("EscapingSimple") { // a == 0x61, _ = 0x5f assert(escapeStringSimple("abc", '_', _ => false) == "abc") assert(escapeStringSimple("abc", '_', _ == 'a') == "_61bc") @@ -64,20 +59,22 @@ class ExprSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(unescapeStringSimple("my name is _u540d_u8c26", '_') == "my name is 名谦") forAll(Gen.asciiPrintableStr) { (s: String) => - assert(s == unescapeStringSimple( + s == unescapeStringSimple( escapeStringSimple(s, '_', _.isLetterOrDigit, _.isLetterOrDigit), '_', - )) + ) } } - @Test def testImportEmptyJSONObjectAsStruct(): Unit = + test("ImportEmptyJSONObjectAsStruct") { assert(JSONAnnotationImpex.importAnnotation(parse("{}"), TStruct()) == Row()) + } - @Test def testExportEmptyJSONObjectAsStruct(): Unit = + test("ExportEmptyJSONObjectAsStruct") { assert(compact(render(JSONAnnotationImpex.exportAnnotation(Row(), TStruct()))) == "{}") + } - @Test def testRoundTripEmptyJSONObject(): Unit = { + test("RoundTripEmptyJSONObject") { val actual = JSONAnnotationImpex.exportAnnotation( JSONAnnotationImpex.importAnnotation(parse("{}"), TStruct()), TStruct(), @@ -85,35 +82,33 @@ class ExprSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(compact(render(actual)) == "{}") } - @Test def testRoundTripEmptyStruct(): Unit = { + test("RoundTripEmptyStruct") { val actual = JSONAnnotationImpex.importAnnotation( JSONAnnotationImpex.exportAnnotation(Row(), TStruct()), TStruct(), ) - assert(actual == Row()) + assertEquals(actual, Row()) } - @Test def testImpexes(): Unit = { - + property("Impexes") { val g = for { t <- arbitrary[Type] a <- genNullable(ctx, t) } yield (t, a) forAll(g) { case (t, a) => - assert(JSONAnnotationImpex.importAnnotation( + JSONAnnotationImpex.importAnnotation( JSONAnnotationImpex.exportAnnotation(a, t), t, - ) == a) - } - - forAll(g) { case (t, a) => - val string = compact(JSONAnnotationImpex.exportAnnotation(a, t)) - assert(JSONAnnotationImpex.importAnnotation(parse(string), t) == a) - } + ) == a + } ++ + forAll(g) { case (t, a) => + val string = compact(JSONAnnotationImpex.exportAnnotation(a, t)) + JSONAnnotationImpex.importAnnotation(parse(string), t) == a + } } - @Test def testOrdering(): Unit = { + property("Ordering") { val intOrd = TInt32.ordering(ctx.stateManager) assert(intOrd.compare(-2, -2) == 0) @@ -130,7 +125,7 @@ class ExprSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { forAll(g) { case (t, a, b) => val ord = t.ordering(ctx.stateManager) - assert(ord.compare(a, b) == -ord.compare(b, a)) + ord.compare(a, b) == -ord.compare(b, a) } } } diff --git a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala index f4ffebf2ef8..3b21a163a52 100644 --- a/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala +++ b/hail/hail/test/src/is/hail/methods/LocalLDPruneSuite.scala @@ -10,12 +10,9 @@ import is.hail.variant._ import breeze.linalg.{Vector => BVector} import org.apache.spark.rdd.RDD -import org.scalacheck.{Gen, Properties} +import org.scalacheck.Gen import org.scalacheck.Gen._ import org.scalacheck.Prop.forAll -import org.scalatest -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.testng.annotations.Test object LocalLDPruneSuite { val variantByteOverhead = 50 @@ -111,7 +108,7 @@ object LocalLDPruneSuite { } } -class LocalLDPruneSuite extends HailSuite { +class LocalLDPruneSuite extends HailSuite with munit.ScalaCheckSuite { val memoryPerCoreBytes = 256L * 1024 * 1024 val nCores = 4 @@ -231,19 +228,19 @@ class LocalLDPruneSuite extends HailSuite { locallyUncorrelated.fold(true)((bool1, bool2) => bool1 && bool2) } - @Test def testBitPackUnpack(): Unit = { + test("BitPackUnpack") { val calls1 = ArraySeq(-1, 0, 1, 2, 1, 1, 0, 0, 0, 0, 2, 2, -1, -1, -1, -1).map(toC2) val calls2 = ArraySeq(0, 1, 2, 2, 2, 0, -1, -1).map(toC2) val calls3 = calls1 ++ ArraySeq.fill(32 - calls1.length)(toC2(0)) ++ calls2 - scalatest.Inspectors.forAll(ArraySeq(calls1, calls2, calls3)) { calls => - scalatest.Inspectors.forAll(LocalLDPruneSuite.fromCalls(calls).toSeq) { bpv => - bpv.unpack().map(toC2) shouldBe calls + ArraySeq(calls1, calls2, calls3).foreach { calls => + LocalLDPruneSuite.fromCalls(calls).toSeq.foreach { bpv => + assertEquals(bpv.unpack().map(toC2), calls) } } } - @Test def testR2(): Unit = { + test("R2") { val calls = Array( Array(1, 0, 0, 0, 0, 0, 0, 0).map(toC2), Array(1, 1, 1, 1, 1, 1, 1, 1).map(toC2), @@ -290,7 +287,7 @@ class LocalLDPruneSuite extends HailSuite { assert(D_==(LocalLDPrune.computeR2(bvi1, bvi2), 1d)) } - object Spec extends Properties("LDPrune") { + { val vectorGen: Gen[(Call, IndexedSeq[BoxedCall], IndexedSeq[BoxedCall])] = for { nSamples: Int <- choose(1, 1000) @@ -298,47 +295,43 @@ class LocalLDPruneSuite extends HailSuite { v2 <- containerOfN[ArraySeq, BoxedCall](nSamples, choose(-1, 2).map(toC2)) } yield (nSamples, v1, v2) - (property("bitPacked pack and unpack give same as orig") = - forAll(vectorGen) { case (_: Int, v1: IndexedSeq[BoxedCall], _) => + property("bitPacked pack and unpack give same as orig") = forAll(vectorGen) { + case (_: Int, v1: IndexedSeq[BoxedCall], _) => val bpv = LocalLDPruneSuite.fromCalls(v1) - bpv match { - case Some(x) => LocalLDPruneSuite.fromCalls(x.unpack().map(toC2)).get.gs == x.gs - case None => true + bpv.foreach { x => + assert(LocalLDPruneSuite.fromCalls(x.unpack().map(toC2)).get.gs sameElements x.gs) } - }): Unit - - (property("R2 bitPacked same as BVector") = - forAll(vectorGen) { - case (nSamples: Int, v1: IndexedSeq[BoxedCall], v2: IndexedSeq[BoxedCall]) => - val bv1 = LocalLDPruneSuite.fromCalls(v1) - val bv2 = LocalLDPruneSuite.fromCalls(v2) - - val sgs1 = - LocalLDPruneSuite.normalizedHardCalls(v1).map(math.sqrt(1d / nSamples) * BVector(_)) - val sgs2 = - LocalLDPruneSuite.normalizedHardCalls(v2).map(math.sqrt(1d / nSamples) * BVector(_)) - - (bv1, bv2, sgs1, sgs2) match { - case (Some(a), Some(b), Some(c: BVector[Double]), Some(d: BVector[Double])) => - val rBreeze = c.dot(d): Double - val r2Breeze = rBreeze * rBreeze - val r2BitPacked = LocalLDPrune.computeR2(a, b) - - val isSame = - D_==(r2BitPacked, r2Breeze) && D_>=(r2BitPacked, 0d) && D_<=(r2BitPacked, 1d) - if (!isSame) { - println(s"breeze=$r2Breeze bitPacked=$r2BitPacked nSamples=$nSamples") - } - isSame - case _ => true - } - }): Unit - } + } - @Test def testRandom(): Unit = Spec.check() + property("R2 bitPacked same as BVector") = forAll(vectorGen) { + case (nSamples: Int, v1: IndexedSeq[BoxedCall], v2: IndexedSeq[BoxedCall]) => + val bv1 = LocalLDPruneSuite.fromCalls(v1) + val bv2 = LocalLDPruneSuite.fromCalls(v2) + + val sgs1 = + LocalLDPruneSuite.normalizedHardCalls(v1).map(math.sqrt(1d / nSamples) * BVector(_)) + val sgs2 = + LocalLDPruneSuite.normalizedHardCalls(v2).map(math.sqrt(1d / nSamples) * BVector(_)) + + (bv1, bv2, sgs1, sgs2) match { + case (Some(a), Some(b), Some(c: BVector[Double]), Some(d: BVector[Double])) => + val rBreeze = c.dot(d): Double + val r2Breeze = rBreeze * rBreeze + val r2BitPacked = LocalLDPrune.computeR2(a, b) + + val isSame = + D_==(r2BitPacked, r2Breeze) && D_>=(r2BitPacked, 0d) && D_<=(r2BitPacked, 1d) + if (!isSame) { + println(s"breeze=$r2Breeze bitPacked=$r2BitPacked nSamples=$nSamples") + } + isSame + case _ => true + } + } + } - @Test def testIsLocallyUncorrelated(): Unit = { + test("IsLocallyUncorrelated") { val locallyPrunedVariantsTable = LocalLDPrune(ctx, mt, r2Threshold = 0.2, windowSize = 1000000, maxQueueSize = maxQueueSize) assert(isLocallyUncorrelated(mt, locallyPrunedVariantsTable, 0.2, 1000000)) diff --git a/hail/hail/test/src/is/hail/methods/MultiArray2Suite.scala b/hail/hail/test/src/is/hail/methods/MultiArray2Suite.scala index 112a469422b..c265fa6d212 100644 --- a/hail/hail/test/src/is/hail/methods/MultiArray2Suite.scala +++ b/hail/hail/test/src/is/hail/methods/MultiArray2Suite.scala @@ -1,106 +1,105 @@ package is.hail.methods -import is.hail.HailSuite import is.hail.utils.MultiArray2 -import org.testng.annotations.Test - -class MultiArray2Suite extends HailSuite { - @Test def test() = { - - // test multiarray of size 0 will be created +class MultiArray2Suite extends munit.FunSuite { + test("fill with size 0") { MultiArray2.fill[Int](0, 0)(0): Unit + } - // test multiarray of size 0 that apply nothing out - assertThrows[IllegalArgumentException] { + test("apply on size 0 throws") { + intercept[IllegalArgumentException] { val ma0 = MultiArray2.fill[Int](0, 0)(0) ma0(0, 0) - } + }: Unit + } - // test array index out of bounds on row slice - assertThrows[ArrayIndexOutOfBoundsException] { + test("row slice out of bounds") { + intercept[ArrayIndexOutOfBoundsException] { val foo = MultiArray2.fill[Int](5, 5)(0) foo.row(0)(5) - } + }: Unit + } - // bad multiarray initiation -- negative number - assertThrows[IllegalArgumentException] { + test("negative row count") { + intercept[IllegalArgumentException] { MultiArray2.fill[Int](-5, 5)(0) - } + }: Unit + } - // bad multiarray initiation -- negative number - assertThrows[IllegalArgumentException] { + test("negative column count") { + intercept[IllegalArgumentException] { MultiArray2.fill[Int](5, -5)(0) - } + }: Unit + } + test("update and apply") { val ma1 = MultiArray2.fill[Int](10, 3)(0) for ((i, j) <- ma1.indices) ma1.update(i, j, i * j) - assert(ma1(2, 2) == 4) - assert(ma1(6, 1) == 6) + assertEquals(ma1(2, 2), 4) + assertEquals(ma1(6, 1), 6) - // Catch exception if try to apply value that is not in indices of multiarray - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { ma1(100, 100) - } + }: Unit val ma2 = MultiArray2.fill[Int](10, 3)(0) for ((i, j) <- ma2.indices) ma2.update(i, j, i + j) - assert(ma2(2, 2) == 4) - assert(ma2(6, 1) == 7) + assertEquals(ma2(2, 2), 4) + assertEquals(ma2(6, 1), 7) - // Test zip with two ints + // zip two int arrays val ma3 = ma1.zip(ma2) - assert(ma3(2, 2) == ((4, 4))) - assert(ma3(6, 1) == ((6, 7))) + assertEquals(ma3(2, 2), (4, 4)) + assertEquals(ma3(6, 1), (6, 7)) - // Test zip with multi-arrays of different types + // zip arrays of different types val ma4 = MultiArray2.fill[String](10, 3)("foo") val ma5 = ma1.zip(ma4) - assert(ma5(2, 2) == ((4, "foo"))) - assert(ma5(0, 0) == ((0, "foo"))) + assertEquals(ma5(2, 2), (4, "foo")) + assertEquals(ma5(0, 0), (0, "foo")) - // Test row slice + // row slice for { row <- ma5.rows idx <- 0 until row.length } - assert(row(idx) == ((row.i * idx, "foo"))) + assertEquals(row(idx), (row.i * idx, "foo")) - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { ma5.row(100) - } + }: Unit - assertThrows[ArrayIndexOutOfBoundsException] { + intercept[ArrayIndexOutOfBoundsException] { val x = ma5.row(0) x(100) - } + }: Unit - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { ma5.row(-5) - } + }: Unit - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { ma5.column(100) - } + }: Unit - assertThrows[IllegalArgumentException] { + intercept[IllegalArgumentException] { ma5.column(-5) - } + }: Unit - assertThrows[ArrayIndexOutOfBoundsException] { + intercept[ArrayIndexOutOfBoundsException] { val x = ma5.column(0) x(100) - } + }: Unit - // Test column slice + // column slice for { column <- ma5.columns idx <- 0 until column.length } - assert(column(idx) == ((column.j * idx, "foo"))) - + assertEquals(column(idx), (column.j * idx, "foo")) } } diff --git a/hail/hail/test/src/is/hail/methods/SkatSuite.scala b/hail/hail/test/src/is/hail/methods/SkatSuite.scala index 447d937df82..32ade2d60b8 100644 --- a/hail/hail/test/src/is/hail/methods/SkatSuite.scala +++ b/hail/hail/test/src/is/hail/methods/SkatSuite.scala @@ -4,11 +4,10 @@ import is.hail.HailSuite import is.hail.utils._ import breeze.linalg._ -import org.testng.annotations.Test class SkatSuite extends HailSuite { - @Test def smallNLargeNEqualityTest(): Unit = { + test("smallNLargeNEquality") { val rand = scala.util.Random rand.setSeed(0) diff --git a/hail/hail/test/src/is/hail/rvd/RVDPartitionerSuite.scala b/hail/hail/test/src/is/hail/rvd/RVDPartitionerSuite.scala index 12fe2c05977..ff318d7e38a 100644 --- a/hail/hail/test/src/is/hail/rvd/RVDPartitionerSuite.scala +++ b/hail/hail/test/src/is/hail/rvd/RVDPartitionerSuite.scala @@ -7,16 +7,14 @@ import is.hail.types.virtual.{TInt32, TStruct} import is.hail.utils.Interval import org.apache.spark.sql.Row -import org.testng.ITestContext -import org.testng.annotations.{BeforeMethod, Test} class RVDPartitionerSuite extends HailSuite { val kType = TStruct(("A", TInt32), ("B", TInt32), ("C", TInt32)) var partitioner: RVDPartitioner = _ - @BeforeMethod - def setupPartitioner(context: ITestContext): Unit = { + override def beforeEach(context: BeforeEach): Unit = { + super.beforeEach(context) partitioner = new RVDPartitioner( ctx.stateManager, kType, @@ -28,7 +26,7 @@ class RVDPartitionerSuite extends HailSuite { ) } - @Test def testExtendKey(): Unit = { + test("ExtendKey") { val p = new RVDPartitioner( ctx.stateManager, TStruct(("A", TInt32), ("B", TInt32)), @@ -47,67 +45,70 @@ class RVDPartitionerSuite extends HailSuite { )) } - @Test def testGetPartitionWithPartitionKeys(): Unit = { - assert(partitioner.lowerBound(Row(-1, 7)) == 0) - assert(partitioner.upperBound(Row(-1, 7)) == 0) + test("GetPartitionWithPartitionKeys") { + assertEquals(partitioner.lowerBound(Row(-1, 7)), 0) + assertEquals(partitioner.upperBound(Row(-1, 7)), 0) - assert(partitioner.lowerBound(Row(4, 2)) == 0) - assert(partitioner.upperBound(Row(4, 2)) == 1) + assertEquals(partitioner.lowerBound(Row(4, 2)), 0) + assertEquals(partitioner.upperBound(Row(4, 2)), 1) - assert(partitioner.lowerBound(Row(4, 3)) == 1) - assert(partitioner.upperBound(Row(4, 3)) == 2) + assertEquals(partitioner.lowerBound(Row(4, 3)), 1) + assertEquals(partitioner.upperBound(Row(4, 3)), 2) - assert(partitioner.lowerBound(Row(5, -10259)) == 1) - assert(partitioner.upperBound(Row(5, -10259)) == 2) + assertEquals(partitioner.lowerBound(Row(5, -10259)), 1) + assertEquals(partitioner.upperBound(Row(5, -10259)), 2) - assert(partitioner.lowerBound(Row(7, 9)) == 2) - assert(partitioner.upperBound(Row(7, 9)) == 2) + assertEquals(partitioner.lowerBound(Row(7, 9)), 2) + assertEquals(partitioner.upperBound(Row(7, 9)), 2) - assert(partitioner.lowerBound(Row(12, 19)) == 3) - assert(partitioner.upperBound(Row(12, 19)) == 3) + assertEquals(partitioner.lowerBound(Row(12, 19)), 3) + assertEquals(partitioner.upperBound(Row(12, 19)), 3) } - @Test def testGetPartitionWithLargerKeys(): Unit = { - assert(partitioner.lowerBound(Row(0, 1, 3)) == 0) - assert(partitioner.upperBound(Row(0, 1, 3)) == 0) + test("GetPartitionWithLargerKeys") { + assertEquals(partitioner.lowerBound(Row(0, 1, 3)), 0) + assertEquals(partitioner.upperBound(Row(0, 1, 3)), 0) - assert(partitioner.lowerBound(Row(2, 7, 5)) == 0) - assert(partitioner.upperBound(Row(2, 7, 5)) == 1) + assertEquals(partitioner.lowerBound(Row(2, 7, 5)), 0) + assertEquals(partitioner.upperBound(Row(2, 7, 5)), 1) - assert(partitioner.lowerBound(Row(4, 2, 1, 2.7, "bar")) == 0) + assertEquals(partitioner.lowerBound(Row(4, 2, 1, 2.7, "bar")), 0) - assert(partitioner.lowerBound(Row(7, 9, 7)) == 2) - assert(partitioner.upperBound(Row(7, 9, 7)) == 2) + assertEquals(partitioner.lowerBound(Row(7, 9, 7)), 2) + assertEquals(partitioner.upperBound(Row(7, 9, 7)), 2) - assert(partitioner.lowerBound(Row(11, 1, 42)) == 3) + assertEquals(partitioner.lowerBound(Row(11, 1, 42)), 3) } - @Test def testGetPartitionPKWithSmallerKeys(): Unit = { - assert(partitioner.lowerBound(Row(2)) == 0) - assert(partitioner.upperBound(Row(2)) == 1) + test("GetPartitionPKWithSmallerKeys") { + assertEquals(partitioner.lowerBound(Row(2)), 0) + assertEquals(partitioner.upperBound(Row(2)), 1) - assert(partitioner.lowerBound(Row(4)) == 0) - assert(partitioner.upperBound(Row(4)) == 2) + assertEquals(partitioner.lowerBound(Row(4)), 0) + assertEquals(partitioner.upperBound(Row(4)), 2) - assert(partitioner.lowerBound(Row(11)) == 3) - assert(partitioner.upperBound(Row(11)) == 3) + assertEquals(partitioner.lowerBound(Row(11)), 3) + assertEquals(partitioner.upperBound(Row(11)), 3) } - @Test def testGetPartitionRange(): Unit = { - assert(partitioner.queryInterval(Interval(Row(3, 4), Row(7, 11), true, true)) == Seq(0, 1, 2)) - assert(partitioner.queryInterval(Interval(Row(3, 4), Row(7, 9), true, false)) == Seq(0, 1)) - assert(partitioner.queryInterval(Interval(Row(4), Row(5), true, true)) == Seq(0, 1)) - assert(partitioner.queryInterval(Interval(Row(4), Row(5), false, true)) == Seq(1)) - assert(partitioner.queryInterval(Interval(Row(-1, 7), Row(0, 9), true, false)) == Seq()) + test("GetPartitionRange") { + assertEquals( + partitioner.queryInterval(Interval(Row(3, 4), Row(7, 11), true, true)), + Seq(0, 1, 2), + ) + assertEquals(partitioner.queryInterval(Interval(Row(3, 4), Row(7, 9), true, false)), Seq(0, 1)) + assertEquals(partitioner.queryInterval(Interval(Row(4), Row(5), true, true)), Seq(0, 1)) + assertEquals(partitioner.queryInterval(Interval(Row(4), Row(5), false, true)), Seq(1)) + assert(partitioner.queryInterval(Interval(Row(-1, 7), Row(0, 9), true, false)).isEmpty) } - @Test def testGetSafePartitionKeyRange(): Unit = { + test("GetSafePartitionKeyRange") { assert(partitioner.queryKey(Row(0, 0)).isEmpty) assert(partitioner.queryKey(Row(7, 10)).isEmpty) - assert(partitioner.queryKey(Row(7, 11)) == Range.inclusive(2, 2)) + assertEquals(partitioner.queryKey(Row(7, 11)), Range.inclusive(2, 2)) } - @Test def testGenerateDisjoint(): Unit = { + test("GenerateDisjoint") { val intervals = ArraySeq( Interval(Row(1, 0, 4), Row(4, 3, 2), true, false), Interval(Row(4, 3, 5), Row(7, 9, 1), true, false), @@ -148,21 +149,21 @@ class RVDPartitionerSuite extends HailSuite { )) } - @Test def testGenerateEmptyKey(): Unit = { + test("GenerateEmptyKey") { val intervals1 = ArraySeq(Interval(Row(), Row(), true, true)) val intervals5 = ArraySeq.fill(5)(Interval(Row(), Row(), true, true)) val p5 = RVDPartitioner.generate(ctx.stateManager, FastSeq(), TStruct.empty, intervals5) - assert(p5.rangeBounds == intervals1) + assertEquals(p5.rangeBounds, intervals1) val p1 = RVDPartitioner.generate(ctx.stateManager, FastSeq(), TStruct.empty, intervals1) - assert(p1.rangeBounds == intervals1) + assertEquals(p1.rangeBounds, intervals1) val p0 = RVDPartitioner.generate(ctx.stateManager, FastSeq(), TStruct.empty, FastSeq()) assert(p0.rangeBounds.isEmpty) } - @Test def testIntersect(): Unit = { + test("Intersect") { val kType = TStruct(("key", TInt32)) val left = new RVDPartitioner( diff --git a/hail/hail/test/src/is/hail/services/BatchClientSuite.scala b/hail/hail/test/src/is/hail/services/BatchClientSuite.scala index 821849177f4..3b517186d7d 100644 --- a/hail/hail/test/src/is/hail/services/BatchClientSuite.scala +++ b/hail/hail/test/src/is/hail/services/BatchClientSuite.scala @@ -10,23 +10,16 @@ import is.hail.utils._ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future -import java.lang.reflect.Method import java.nio.file.Path -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.{AfterClass, BeforeClass, BeforeMethod, Test} - -class BatchClientSuite extends TestNGSuite { +class BatchClientSuite extends munit.FunSuite { private[this] var client: BatchClient = _ private[this] var batchId: Int = _ private[this] var parentJobGroupId: Int = _ - @BeforeClass - def createClientAndBatch(): Unit = { + override def beforeAll(): Unit = { + super.beforeAll() client = BatchClient( DeployConfig.default, @@ -44,25 +37,25 @@ class BatchClientSuite extends TestNGSuite { ) } - @BeforeMethod - def createEmptyParentJobGroup(m: Method): Unit = { + override def beforeEach(context: BeforeEach): Unit = { + super.beforeEach(context) parentJobGroupId = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, absolute_parent_id = 0, token = tokenUrlSafe, - attributes = Map("name" -> m.getName), + attributes = Map("name" -> context.test.name), jobs = FastSeq(), ) )._1 } - @AfterClass - def closeClient(): Unit = + override def afterAll(): Unit = { client.close() + super.afterAll() + } - @Test - def testCancelAfterNFailures(): Unit = { + test("CancelAfterNFailures") { val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, @@ -89,13 +82,12 @@ class BatchClientSuite extends TestNGSuite { ) ) val result = client.waitForJobGroup(batchId, jobGroupId) - assert(result.state == Failure) - assert(result.n_jobs == 2) - assert(result.n_failed == 1) + assertEquals(result.state, Failure) + assertEquals(result.n_jobs, 2) + assertEquals(result.n_failed, 1) } - @Test - def testGetJobGroupJobsByState(): Unit = { + test("GetJobGroupJobsByState") { val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, @@ -120,19 +112,17 @@ class BatchClientSuite extends TestNGSuite { ) ) client.waitForJobGroup(batchId, jobGroupId): Unit - forAll(Array(JobStates.Failed, JobStates.Success)) { state => - forAll(client.getJobGroupJobs(batchId, jobGroupId, Some(state))) { jobs => - assert(jobs.length == 1) - assert(jobs(0).state == state) + Array(JobStates.Failed, JobStates.Success).foreach { state => + client.getJobGroupJobs(batchId, jobGroupId, Some(state)).foreach { jobs => + assertEquals(jobs.length, 1) + assertEquals(jobs(0).state, state) assert(jobs.head.end_time.isDefined) } } } - @Test - def testNewJobGroup(): Unit = - // The query driver submits a job group per stage with one job per partition - forAll(1 to 2) { i => + test("NewJobGroup") { + (1 to 2).foreach { i => val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, @@ -152,11 +142,11 @@ class BatchClientSuite extends TestNGSuite { ) val result = client.getJobGroup(batchId, jobGroupId) - assert(result.n_jobs == i) + assertEquals(result.n_jobs, i) } + } - @Test - def testJvmJob(): Unit = { + test("JvmJob") { val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, @@ -177,11 +167,10 @@ class BatchClientSuite extends TestNGSuite { ) val result = client.getJobGroup(batchId, jobGroupId) - assert(result.n_jobs == 1) + assertEquals(result.n_jobs, 1) } - @Test // #15118 - def testWaitForCancelledJobGroup(): Unit = { + test("WaitForCancelledJobGroup") { // #15118 val (jobGroupId, _) = client.newJobGroup( req = JobGroupRequest( batch_id = batchId, @@ -207,6 +196,6 @@ class BatchClientSuite extends TestNGSuite { }: Unit val jg = client.waitForJobGroup(batchId, jobGroupId) - jg.state shouldBe services.JobGroupStates.Cancelled + assertEquals(jg.state, services.JobGroupStates.Cancelled) } } diff --git a/hail/hail/test/src/is/hail/sparkextras/RichRDDSuite.scala b/hail/hail/test/src/is/hail/sparkextras/RichRDDSuite.scala index 6654157766d..23d90bc280b 100644 --- a/hail/hail/test/src/is/hail/sparkextras/RichRDDSuite.scala +++ b/hail/hail/test/src/is/hail/sparkextras/RichRDDSuite.scala @@ -5,10 +5,8 @@ import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.ExportType import is.hail.sparkextras.implicits._ -import org.testng.annotations.Test - class RichRDDSuite extends HailSuite { - @Test def parallelWrite(): Unit = { + test("parallelWrite") { def read(file: String): Array[String] = fs.readLines(file)(_.map(_.value).toArray) val header = "my header is awesome!" diff --git a/hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala b/hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala index bf5d68aa584..be20e72faac 100644 --- a/hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala +++ b/hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala @@ -2,11 +2,9 @@ package is.hail.stats import is.hail.HailSuite -import org.testng.annotations.Test - class FisherExactTestSuite extends HailSuite { - @Test def testPvalue(): Unit = { + test("Pvalue") { val a = 5 val b = 10 val c = 95 diff --git a/hail/hail/test/src/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala b/hail/hail/test/src/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala index 69e829f1225..a0ef49d93ff 100644 --- a/hail/hail/test/src/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala +++ b/hail/hail/test/src/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala @@ -2,8 +2,6 @@ package is.hail.stats import is.hail.HailSuite -import org.testng.annotations.Test - class GeneralizedChiSquaredDistributionSuite extends HailSuite { private[this] def pgenchisq( c: Double, @@ -32,7 +30,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { g == g2) } - @Test def test0(): Unit = { + test("0") { val (actualValue, actualTrace, actualFault) = pgenchisq( 1.0, Array(1, 1, 1), @@ -48,10 +46,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(0.76235, 744, 2, 0.03819, 53.37969, 0.0, 51), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test1(): Unit = { + test("1") { val (actualValue, actualTrace, actualFault) = pgenchisq( 7.0, Array(1, 1, 1), @@ -66,10 +64,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.57018, 625, 2, 0.03964, 34.66214, 0.04784, 51), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test2(): Unit = { + test("2") { val (actualValue, actualTrace, actualFault) = pgenchisq( 20.0, Array(1, 1, 1), @@ -84,10 +82,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.16244, 346, 1, 0.04602, 15.88681, 0.14159, 32), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test3(): Unit = { + test("3") { val (actualValue, actualTrace, actualFault) = pgenchisq( 2.0, Array(2, 2, 2), @@ -102,10 +100,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(0.84764, 74, 1, 0.03514, 2.55311, 0.0, 22), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test4(): Unit = { + test("4") { val (actualValue, actualTrace, actualFault) = pgenchisq( 20.0, Array(2, 2, 2), @@ -120,10 +118,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.74138, 66, 1, 0.03907, 2.55311, 0.0, 22), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test5(): Unit = { + test("5") { val (actualValue, actualTrace, actualFault) = pgenchisq( 60.0, Array(2, 2, 2), @@ -135,10 +133,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.983897)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.72757, 50, 1, 0.052, 2.55311, 0.0, 22))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test6(): Unit = { + test("6") { val (actualValue, actualTrace, actualFault) = pgenchisq( 10.0, Array(6, 4, 2), @@ -153,10 +151,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.20122, 18, 1, 0.02706, 0.46096, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test7(): Unit = { + test("7") { val (actualValue, actualTrace, actualFault) = pgenchisq( 50.0, Array(6, 4, 2), @@ -171,10 +169,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.06868, 15, 1, 0.03269, 0.46096, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test8(): Unit = { + test("8") { val (actualValue, actualTrace, actualFault) = pgenchisq( 120.0, Array(6, 4, 2), @@ -189,10 +187,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.58496, 10, 1, 0.05141, 0.46096, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test9(): Unit = { + test("9") { val (actualValue, actualTrace, actualFault) = pgenchisq( 10.0, Array(2, 4, 6), @@ -207,10 +205,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.29976, 27, 1, 0.03459, 0.88302, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test10(): Unit = { + test("10") { val (actualValue, actualTrace, actualFault) = pgenchisq( 30.0, Array(2, 4, 6), @@ -225,10 +223,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.01747, 24, 1, 0.03887, 0.88302, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test11(): Unit = { + test("11") { val (actualValue, actualTrace, actualFault) = pgenchisq( 80.0, Array(2, 4, 6), @@ -243,10 +241,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.81157, 17, 1, 0.05628, 0.88302, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test12(): Unit = { + test("12") { val (actualValue, actualTrace, actualFault) = pgenchisq( 20.0, Array(6, 2), @@ -261,10 +259,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.16271, 16, 1, 0.01561, 0.24013, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test13(): Unit = { + test("13") { val (actualValue, actualTrace, actualFault) = pgenchisq( 100.0, Array(6, 2), @@ -279,10 +277,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.02277, 13, 1, 0.01949, 0.24013, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test14(): Unit = { + test("14") { val (actualValue, actualTrace, actualFault) = pgenchisq( 200.0, Array(6, 2), @@ -297,10 +295,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.09687, 10, 1, 0.02825, 0.24013, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test15(): Unit = { + test("15") { val (actualValue, actualTrace, actualFault) = pgenchisq( 10.0, Array(1, 1), @@ -315,10 +313,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(0.8712, 603, 2, 0.01628, 13.86318, 0.0, 49), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test16(): Unit = { + test("16") { val (actualValue, actualTrace, actualFault) = pgenchisq( 60.0, Array(1, 1), @@ -333,10 +331,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.69157, 340, 1, 0.02043, 6.93159, 0.24644, 31), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test17(): Unit = { + test("17") { val (actualValue, actualTrace, actualFault) = pgenchisq( 150.0, Array(1, 1), @@ -351,10 +349,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.06625, 87, 1, 0.02888, 2.47557, 0.81533, 29), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test18(): Unit = { + test("18") { val (actualValue, actualTrace, actualFault) = pgenchisq( 45.0, Array(6, 4, 2, 2, 4, 6), @@ -366,10 +364,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.01095)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.82147, 13, 1, 0.01582, 0.193, 0.0, 18))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test19(): Unit = { + test("19") { val (actualValue, actualTrace, actualFault) = pgenchisq( 120.0, Array(6, 4, 2, 2, 4, 6), @@ -381,10 +379,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.654735)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.73768, 11, 1, 0.0195, 0.193, 0.0, 18))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test20(): Unit = { + test("20") { val (actualValue, actualTrace, actualFault) = pgenchisq( 210.0, Array(6, 4, 2, 2, 4, 6), @@ -396,10 +394,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.984606)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.83651, 8, 1, 0.02707, 0.193, 0.0, 18))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test21(): Unit = { + test("21") { val (actualValue, actualTrace, actualFault) = pgenchisq( 70.0, Array(6, 2, 1, 1), @@ -414,10 +412,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.65876, 10, 1, 0.01346, 0.12785, 0.0, 18), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test22(): Unit = { + test("22") { val (actualValue, actualTrace, actualFault) = pgenchisq( 160.0, Array(6, 2, 1, 1), @@ -432,10 +430,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.34799, 9, 1, 0.01668, 0.12785, 0.0, 18), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test23(): Unit = { + test("23") { val (actualValue, actualTrace, actualFault) = pgenchisq( 260.0, Array(6, 2, 1, 1), @@ -450,10 +448,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.11236, 7, 1, 0.02271, 0.12785, 0.0, 18), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test24(): Unit = { + test("24") { val (actualValue, actualTrace, actualFault) = pgenchisq( -40.0, Array(6, 2, 1, 1), @@ -468,10 +466,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.42913, 10, 1, 0.01483, 0.12785, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test25(): Unit = { + test("25") { val (actualValue, actualTrace, actualFault) = pgenchisq( 40.0, Array(6, 2, 1, 1), @@ -486,10 +484,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.42909, 8, 1, 0.01771, 0.12785, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test26(): Unit = { + test("26") { val (actualValue, actualTrace, actualFault) = pgenchisq( 140.0, Array(6, 2, 1, 1), @@ -504,10 +502,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.19476, 10, 1, 0.01381, 0.12785, 0.0, 19), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test27(): Unit = { + test("27") { val (actualValue, actualTrace, actualFault) = pgenchisq( 120.0, Array(6, 4, 2, 2, 4, 6, 6, 2, 1, 1), @@ -522,10 +520,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.33438, 9, 1, 0.01202, 0.09616, 0.0, 18), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test28(): Unit = { + test("28") { val (actualValue, actualTrace, actualFault) = pgenchisq( 240.0, Array(6, 4, 2, 2, 4, 6, 6, 2, 1, 1), @@ -537,10 +535,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.573625)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.1401, 7, 1, 0.01561, 0.09616, 0.0, 18))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test29(): Unit = { + test("29") { val (actualValue, actualTrace, actualFault) = pgenchisq( 400.0, Array(6, 4, 2, 2, 4, 6, 6, 2, 1, 1), @@ -552,10 +550,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { ) assert(nearEqual(actualValue, 0.988332)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(4.2142, 6, 1, 0.01812, 0.09616, 0.0, 18))) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test30(): Unit = { + test("30") { val (actualValue, actualTrace, actualFault) = pgenchisq( 5.0, Array(1, 10), @@ -570,10 +568,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(0.95892, 163, 1, 0.00841, 1.3638, 0.0, 22), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test31(): Unit = { + test("31") { val (actualValue, actualTrace, actualFault) = pgenchisq( 25.0, Array(1, 10), @@ -588,10 +586,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.72922, 159, 1, 0.00864, 1.3638, 0.0, 22), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test32(): Unit = { + test("32") { val (actualValue, actualTrace, actualFault) = pgenchisq( 100.0, Array(1, 10), @@ -606,10 +604,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(4.61788, 143, 1, 0.00963, 1.3638, 0.0, 22), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test33(): Unit = { + test("33") { val (actualValue, actualTrace, actualFault) = pgenchisq( 10.0, Array(1, 20), @@ -624,10 +622,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.26245, 97, 1, 0.00839, 0.80736, 0.0, 21), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test34(): Unit = { + test("34") { val (actualValue, actualTrace, actualFault) = pgenchisq( 40.0, Array(1, 20), @@ -642,10 +640,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.16513, 93, 1, 0.00874, 0.80736, 0.0, 21), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test35(): Unit = { + test("35") { val (actualValue, actualTrace, actualFault) = pgenchisq( 100.0, Array(1, 20), @@ -660,10 +658,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.97055, 86, 1, 0.00954, 0.80736, 0.0, 21), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test36(): Unit = { + test("36") { val (actualValue, actualTrace, actualFault) = pgenchisq( 20.0, Array(1, 30), @@ -678,10 +676,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(1.65684, 81, 1, 0.00843, 0.67453, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test37(): Unit = { + test("37") { val (actualValue, actualTrace, actualFault) = pgenchisq( 50.0, Array(1, 30), @@ -696,10 +694,10 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(2.44382, 78, 1, 0.00878, 0.67453, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } - @Test def test38(): Unit = { + test("38") { val (actualValue, actualTrace, actualFault) = pgenchisq( 100.0, Array(1, 30), @@ -714,6 +712,6 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { actualTrace, DaviesAlgorithmTrace(3.75545, 72, 1, 0.00944, 0.67453, 0.0, 20), )) - assert(actualFault == 0) + assertEquals(actualFault, 0) } } diff --git a/hail/hail/test/src/is/hail/stats/LeveneHaldaneSuite.scala b/hail/hail/test/src/is/hail/stats/LeveneHaldaneSuite.scala index d32e8b6d4df..841f1ce9b85 100644 --- a/hail/hail/test/src/is/hail/stats/LeveneHaldaneSuite.scala +++ b/hail/hail/test/src/is/hail/stats/LeveneHaldaneSuite.scala @@ -3,10 +3,8 @@ package is.hail.stats import is.hail.utils._ import org.apache.commons.math3.util.CombinatoricsUtils.factorialLog -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test -class LeveneHaldaneSuite extends TestNGSuite { +class LeveneHaldaneSuite extends munit.FunSuite { def LH(n: Int, nA: Int)(nAB: Int): Double = { assert(nA >= 0 && nA <= n) @@ -27,29 +25,26 @@ class LeveneHaldaneSuite extends TestNGSuite { val examples = List((15, 10), (15, 9), (15, 0), (15, 15), (1, 0), (1, 1), (0, 0), (1526, 431), (1526, 430)) - @Test def pmfTest(): Unit = { - + test("pmf") { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val p0 = LeveneHaldane(n, nA).probability _ val p1 = LH(n, nA) _ (-2 to nA + 2).forall(nAB => D_==(p0(nAB), p1(nAB))) } - examples foreach { e => assert(test(e)) } + examples.foreach(e => assert(test(e))) } - @Test def modeTest(): Unit = { - + test("mode") { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) D_==(LH.probability(LH.mode), (nA % 2 to nA by 2).map(LH.probability).max) } - examples foreach { e => assert(test(e)) } + examples.foreach(e => assert(test(e))) } - @Test def meanTest(): Unit = { - + test("mean") { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) @@ -58,11 +53,10 @@ class LeveneHaldaneSuite extends TestNGSuite { (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * LH.probability(i)).sum, ) } - examples foreach { e => assert(test(e)) } + examples.foreach(e => assert(test(e))) } - @Test def varianceTest(): Unit = { - + test("variance") { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) @@ -71,11 +65,10 @@ class LeveneHaldaneSuite extends TestNGSuite { (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * i * LH.probability(i)).sum, ) } - examples foreach { e => assert(test(e)) } + examples.foreach(e => assert(test(e))) } - @Test def exactTestsTest(): Unit = { - + test("exactTests") { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) @@ -99,7 +92,7 @@ class LeveneHaldaneSuite extends TestNGSuite { ) ) } - examples foreach { e => assert(test(e)) } + examples.foreach(e => assert(test(e))) } } diff --git a/hail/hail/test/src/is/hail/stats/LogisticRegressionModelSuite.scala b/hail/hail/test/src/is/hail/stats/LogisticRegressionModelSuite.scala index 91f8e22391c..e6d636497d4 100644 --- a/hail/hail/test/src/is/hail/stats/LogisticRegressionModelSuite.scala +++ b/hail/hail/test/src/is/hail/stats/LogisticRegressionModelSuite.scala @@ -4,11 +4,10 @@ import is.hail.HailSuite import is.hail.utils._ import breeze.linalg._ -import org.testng.annotations.Test class LogisticRegressionModelSuite extends HailSuite { - @Test def covariatesVsInterceptOnlyTest(): Unit = { + test("covariatesVsInterceptOnly") { /* R code: * y0 = c(0, 0, 1, 1, 1, 1) c1 = c(0, 2, 1, -2, -2, 4) c2 = c(-1, 3, 5, 0, -4, 3) logfit <- @@ -60,7 +59,7 @@ class LogisticRegressionModelSuite extends HailSuite { assert(D_==(scoreStats.p, 0.8408652791, tolerance = 1.0e-5)) } - @Test def gtsAndCovariatesVsCovariatesOnlyTest(): Unit = { + test("gtsAndCovariatesVsCovariatesOnly") { /* R code: * y0 <- c(0, 0, 1, 1, 1, 1) c1 <- c(0, 2, 1, -2, -2, 4) c2 <- c(-1, 3, 5, 0, -4, 3) gts <- c(0, @@ -125,7 +124,7 @@ class LogisticRegressionModelSuite extends HailSuite { assert(D_==(scoreStats.p, 0.3724319159, tolerance = 1.0e-5)) } - @Test def firthSeparationTest(): Unit = { + test("firthSeparation") { val y = DenseVector(0d, 0d, 0d, 1d, 1d, 1d) val X = y.asDenseMatrix.t diff --git a/hail/hail/test/src/is/hail/stats/StatsSuite.scala b/hail/hail/test/src/is/hail/stats/StatsSuite.scala index ffb81a86acf..eebc9bd1f1f 100644 --- a/hail/hail/test/src/is/hail/stats/StatsSuite.scala +++ b/hail/hail/test/src/is/hail/stats/StatsSuite.scala @@ -6,11 +6,10 @@ import is.hail.utils._ import org.apache.commons.math3.distribution.{ ChiSquaredDistribution, GammaDistribution, NormalDistribution, } -import org.testng.annotations.Test class StatsSuite extends HailSuite { - @Test def chiSquaredTailTest(): Unit = { + test("chiSquaredTail") { val chiSq1 = new ChiSquaredDistribution(1) assert(D_==(pchisqtail(1d, 1), 1 - chiSq1.cumulativeProbability(1d))) assert(D_==(pchisqtail(5.52341d, 1), 1 - chiSq1.cumulativeProbability(5.52341d))) @@ -34,7 +33,7 @@ class StatsSuite extends HailSuite { assert(D_==(qchisqtail(5.507248e-89, 1), 400)) } - @Test def gammaTest(): Unit = { + test("gamma") { val gammaDist1 = new GammaDistribution(2.0, 1.0) val gammaDist2 = new GammaDistribution(1.0, 2.0) val gammaDist3 = new GammaDistribution(1.0, 1.0) @@ -111,7 +110,7 @@ class StatsSuite extends HailSuite { } } - @Test def normalTest(): Unit = { + test("normal") { val normalDist = new NormalDistribution() assert(D_==(pnorm(1), normalDist.cumulativeProbability(1))) assert(math.abs(pnorm(-10) - normalDist.cumulativeProbability(-10)) < 1e-10) @@ -126,7 +125,7 @@ class StatsSuite extends HailSuite { assert(D_==(qnorm(2.753624e-89), -20)) } - @Test def poissonTest(): Unit = { + test("poisson") { // compare with R assert(D_==(dpois(5, 10), 0.03783327)) assert(qpois(0.3, 10) == 8) @@ -145,7 +144,7 @@ class StatsSuite extends HailSuite { assert(ppois(30, 1, lowerTail = false, logP = false) > 0) } - @Test def betaTest(): Unit = { + test("beta") { val tol = 1e-5 assert(D_==(dbeta(.2, 1, 3), 1.92, tol)) @@ -161,7 +160,7 @@ class StatsSuite extends HailSuite { } - @Test def entropyTest(): Unit = { + test("entropy") { assert(D_==(entropy("accctg"), 1.79248, tolerance = 1e-5)) assert(D_==(entropy(Array(2, 3, 4, 5, 6, 6, 4)), 2.23593, tolerance = 1e-5)) diff --git a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala index 2d12073f217..9f9e381e5aa 100644 --- a/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala +++ b/hail/hail/test/src/is/hail/stats/eigSymDSuite.scala @@ -6,12 +6,9 @@ import is.hail.utils._ import breeze.linalg.{eigSym, svd, DenseMatrix, DenseVector} import org.apache.commons.math3.random.JDKRandomGenerator -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.Test class eigSymDSuite extends HailSuite { - @Test def eigSymTest(): Unit = { + test("eigSym") { val seed = 0 val rand = new JDKRandomGenerator() @@ -29,21 +26,21 @@ class eigSymDSuite extends HailSuite { val eigSymDK = eigSymD(K) // eigSymD = svdW - forAll(0 until n) { j => + (0 until n).foreach { j => assert(D_==(svdW.S(j) * svdW.S(j), eigSymDK.eigenvalues(n - j - 1))) for (i <- 0 until n) assert(D_==(math.abs(svdW.U(i, j)), math.abs(eigSymDK.eigenvectors(i, n - j - 1)))) } // eigSymR = svdK - forAll(0 until n) { j => + (0 until n).foreach { j => assert(D_==(svdK.S(j), eigSymDK.eigenvalues(n - j - 1))) for (i <- 0 until n) assert(D_==(math.abs(svdK.U(i, j)), math.abs(eigSymDK.eigenvectors(i, n - j - 1)))) } // eigSymD = eigSym - forAll(0 until n) { j => + (0 until n).foreach { j => assert(D_==(eigSymK.eigenvalues(j), eigSymDK.eigenvalues(j))) for (i <- 0 until n) assert(D_==(math.abs(eigSymK.eigenvectors(i, j)), math.abs(eigSymDK.eigenvectors(i, j)))) @@ -100,13 +97,13 @@ class eigSymDSuite extends HailSuite { timeSymEig() } - @Test def triSolveTest(): Unit = { + test("triSolve") { val seed = 0 val rand = new JDKRandomGenerator() rand.setSeed(seed) - forAll(1 to 5) { n => + (1 to 5).foreach { n => val A = DenseMatrix.zeros[Double](n, n) (0 until n).foreach(i => (i until n).foreach(j => A(i, j) = rand.nextGaussian())) diff --git a/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala b/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala index 4bf3dbbc575..b1954d7532a 100644 --- a/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala +++ b/hail/hail/test/src/is/hail/types/encoded/ETypeSuite.scala @@ -11,51 +11,49 @@ import is.hail.types.physical._ import org.apache.spark.sql.Row import org.json4s.jackson.Serialization -import org.testng.annotations.{DataProvider, Test} class ETypeSuite extends HailSuite { - @DataProvider(name = "etypes") - def etypes(): Array[Array[Any]] = { - Array[EType]( - EInt32Required, - EInt32Optional, - EInt64Required, - EFloat32Optional, - EFloat32Required, - EFloat64Optional, - EFloat64Required, - EBooleanOptional, - EBinaryRequired, - EBinaryOptional, - EBinaryRequired, - EArray(EInt32Required, required = false), - EArray(EArray(EInt32Optional, required = true), required = true), - EBaseStruct(FastSeq(), required = true), - EBaseStruct( - FastSeq(EField("x", EBinaryRequired, 0), EField("y", EFloat64Optional, 1)), - required = true, - ), - ENDArrayColumnMajor(EFloat64Required, 3), - EStructOfArrays( - FastSeq( - EField("a", EArray(EInt32Required, true), 0), - EField("b", EArray(EInt32Optional, true), 1), - ), - required = true, - structRequired = false, + val etypes: Array[EType] = Array( + EInt32Required, + EInt32Optional, + EInt64Required, + EFloat32Optional, + EFloat32Required, + EFloat64Optional, + EFloat64Required, + EBooleanOptional, + EBinaryRequired, + EBinaryOptional, + EBinaryRequired, + EArray(EInt32Required, required = false), + EArray(EArray(EInt32Optional, required = true), required = true), + EBaseStruct(FastSeq(), required = true), + EBaseStruct( + FastSeq(EField("x", EBinaryRequired, 0), EField("y", EFloat64Optional, 1)), + required = true, + ), + ENDArrayColumnMajor(EFloat64Required, 3), + EStructOfArrays( + FastSeq( + EField("a", EArray(EInt32Required, true), 0), + EField("b", EArray(EInt32Optional, true), 1), ), - ).map(t => Array(t: Any)) + required = true, + structRequired = false, + ), + ) + + object checkSerialization extends TestCases { + def apply(etype: EType)(implicit loc: munit.Location): Unit = + test("Serialization") { + implicit val formats = AbstractRVDSpec.formats + val s = Serialization.write(etype) + assertEquals(Serialization.read[EType](s), etype) + } } - @Test def testDataProvider(): Unit = etypes(): Unit - - @Test(dataProvider = "etypes") - def testSerialization(etype: EType): Unit = { - implicit val formats = AbstractRVDSpec.formats - val s = Serialization.write(etype) - assert(Serialization.read[EType](s) == etype) - } + etypes.foreach(checkSerialization(_)) def encodeDecode(inPType: PType, eType: EType, outPType: PType, data: Annotation): Annotation = { val fb = EmitFunctionBuilder[Long, OutputBuffer, Unit](ctx, "fb") @@ -95,10 +93,10 @@ class ETypeSuite extends HailSuite { def assertEqualEncodeDecode(inPType: PType, eType: EType, outPType: PType, data: Annotation) : Unit = { val encodeDecodeResult = encodeDecode(inPType, eType, outPType, data) - assert(encodeDecodeResult == data) + assertEquals(encodeDecodeResult, data) } - @Test def testDifferentRequirednessEncodeDecode(): Unit = { + test("DifferentRequirednessEncodeDecode") { val inPType = PCanonicalArray( PCanonicalStruct( @@ -138,7 +136,7 @@ class ETypeSuite extends HailSuite { assertEqualEncodeDecode(inPType, etype, outPType, data) } - @Test def testNDArrayEncodeDecode(): Unit = { + test("NDArrayEncodeDecode") { val pTypeInt0 = PCanonicalNDArray(PInt32Required, 0, true) val eTypeInt0 = ENDArrayColumnMajor(EInt32Required, 0, true) val dataInt0 = new SafeNDArray(IndexedSeq[Long](), FastSeq(0)) @@ -161,8 +159,10 @@ class ETypeSuite extends HailSuite { val eTypeDouble3 = ENDArrayColumnMajor(EFloat64Required, 3, false) val dataDouble3 = new SafeNDArray(IndexedSeq(3L, 2L, 1L), FastSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - assert(encodeDecode(pTypeDouble3, eTypeDouble3, pTypeDouble3, dataDouble3) == - new SafeNDArray(IndexedSeq(3L, 2L, 1L), FastSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))) + assertEquals( + encodeDecode(pTypeDouble3, eTypeDouble3, pTypeDouble3, dataDouble3), + new SafeNDArray(IndexedSeq(3L, 2L, 1L), FastSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)), + ) // Test for skipping val pStructContainingNDArray = PCanonicalStruct(true, "a" -> pTypeInt2, "b" -> PInt32Optional) @@ -177,26 +177,23 @@ class ETypeSuite extends HailSuite { val dataStruct = Row(dataInt2, 3) - assert(encodeDecode( - pStructContainingNDArray, - eStructContainingNDArray, - pOnlyReadB, - dataStruct, - ) == - Row(3)) + assertEquals( + encodeDecode(pStructContainingNDArray, eStructContainingNDArray, pOnlyReadB, dataStruct), + Row(3), + ) } - @Test def testArrayOfString(): Unit = { + test("ArrayOfString") { val etype = EArray(EBinary(false), false) val toEncode = PCanonicalArray(PCanonicalStringRequired, false) val toDecode = PCanonicalArray(PCanonicalStringOptional, false) val longListOfStrings = (0 until 36).map(idx => s"foo_name_sample_$idx") val data = longListOfStrings - assert(encodeDecode(toEncode, etype, toDecode, data) == data) + assertEquals(encodeDecode(toEncode, etype, toDecode, data), data) } - @Test def testStructOfArrays(): Unit = { + test("StructOfArrays") { val etype = EStructOfArrays( FastSeq( diff --git a/hail/hail/test/src/is/hail/types/physical/PBaseStructSuite.scala b/hail/hail/test/src/is/hail/types/physical/PBaseStructSuite.scala index 42a24869dce..a036873586b 100644 --- a/hail/hail/test/src/is/hail/types/physical/PBaseStructSuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PBaseStructSuite.scala @@ -2,10 +2,8 @@ package is.hail.types.physical import is.hail.annotations.Annotation -import org.testng.annotations.Test - class PBaseStructSuite extends PhysicalTestUtils { - @Test def testStructCopy(): Unit = { + test("StructCopy") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalStruct(), @@ -178,7 +176,7 @@ class PBaseStructSuite extends PhysicalTestUtils { runTests(false, true) } - @Test def tupleCopyTests(): Unit = { + test("TupleCopy") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalTuple(false, PCanonicalString(true), PCanonicalString(true)), diff --git a/hail/hail/test/src/is/hail/types/physical/PBinarySuite.scala b/hail/hail/test/src/is/hail/types/physical/PBinarySuite.scala index 81e68c81aaf..d1fa1bc385a 100644 --- a/hail/hail/test/src/is/hail/types/physical/PBinarySuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PBinarySuite.scala @@ -1,9 +1,7 @@ package is.hail.types.physical -import org.testng.annotations.Test - class PBinarySuite extends PhysicalTestUtils { - @Test def testCopy(): Unit = { + test("Copy") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalString(), diff --git a/hail/hail/test/src/is/hail/types/physical/PCallSuite.scala b/hail/hail/test/src/is/hail/types/physical/PCallSuite.scala index d2250d851b4..69be04fb15f 100644 --- a/hail/hail/test/src/is/hail/types/physical/PCallSuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PCallSuite.scala @@ -1,9 +1,7 @@ package is.hail.types.physical -import org.testng.annotations.Test - class PCallSuite extends PhysicalTestUtils { - @Test def copyTests(): Unit = { + test("copyTests") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalCall(), diff --git a/hail/hail/test/src/is/hail/types/physical/PContainerTest.scala b/hail/hail/test/src/is/hail/types/physical/PContainerTest.scala index db8d180a048..5be7bbad48e 100644 --- a/hail/hail/test/src/is/hail/types/physical/PContainerTest.scala +++ b/hail/hail/test/src/is/hail/types/physical/PContainerTest.scala @@ -5,10 +5,6 @@ import is.hail.asm4s._ import is.hail.collection.FastSeq import is.hail.expr.ir.EmitFunctionBuilder -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.annotations.Test - class PContainerTest extends PhysicalTestUtils { def nullInByte(nElements: Int, missingElement: Int) = { IndexedSeq.tabulate(nElements) { i => @@ -66,70 +62,70 @@ class PContainerTest extends PhysicalTestUtils { res } - @Test def checkFirstNonZeroByte(): Unit = { + test("FirstNonZeroByte") { val sourceType = PCanonicalArray(PInt64(false)) - assert(testContainsNonZeroBits(sourceType, nullInByte(0, 0)) == false) - - assert(testContainsNonZeroBits(sourceType, nullInByte(1, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(1, 1)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(8, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(8, 1)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(8, 8)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(32, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(31, 31)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(32, 32)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(33, 33)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(64, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(64, 1)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(64, 32)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(64, 33)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(64, 64)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(68, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(68, 1)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(68, 32)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(68, 33)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(68, 64)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(72, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(72, 1)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(72, 32)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(72, 33)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(72, 64)) == true) - - assert(testContainsNonZeroBits(sourceType, nullInByte(73, 0)) == false) - assert(testContainsNonZeroBits(sourceType, nullInByte(73, 1)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(73, 32)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(73, 33)) == true) - assert(testContainsNonZeroBits(sourceType, nullInByte(73, 64)) == true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(0, 0)), false) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(1, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(1, 1)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(8, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(8, 1)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(8, 8)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(32, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(31, 31)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(32, 32)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(33, 33)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(64, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(64, 1)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(64, 32)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(64, 33)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(64, 64)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(68, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(68, 1)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(68, 32)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(68, 33)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(68, 64)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(72, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(72, 1)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(72, 32)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(72, 33)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(72, 64)), true) + + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(73, 0)), false) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(73, 1)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(73, 32)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(73, 33)), true) + assertEquals(testContainsNonZeroBits(sourceType, nullInByte(73, 64)), true) } - @Test def checkFirstNonZeroByteStaged(): Unit = { + test("FirstNonZeroByteStaged") { val sourceType = PCanonicalArray(PInt64(false)) - assert(testContainsNonZeroBitsStaged(sourceType, nullInByte(32, 0)) == false) - assert(testContainsNonZeroBitsStaged(sourceType, nullInByte(73, 64)) == true) + assertEquals(testContainsNonZeroBitsStaged(sourceType, nullInByte(32, 0)), false) + assertEquals(testContainsNonZeroBitsStaged(sourceType, nullInByte(73, 64)), true) } - @Test def checkHasMissingValues(): Unit = { + test("HasMissingValues") { val sourceType = PCanonicalArray(PInt64(false)) - assert(testHasMissingValues(sourceType, nullInByte(1, 0)) == false) - assert(testHasMissingValues(sourceType, nullInByte(1, 1)) == true) - assert(testHasMissingValues(sourceType, nullInByte(2, 1)) == true) + assertEquals(testHasMissingValues(sourceType, nullInByte(1, 0)), false) + assertEquals(testHasMissingValues(sourceType, nullInByte(1, 1)), true) + assertEquals(testHasMissingValues(sourceType, nullInByte(2, 1)), true) - forAll(Seq(2, 16, 31, 32, 33, 50, 63, 64, 65, 90, 127, 128, 129)) { num => - forAll(1 to num) { missing => - assert(testHasMissingValues(sourceType, nullInByte(num, missing)) == true) + Seq(2, 16, 31, 32, 33, 50, 63, 64, 65, 90, 127, 128, 129).foreach { num => + (1 to num).foreach { missing => + assertEquals(testHasMissingValues(sourceType, nullInByte(num, missing)), true) } } } - @Test def arrayCopyTest(): Unit = { + test("arrayCopy") { /* Note: can't test where data is null due to ArrayStack.top semantics (ScalaToRegionValue: * assert(size_ > 0)) */ def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { @@ -310,7 +306,7 @@ class PContainerTest extends PhysicalTestUtils { runTests(false, interpret = true) } - @Test def dictCopyTests(): Unit = { + test("dictCopy") { def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { copyTestExecutor( PCanonicalDict(PCanonicalString(), PInt32()), @@ -343,7 +339,7 @@ class PContainerTest extends PhysicalTestUtils { runTests(false, interpret = true) } - @Test def setCopyTests(): Unit = { + test("setCopy") { def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { copyTestExecutor( PCanonicalSet(PCanonicalString(true)), diff --git a/hail/hail/test/src/is/hail/types/physical/PIntervalSuite.scala b/hail/hail/test/src/is/hail/types/physical/PIntervalSuite.scala index 2ca046e14f8..5aa86434f00 100644 --- a/hail/hail/test/src/is/hail/types/physical/PIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PIntervalSuite.scala @@ -7,10 +7,8 @@ import is.hail.types.physical.stypes.concrete.{SUnreachableInterval, SUnreachabl import is.hail.types.virtual.{TInt32, TInterval} import is.hail.utils._ -import org.testng.annotations.Test - class PIntervalSuite extends PhysicalTestUtils { - @Test def copyTests(): Unit = { + test("copyTests") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalInterval(PInt64()), @@ -61,7 +59,7 @@ class PIntervalSuite extends PhysicalTestUtils { } // Just makes sure we can generate code to store an unreachable interval - @Test def storeUnreachable(): Unit = { + test("storeUnreachable") { val ust = SUnreachableInterval(TInterval(TInt32)) val usv = new SUnreachableIntervalValue(ust) val pt = PCanonicalInterval(PInt32Required, true) diff --git a/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala b/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala index 93c81e40472..fe3d69ff098 100644 --- a/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PNDArraySuite.scala @@ -7,11 +7,14 @@ import is.hail.expr.ir.{EmitCodeBuilder, EmitFunctionBuilder} import is.hail.methods.LocalWhitening import is.hail.types.physical.stypes.interfaces.{ColonIndex => Colon, _} +import scala.concurrent.duration.Duration + import org.apache.spark.sql.Row -import org.testng.annotations.Test class PNDArraySuite extends PhysicalTestUtils { - @Test def copyTests(): Unit = { + override val munitTimeout = Duration(60, "s") + + test("copyTests") { def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { copyTestExecutor( PCanonicalNDArray(PInt64(true), 1), @@ -29,7 +32,7 @@ class PNDArraySuite extends PhysicalTestUtils { runTests(false, true) } - @Test def testWhitenBase(): Unit = { + test("WhitenBase") { val fb = EmitFunctionBuilder[Region, Double](ctx, "whiten_test") val matType = PCanonicalNDArray(PFloat64Required, 2) val vecType = PCanonicalNDArray(PFloat64Required, 1) @@ -99,7 +102,7 @@ class PNDArraySuite extends PhysicalTestUtils { } } - @Test def testQrPivot(): Unit = { + test("QrPivot") { val fb = EmitFunctionBuilder[Region, Double](ctx, "whiten_test") val matType = PCanonicalNDArray(PFloat64Required, 2) val vecType = PCanonicalNDArray(PFloat64Required, 1) @@ -226,7 +229,7 @@ class PNDArraySuite extends PhysicalTestUtils { Xw } - @Test def testWhitenNonrecur(): Unit = { + test("WhitenNonrecur") { val fb = EmitFunctionBuilder[Region, Unit](ctx, "whiten_test") val matType = PCanonicalNDArray(PFloat64Required, 2) val m = SizeValueStatic(2000) @@ -315,11 +318,10 @@ class PNDArraySuite extends PhysicalTestUtils { val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) f(region) - succeed } } - @Test def testWhiten(): Unit = { + test("Whiten") { val fb = EmitFunctionBuilder[Region, Unit](ctx, "whiten_test") val matType = PCanonicalNDArray(PFloat64Required, 2) val m = SizeValueStatic(2000) @@ -377,11 +379,10 @@ class PNDArraySuite extends PhysicalTestUtils { val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) f(region) - succeed } } - @Test def testRefCounted(): Unit = { + test("RefCounted") { val nd = PCanonicalNDArray(PInt32Required, 1) val region1 = Region(pool = this.pool) @@ -423,34 +424,34 @@ class PNDArraySuite extends PhysicalTestUtils { val result1Data = nd.unstagedDataFirstElementPointer(result1) // Check number of ndarrays in each region: - assert(region1.memory.listNDArrayRefs().size == 1) - assert(region1.memory.listNDArrayRefs()(0) == result1Data) + assertEquals(region1.memory.listNDArrayRefs().size, 1) + assertEquals(region1.memory.listNDArrayRefs()(0), result1Data) - assert(region2.memory.listNDArrayRefs().size == 2) - assert(region2.memory.listNDArrayRefs()(1) == result1Data) + assertEquals(region2.memory.listNDArrayRefs().size, 2) + assertEquals(region2.memory.listNDArrayRefs()(1), result1Data) // Check that the reference count of ndarray1 is 2: val rc1A = Region.loadLong(result1Data - Region.sharedChunkHeaderBytes) - assert(rc1A == 2) + assertEquals(rc1A, 2L) region1.clear() - assert(region1.memory.listNDArrayRefs().size == 0) + assertEquals(region1.memory.listNDArrayRefs().size, 0) // Check that ndarray 1 wasn't actually cleared, ref count should just be 1 now: val rc1B = Region.loadLong(result1Data - Region.sharedChunkHeaderBytes) - assert(rc1B == 1) + assertEquals(rc1B, 1L) - assert(region3.memory.listNDArrayRefs().size == 0) + assertEquals(region3.memory.listNDArrayRefs().size, 0) // Do an unstaged copy into region3 nd.copyFromAddress(ctx.stateManager, region3, nd, result1, true): Unit - assert(region3.memory.listNDArrayRefs().size == 1) + assertEquals(region3.memory.listNDArrayRefs().size, 1) // Check that clearing region2 removes both ndarrays region2.clear() - assert(region2.memory.listNDArrayRefs().size == 0) + assertEquals(region2.memory.listNDArrayRefs().size, 0) } - @Test def testUnstagedCopy(): Unit = { + test("UnstagedCopy") { val region1 = Region(pool = this.pool) val region2 = Region(pool = this.pool) val x = SafeNDArray(IndexedSeq(3L, 2L), (0 until 6).map(_.toDouble)) @@ -462,11 +463,11 @@ class PNDArraySuite extends PhysicalTestUtils { // Deep copy same ptype just increments reference count, doesn't change the address. val dataAddr1 = Region.loadAddress(pNd.representation.loadField(ndAddr1, 2)) val dataAddr2 = Region.loadAddress(pNd.representation.loadField(ndAddr2, 2)) - assert(dataAddr1 == dataAddr2) - assert(Region.getSharedChunkRefCount(dataAddr1) == 2) - assert(unsafe1 == unsafe2) + assertEquals(dataAddr1, dataAddr2) + assertEquals(Region.getSharedChunkRefCount(dataAddr1), 2L) + assertEquals(unsafe1, unsafe2) region1.clear() - assert(Region.getSharedChunkRefCount(dataAddr1) == 1) + assertEquals(Region.getSharedChunkRefCount(dataAddr1), 1L) // Deep copy with elements that contain pointers, so have to actually do a full copy // FIXME: Currently ndarrays do not support this, reference counting needs to account for this. @@ -497,6 +498,6 @@ class PNDArraySuite extends PhysicalTestUtils { val unsafe5 = UnsafeRow.read(pNDOfStructs1, region1, addr5) val addr6 = pNDOfStructs2.copyFromAddress(ctx.stateManager, region2, pNDOfStructs1, addr5, true) val unsafe6 = UnsafeRow.read(pNDOfStructs2, region2, addr6) - assert(unsafe5 == unsafe6) + assertEquals(unsafe5, unsafe6) } } diff --git a/hail/hail/test/src/is/hail/types/physical/PTypeSuite.scala b/hail/hail/test/src/is/hail/types/physical/PTypeSuite.scala index f04dbdbe337..9e83cafd5c1 100644 --- a/hail/hail/test/src/is/hail/types/physical/PTypeSuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/PTypeSuite.scala @@ -9,68 +9,74 @@ import is.hail.variant.ReferenceGenome import org.apache.spark.sql.Row import org.json4s.jackson.Serialization -import org.testng.annotations.{DataProvider, Test} class PTypeSuite extends HailSuite { - @DataProvider(name = "ptypes") - def ptypes(): Array[Array[Any]] = { - Array[PType]( - PInt32(true), - PInt32(false), - PInt64(true), - PInt64(false), - PFloat32(true), - PFloat64(true), - PBoolean(true), - PCanonicalCall(true), - PCanonicalBinary(false), - PCanonicalString(true), - PCanonicalLocus(ReferenceGenome.GRCh37, false), - PCanonicalArray(PInt32Required, true), - PCanonicalSet(PInt32Required, false), - PCanonicalDict(PInt32Required, PCanonicalString(true), true), - PCanonicalInterval(PInt32Optional, false), - PCanonicalTuple( - FastSeq(PTupleField(1, PInt32Required), PTupleField(3, PCanonicalString(false))), - true, - ), - PCanonicalStruct( - FastSeq(PField("foo", PInt32Required, 0), PField("bar", PCanonicalString(false), 1)), - true, - ), - ).map(t => Array(t: Any)) - } - - @Test def testPTypesDataProvider(): Unit = ptypes(): Unit + val ptypes: Array[PType] = Array( + PInt32(true), + PInt32(false), + PInt64(true), + PInt64(false), + PFloat32(true), + PFloat64(true), + PBoolean(true), + PCanonicalCall(true), + PCanonicalBinary(false), + PCanonicalString(true), + PCanonicalLocus(ReferenceGenome.GRCh37, false), + PCanonicalArray(PInt32Required, true), + PCanonicalSet(PInt32Required, false), + PCanonicalDict(PInt32Required, PCanonicalString(true), true), + PCanonicalInterval(PInt32Optional, false), + PCanonicalTuple( + FastSeq(PTupleField(1, PInt32Required), PTupleField(3, PCanonicalString(false))), + true, + ), + PCanonicalStruct( + FastSeq(PField("foo", PInt32Required, 0), PField("bar", PCanonicalString(false), 1)), + true, + ), + ) - @Test(dataProvider = "ptypes") - def testSerialization(ptype: PType): Unit = { - implicit val formats = AbstractRVDSpec.formats - val s = Serialization.write(ptype) - assert(Serialization.read[PType](s) == ptype) + object checkSerialization extends TestCases { + def apply(ptype: PType)(implicit loc: munit.Location): Unit = + test("Serialization") { + implicit val formats = AbstractRVDSpec.formats + val s = Serialization.write(ptype) + assertEquals(Serialization.read[PType](s), ptype) + } } - @Test def testLiteralPType(): Unit = { - assert(PType.literalPType(TInt32, 5) == PInt32(true)) - assert(PType.literalPType(TInt32, null) == PInt32()) + ptypes.foreach(checkSerialization(_)) - assert(PType.literalPType(TArray(TInt32), null) == PCanonicalArray(PInt32(true))) - assert(PType.literalPType(TArray(TInt32), FastSeq(1, null)) == PCanonicalArray(PInt32(), true)) - assert(PType.literalPType(TArray(TInt32), FastSeq(1, 5)) == PCanonicalArray(PInt32(true), true)) + test("LiteralPType") { + assertEquals(PType.literalPType(TInt32, 5), PInt32(true)) + assertEquals(PType.literalPType(TInt32, null), PInt32()) - assert(PType.literalPType( - TInterval(TInt32), - Interval(5, null, false, true), - ) == PCanonicalInterval(PInt32(), true)) + assertEquals(PType.literalPType(TArray(TInt32), null), PCanonicalArray(PInt32(true))) + assertEquals( + PType.literalPType(TArray(TInt32), FastSeq(1, null)), + PCanonicalArray(PInt32(), true), + ) + assertEquals( + PType.literalPType(TArray(TInt32), FastSeq(1, 5)), + PCanonicalArray(PInt32(true), true), + ) + + assertEquals( + PType.literalPType(TInterval(TInt32), Interval(5, null, false, true)), + PCanonicalInterval(PInt32(), true), + ) val p = TStruct("a" -> TInt32, "b" -> TInt32) val d = TDict(p, p) - assert(PType.literalPType(d, Map(Row(3, null) -> Row(null, 3))) == + assertEquals( + PType.literalPType(d, Map(Row(3, null) -> Row(null, 3))), PCanonicalDict( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32()), PCanonicalStruct(true, "a" -> PInt32(), "b" -> PInt32(true)), true, - )) + ), + ) } } diff --git a/hail/hail/test/src/is/hail/types/physical/PhysicalTestUtils.scala b/hail/hail/test/src/is/hail/types/physical/PhysicalTestUtils.scala index c8154cb6178..5220c4c2f60 100644 --- a/hail/hail/test/src/is/hail/types/physical/PhysicalTestUtils.scala +++ b/hail/hail/test/src/is/hail/types/physical/PhysicalTestUtils.scala @@ -55,7 +55,7 @@ abstract class PhysicalTestUtils extends HailSuite { return } - fail(e) + fail("unexpected error in interpret path", e) } return @@ -81,7 +81,7 @@ abstract class PhysicalTestUtils extends HailSuite { return } - fail(e) + fail("unexpected error in compile path", e) } if (compileSuccess && expectCompileError) { @@ -102,7 +102,7 @@ abstract class PhysicalTestUtils extends HailSuite { return } - fail(e) + fail("unexpected runtime error", e) } logger.info(s"Copied value: $copy, Source value: $sourceValue") diff --git a/hail/hail/test/src/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala b/hail/hail/test/src/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala index 0e6db0d34f1..707b10e0519 100644 --- a/hail/hail/test/src/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala +++ b/hail/hail/test/src/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala @@ -7,8 +7,6 @@ import is.hail.types.physical.stypes.interfaces.SBaseStruct import is.hail.types.tcoerce import is.hail.types.virtual.{TInt32, TInt64, TStruct} -import org.testng.annotations.Test - class SStructViewSuite extends HailSuite { val xyz: SBaseStruct = @@ -22,7 +20,7 @@ class SStructViewSuite extends HailSuite { ) ) - @Test def testCastRename(): Unit = { + test("CastRename") { val newtype = TStruct("x" -> TStruct("b" -> TInt32)) val expected = @@ -32,10 +30,10 @@ class SStructViewSuite extends HailSuite { rename = newtype, ) - assert(SStructView.subset(FastSeq("z"), xyz).castRename(newtype) == expected) + assertEquals(SStructView.subset(FastSeq("z"), xyz).castRename(newtype), expected) } - @Test def testSubsetRenameSubset(): Unit = { + test("SubsetRenameSubset") { val subset = SStructView.subset( FastSeq("x"), @@ -51,12 +49,13 @@ class SStructViewSuite extends HailSuite { rename = TStruct("x" -> TStruct("b" -> TInt32)), ) - assert(subset == expected) + assertEquals(subset, expected) } - @Test def testAssertIsomorphism(): Unit = - assertThrows[AssertionError] { + test("AssertIsomorphism") { + intercept[AssertionError] { SStructView.subset(FastSeq("x", "y"), xyz) .castRename(TStruct("x" -> TInt64, "x" -> TInt32)) } + } } diff --git a/hail/hail/test/src/is/hail/types/virtual/TStructSuite.scala b/hail/hail/test/src/is/hail/types/virtual/TStructSuite.scala index e2391d2b152..12b8d1c5e88 100644 --- a/hail/hail/test/src/is/hail/types/virtual/TStructSuite.scala +++ b/hail/hail/test/src/is/hail/types/virtual/TStructSuite.scala @@ -5,116 +5,170 @@ import is.hail.annotations.{Annotation, Inserter} import is.hail.collection.FastSeq import org.apache.spark.sql.Row -import org.testng.annotations.{DataProvider, Test} class TStructSuite extends HailSuite { - @DataProvider(name = "isPrefixOf") - def isPrefixOfData: Array[Array[Any]] = - Array( - Array(TStruct.empty, TStruct.empty, true), - Array(TStruct.empty, TStruct("a" -> TVoid), true), - Array(TStruct("a" -> TVoid), TStruct.empty, false), - Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid), true), - Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), - Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), - // isPrefixOf ignores field names and compares the ordered sequence of types. - // Consider joins for example - we only care that the key fields have the same types - // so we compare the key types (which are structs) for equality ignoring field names. - // isPrefixOf is used in similar cases involving key types where we don't care about names. - Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), - ) - - @Test(dataProvider = "isPrefixOf") - def testIsPrefixOf(a: TStruct, b: TStruct, isPrefix: Boolean): Unit = - assert(a.isPrefixOf(b) == isPrefix, s"expected $a `isPrefixOf` $b == $isPrefix") - - @DataProvider(name = "isSubsetOf") - def isSubsetOfData: Array[Array[Any]] = - Array( - Array(TStruct.empty, TStruct.empty, true), - Array(TStruct.empty, TStruct("a" -> TVoid), true), - Array(TStruct("a" -> TVoid), TStruct.empty, false), - Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid), true), - Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), - Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), - Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), - ) - - @Test(dataProvider = "isSubsetOf") - def testIsSubsetOf(a: TStruct, b: TStruct, isSubset: Boolean): Unit = - assert(a.isSubsetOf(b) == isSubset, s"expected $a `isSubsetOf` $b == $isSubset") - - @DataProvider(name = "structInsert") - def structInsertData: Array[Array[Any]] = - Array( - Array(TStruct("a" -> TInt32), FastSeq("a"), TInt32, TStruct("a" -> TInt32)), - Array(TStruct("a" -> TInt32), FastSeq("b"), TInt32, TStruct("a" -> TInt32, "b" -> TInt32)), - Array(TStruct("a" -> TInt32), FastSeq("a"), TVoid, TStruct("a" -> TVoid)), - Array( - TStruct("a" -> TInt32), - FastSeq("a", "b"), - TInt32, - TStruct("a" -> TStruct("b" -> TInt32)), - ), - Array(TStruct.empty, FastSeq("a"), TInt32, TStruct("a" -> TInt32)), - Array(TStruct.empty, FastSeq("a", "b"), TInt32, TStruct("a" -> TStruct("b" -> TInt32))), - ) - - @Test(dataProvider = "structInsert") - def testStructInsert(base: TStruct, path: IndexedSeq[String], signature: Type, expected: TStruct) - : Unit = - assert(base.structInsert(signature, path) == expected) - - @Test def testInsertEmptyPath(): Unit = - assertThrows[IllegalArgumentException] { + object checkIsPrefixOf extends TestCases { + def apply( + a: TStruct, + b: TStruct, + isPrefix: Boolean, + )(implicit loc: munit.Location + ): Unit = test("isPrefixOf") { + assert(a.isPrefixOf(b) == isPrefix, s"expected $a `isPrefixOf` $b == $isPrefix") + } + } + + checkIsPrefixOf(TStruct.empty, TStruct.empty, true) + checkIsPrefixOf(TStruct.empty, TStruct("a" -> TVoid), true) + checkIsPrefixOf(TStruct("a" -> TVoid), TStruct.empty, false) + checkIsPrefixOf(TStruct("a" -> TVoid), TStruct("a" -> TVoid), true) + checkIsPrefixOf(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + checkIsPrefixOf(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false) + // isPrefixOf ignores field names and compares the ordered sequence of types. + // Consider joins for example - we only care that the key fields have the same types + // so we compare the key types (which are structs) for equality ignoring field names. + // isPrefixOf is used in similar cases involving key types where we don't care about names. + checkIsPrefixOf(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + + object checkIsSubsetOf extends TestCases { + def apply( + a: TStruct, + b: TStruct, + isSubset: Boolean, + )(implicit loc: munit.Location + ): Unit = test("isSubsetOf") { + assert(a.isSubsetOf(b) == isSubset, s"expected $a `isSubsetOf` $b == $isSubset") + } + } + + checkIsSubsetOf(TStruct.empty, TStruct.empty, true) + checkIsSubsetOf(TStruct.empty, TStruct("a" -> TVoid), true) + checkIsSubsetOf(TStruct("a" -> TVoid), TStruct.empty, false) + checkIsSubsetOf(TStruct("a" -> TVoid), TStruct("a" -> TVoid), true) + checkIsSubsetOf(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + checkIsSubsetOf(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false) + checkIsSubsetOf(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + + object checkStructInsert extends TestCases { + def apply( + base: TStruct, + path: IndexedSeq[String], + signature: Type, + expected: TStruct, + )(implicit loc: munit.Location + ): Unit = test("structInsert") { + assertEquals(base.structInsert(signature, path), expected) + } + } + + checkStructInsert(TStruct("a" -> TInt32), FastSeq("a"), TInt32, TStruct("a" -> TInt32)) + + checkStructInsert( + TStruct("a" -> TInt32), + FastSeq("b"), + TInt32, + TStruct("a" -> TInt32, "b" -> TInt32), + ) + + checkStructInsert(TStruct("a" -> TInt32), FastSeq("a"), TVoid, TStruct("a" -> TVoid)) + + checkStructInsert( + TStruct("a" -> TInt32), + FastSeq("a", "b"), + TInt32, + TStruct("a" -> TStruct("b" -> TInt32)), + ) + + checkStructInsert(TStruct.empty, FastSeq("a"), TInt32, TStruct("a" -> TInt32)) + + checkStructInsert( + TStruct.empty, + FastSeq("a", "b"), + TInt32, + TStruct("a" -> TStruct("b" -> TInt32)), + ) + + test("InsertEmptyPath") { + intercept[IllegalArgumentException] { TStruct.empty.insert(TInt32, FastSeq()) } + } + + object checkInsert extends TestCases { + def apply( + inserter: Inserter, + base: Annotation, + value: Any, + expected: Annotation, + )(implicit loc: munit.Location + ): Unit = test("insert") { + assertEquals(inserter(base, value), expected) + } + } + + checkInsert(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, null, 0, Row(0)) + checkInsert(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, Row(0), 1, Row(1)) + checkInsert(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, null, 0, Row(null, 0)) + checkInsert(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, Row(0), 1, Row(0, 1)) + checkInsert(TStruct.empty.insert(TInt32, FastSeq("a", "b"))._2, null, 0, Row(Row(0))) + checkInsert(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a", "b"))._2, Row(0), 1, Row(Row(1))) + + object checkIsIsomorphicTo extends TestCases { + def apply( + a: TStruct, + b: TStruct, + isIsomorphic: Boolean, + )(implicit loc: munit.Location + ): Unit = test("isIsomorphicTo") { + assert( + (a isIsomorphicTo b) == isIsomorphic, + s"expected $a isIsomorphicTo $b == $isIsomorphic", + ) + } + } + + checkIsIsomorphicTo(TStruct.empty, TStruct.empty, true) + checkIsIsomorphicTo(TStruct.empty, TStruct("a" -> TVoid), false) + checkIsIsomorphicTo(TStruct("a" -> TVoid), TStruct.empty, false) + checkIsIsomorphicTo(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true) + + checkIsIsomorphicTo( + TStruct("a" -> TStruct("b" -> TVoid)), + TStruct("b" -> TStruct("a" -> TVoid)), + true, + ) + + checkIsIsomorphicTo(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false) + checkIsIsomorphicTo(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false) + + object checkIsJoinableWith extends TestCases { + def apply( + a: TStruct, + b: TStruct, + isJoinable: Boolean, + )(implicit loc: munit.Location + ): Unit = test("isJoinableWith") { + assert((a isJoinableWith b) == isJoinable, s"expected $a isJoinableWith $b == $isJoinable") + } + } + + checkIsJoinableWith(TStruct.empty, TStruct.empty, true) + checkIsJoinableWith(TStruct.empty, TStruct("a" -> TVoid), false) + checkIsJoinableWith(TStruct("a" -> TVoid), TStruct.empty, false) + checkIsJoinableWith(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true) + + checkIsJoinableWith( + TStruct("a" -> TStruct("a" -> TVoid)), + TStruct("b" -> TStruct("a" -> TVoid)), + true, + ) + + checkIsJoinableWith( + TStruct("a" -> TStruct("a" -> TVoid)), + TStruct("b" -> TStruct("b" -> TVoid)), + false, + ) - @DataProvider(name = "inserter") - def inserterData: Array[Array[Any]] = - Array( - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, null, 0, Row(0)), - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, Row(0), 1, Row(1)), - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, null, 0, Row(null, 0)), - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, Row(0), 1, Row(0, 1)), - Array(TStruct.empty.insert(TInt32, FastSeq("a", "b"))._2, null, 0, Row(Row(0))), - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a", "b"))._2, Row(0), 1, Row(Row(1))), - ) - - @Test(dataProvider = "inserter") - def testInsert(inserter: Inserter, base: Annotation, value: Any, expected: Annotation): Unit = - assert(inserter(base, value) == expected) - - @DataProvider(name = "isIsomorphicTo") - def isIsomorphicToData: Array[Array[Any]] = - Array( - Array(TStruct.empty, TStruct.empty, true), - Array(TStruct.empty, TStruct("a" -> TVoid), false), - Array(TStruct("a" -> TVoid), TStruct.empty, false), - Array(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true), - Array(TStruct("a" -> TStruct("b" -> TVoid)), TStruct("b" -> TStruct("a" -> TVoid)), true), - Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false), - Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), - ) - - @Test(dataProvider = "isIsomorphicTo") - def testIsIsomorphicTo(a: TStruct, b: TStruct, isIsomorphic: Boolean): Unit = - assert((a isIsomorphicTo b) == isIsomorphic, s"expected $a isIsomorphicTo $b == $isIsomorphic") - - @DataProvider(name = "isJoinableWith") - def isJoinableWithData: Array[Array[Any]] = - Array( - Array(TStruct.empty, TStruct.empty, true), - Array(TStruct.empty, TStruct("a" -> TVoid), false), - Array(TStruct("a" -> TVoid), TStruct.empty, false), - Array(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true), - Array(TStruct("a" -> TStruct("a" -> TVoid)), TStruct("b" -> TStruct("a" -> TVoid)), true), - Array(TStruct("a" -> TStruct("a" -> TVoid)), TStruct("b" -> TStruct("b" -> TVoid)), false), - Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), - Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false), - ) - - @Test(dataProvider = "isJoinableWith") - def testIsJoinableWith(a: TStruct, b: TStruct, isJoinable: Boolean): Unit = - assert((a isJoinableWith b) == isJoinable, s"expected $a isJoinableWith $b == $isJoinable") + checkIsJoinableWith(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false) + checkIsJoinableWith(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false) } diff --git a/hail/hail/test/src/is/hail/utils/BufferedAggregatorIteratorSuite.scala b/hail/hail/test/src/is/hail/utils/BufferedAggregatorIteratorSuite.scala index 3d616f63be6..c7124a3765a 100644 --- a/hail/hail/test/src/is/hail/utils/BufferedAggregatorIteratorSuite.scala +++ b/hail/hail/test/src/is/hail/utils/BufferedAggregatorIteratorSuite.scala @@ -4,11 +4,7 @@ import scala.collection.compat._ import org.scalacheck.Gen import org.scalacheck.Gen._ -import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll class SumAgg { var x: Long = 0 @@ -22,7 +18,7 @@ class SumAgg { s"${getClass.getSimpleName}($x)" } -class BufferedAggregatorIteratorSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { +class BufferedAggregatorIteratorSuite extends munit.ScalaCheckSuite { private[this] lazy val gen: Gen[(Array[(Int, Long)], Int)] = for { @@ -30,7 +26,7 @@ class BufferedAggregatorIteratorSuite extends TestNGSuite with ScalaCheckDrivenP len <- choose(1, 5) } yield (data, len) - @Test def test(): Unit = + property("BufferedAggregatorIterator matches simple groupBy") = forAll(gen) { case (arr, bufferSize) => val simple: Map[Int, Long] = arr.groupBy(_._1).map { case (k, a) => k -> a.map(_._2).sum } @@ -47,7 +43,7 @@ class BufferedAggregatorIteratorSuite extends TestNGSuite with ScalaCheckDrivenP .toArray .groupMapReduce(_._1)(_._2)(_ comb _) .view.mapValues(_.x).toMap - simple should be(buffAgg) + simple == buffAgg } } diff --git a/hail/hail/test/src/is/hail/utils/GraphSuite.scala b/hail/hail/test/src/is/hail/utils/GraphSuite.scala index a46e2ed9337..f22b4112344 100644 --- a/hail/hail/test/src/is/hail/utils/GraphSuite.scala +++ b/hail/hail/test/src/is/hail/utils/GraphSuite.scala @@ -2,10 +2,7 @@ package is.hail.utils import scala.collection.mutable -import org.scalatest.matchers.should.Matchers._ -import org.testng.annotations.Test - -class GraphSuite { +class GraphSuite extends munit.FunSuite { import Graph._ @@ -14,23 +11,20 @@ class GraphSuite { x.forall(x => g(x).intersect(s).isEmpty) } - @Test def simple(): Unit = { + test("simple") { { val actual = maximalIndependentSet(Array((0 -> 1))) - actual should ((contain theSameElementsAs Array(0)) or (contain theSameElementsAs Array(1))) + assert(actual.toSet == Set(0) || actual.toSet == Set(1)) } { val actual = maximalIndependentSet(Array(0 -> 1, 0 -> 2)) - actual should contain theSameElementsAs Array(1, 2) + assertEquals(actual.toSet, Set(1, 2)) } { val actual = maximalIndependentSet(Array(0 -> 1, 0 -> 2, 3 -> 1, 3 -> 2)) - actual should ((contain theSameElementsAs Array(1, 2)) or (contain theSameElementsAs Array( - 0, - 3, - ))) + assert(actual.toSet == Set(1, 2) || actual.toSet == Set(0, 3)) } { @@ -40,23 +34,23 @@ class GraphSuite { } } - @Test def longCycle(): Unit = { + test("longCycle") { val g = mkGraph(0 -> 1, 1 -> 2, 2 -> 3, 3 -> 4, 4 -> 5, 5 -> 6, 6 -> 0) val actual = maximalIndependentSet(g) assert(isIndependentSet(actual, g)) - assert(actual.length == 3) + assertEquals(actual.length, 3) } - @Test def twoPopularNodes(): Unit = { + test("twoPopularNodes") { val g = mkGraph(0 -> 1, 0 -> 2, 0 -> 3, 4 -> 5, 4 -> 6, 4 -> 0) val actual = maximalIndependentSet(g) assert(isIndependentSet(actual, g)) - assert(actual.length == 5) + assertEquals(actual.length, 5) } - @Test def totallyDisconnected(): Unit = { + test("totallyDisconnected") { val expected = 0 until 10 val m = mutable.HashMap.empty[Int, mutable.Set[Int]] @@ -65,52 +59,52 @@ class GraphSuite { val actual = maximalIndependentSet(m) - actual should contain theSameElementsAs expected + assertEquals(actual.toSet, expected.toSet) } - @Test def disconnected(): Unit = { + test("disconnected") { val g = mkGraph(for (i <- 0 until 10) yield (i, i + 10)) val actual = maximalIndependentSet(g) assert(isIndependentSet(actual, g)) - assert(actual.length == 10) + assertEquals(actual.length, 10) } - @Test def selfEdge(): Unit = { + test("selfEdge") { val g = mkGraph(0 -> 0, 1 -> 2, 1 -> 3) val actual = maximalIndependentSet(g) assert(isIndependentSet(actual, g)) - actual should contain theSameElementsAs Array(2, 3) + assertEquals(actual.toSet, Set(2, 3)) } - @Test def emptyGraph(): Unit = { + test("emptyGraph") { val g = mkGraph[Int]() val actual = maximalIndependentSet(g) - assert(actual === Array[Int]()) + assertEquals(actual.toSeq, Seq.empty[Int]) } - @Test def tieBreakingOfBipartiteGraphWorks(): Unit = { + test("tieBreakingOfBipartiteGraphWorks") { val g = mkGraph(for (i <- 0 until 10) yield (i, i + 10)) // prefer to remove big numbers val actual = maximalIndependentSet(g, Some((l: Int, r: Int) => (l - r).toDouble)) assert(isIndependentSet(actual, g)) - assert(actual.length == 10) + assertEquals(actual.length, 10) assert(actual.forall(_ < 10)) } - @Test def tieBreakingInLongCycleWorks(): Unit = { + test("tieBreakingInLongCycleWorks") { val g = mkGraph(0 -> 1, 1 -> 2, 2 -> 3, 3 -> 4, 4 -> 5, 5 -> 6, 6 -> 0) // prefers to remove small numbers val actual = maximalIndependentSet(g, Some((l: Int, r: Int) => (r - l).toDouble)) assert(isIndependentSet(actual, g)) - assert(actual.length == 3) - actual should contain theSameElementsAs Array[Int](1, 3, 6) + assertEquals(actual.length, 3) + assertEquals(actual.toSet, Set(1, 3, 6)) } } diff --git a/hail/hail/test/src/is/hail/utils/IntervalSuite.scala b/hail/hail/test/src/is/hail/utils/IntervalSuite.scala index f86a81dac8a..62c77842c3c 100644 --- a/hail/hail/test/src/is/hail/utils/IntervalSuite.scala +++ b/hail/hail/test/src/is/hail/utils/IntervalSuite.scala @@ -9,10 +9,6 @@ import is.hail.rvd.RVDPartitioner import is.hail.types.virtual.{TInt32, TStruct} import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll -import org.scalatest.enablers.InspectorAsserting.assertingNatureOfAssertion -import org.testng.ITestContext -import org.testng.annotations.{BeforeMethod, Test} class IntervalSuite extends HailSuite { @@ -33,8 +29,8 @@ class IntervalSuite extends HailSuite { var test_itrees: IndexedSeq[SetIntervalTree] = _ - @BeforeMethod - def setupIntervalTrees(context: ITestContext): Unit = { + override def beforeEach(context: BeforeEach): Unit = { + super.beforeEach(context) test_itrees = SetIntervalTree(ctx, ArraySeq[(SetInterval, Int)]()) +: test_intervals.flatMap { i1 => SetIntervalTree(ctx, ArraySeq(i1).zipWithIndex) +: @@ -47,39 +43,43 @@ class IntervalSuite extends HailSuite { } } - @Test def interval_agrees_with_set_interval_greater_than_point(): Unit = - forAll(cartesian(test_intervals, points)) { case (set_interval, p) => + test("interval agrees with set interval greater than point") { + cartesian(test_intervals, points).foreach { case (set_interval, p) => val interval = set_interval.interval assert( interval.isAbovePosition(pord, p) == set_interval.doubledPointSet.forall(dp => dp > 2 * p) ) } + } - @Test def interval_agrees_with_set_interval_less_than_point(): Unit = - forAll(cartesian(test_intervals, points)) { case (set_interval, p) => + test("interval agrees with set interval less than point") { + cartesian(test_intervals, points).foreach { case (set_interval, p) => val interval = set_interval.interval assert( interval.isBelowPosition(pord, p) == set_interval.doubledPointSet.forall(dp => dp < 2 * p) ) } + } - @Test def interval_agrees_with_set_interval_contains(): Unit = - forAll(cartesian(test_intervals, points)) { case (set_interval, p) => + test("interval agrees with set interval contains") { + cartesian(test_intervals, points).foreach { case (set_interval, p) => val interval = set_interval.interval assert(interval.contains(pord, p) == set_interval.contains(p)) } + } - @Test def interval_agrees_with_set_interval_includes(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval includes") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert(interval1.includes(pord, interval2) == set_interval1.includes(set_interval2)) } + } - @Test def interval_agrees_with_set_interval_probably_overlaps(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval probably overlaps") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert( @@ -87,9 +87,10 @@ class IntervalSuite extends HailSuite { set_interval1.probablyOverlaps(set_interval2) ) } + } - @Test def interval_agrees_with_set_interval_definitely_disjoint(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval definitely disjoint") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert( @@ -97,30 +98,34 @@ class IntervalSuite extends HailSuite { set_interval1.definitelyDisjoint(set_interval2) ) } + } - @Test def interval_agrees_with_set_interval_disjoint_greater_than(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval disjoint greater than") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert(interval1.isAbove(pord, interval2) == set_interval1.isAboveInterval(set_interval2)) } + } - @Test def interval_agrees_with_set_interval_disjoint_less_than(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval disjoint less than") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert(interval1.isBelow(pord, interval2) == set_interval1.isBelowInterval(set_interval2)) } + } - @Test def interval_agrees_with_set_interval_mergeable(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval mergeable") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert(interval1.canMergeWith(pord, interval2) == set_interval1.mergeable(set_interval2)) } + } - @Test def interval_agrees_with_set_interval_merge(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval merge") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert( @@ -128,9 +133,10 @@ class IntervalSuite extends HailSuite { set_interval1.union(set_interval2).map(_.interval) ) } + } - @Test def interval_agrees_with_set_interval_intersect(): Unit = - forAll(cartesian(test_intervals, test_intervals)) { case (set_interval1, set_interval2) => + test("interval agrees with set interval intersect") { + cartesian(test_intervals, test_intervals).foreach { case (set_interval1, set_interval2) => val interval1 = set_interval1.interval val interval2 = set_interval2.interval assert( @@ -138,43 +144,49 @@ class IntervalSuite extends HailSuite { set_interval1.intersect(set_interval2).map(_.interval) ) } + } - @Test def interval_tree_agrees_with_set_interval_tree_contains(): Unit = - forAll(cartesian(test_itrees, points)) { case (set_itree, p) => + test("interval tree agrees with set interval tree contains") { + cartesian(test_itrees, points).foreach { case (set_itree, p) => val itree = set_itree.intervalTree assert(itree.contains(Row(p)) == set_itree.contains(p)) } + } - @Test def interval_tree_agrees_with_set_interval_tree_probably_overlaps(): Unit = - forAll(cartesian(test_itrees, test_intervals)) { case (set_itree, set_interval) => + test("interval tree agrees with set interval tree probably overlaps") { + cartesian(test_itrees, test_intervals).foreach { case (set_itree, set_interval) => val itree = set_itree.intervalTree val interval = set_interval.rowInterval assert(itree.overlaps(interval) == set_itree.probablyOverlaps(set_interval)) } + } - @Test def interval_tree_agrees_with_set_interval_tree_definitely_disjoint(): Unit = - forAll(cartesian(test_itrees, test_intervals)) { case (set_itree, set_interval) => + test("interval tree agrees with set interval tree definitely disjoint") { + cartesian(test_itrees, test_intervals).foreach { case (set_itree, set_interval) => val itree = set_itree.intervalTree val interval = set_interval.rowInterval assert(itree.isDisjointFrom(interval) == set_itree.definitelyDisjoint(set_interval)) } + } - @Test def interval_tree_agrees_with_set_interval_tree_query_values(): Unit = - forAll(cartesian(test_itrees, points)) { case (set_itree, point) => + test("interval tree agrees with set interval tree query values") { + cartesian(test_itrees, points).foreach { case (set_itree, point) => val itree = set_itree.intervalTree val result = itree.queryKey(Row(point)) assert(result.areDistinct()) assert(result.toSet == set_itree.queryValues(point)) } + } - @Test def interval_tree_agrees_with_set_interval_tree_query_overlapping_values(): Unit = - forAll(cartesian(test_itrees, test_intervals)) { case (set_itree, set_interval) => + test("interval tree agrees with set interval tree query overlapping values") { + cartesian(test_itrees, test_intervals).foreach { case (set_itree, set_interval) => val itree = set_itree.intervalTree val interval = set_interval.rowInterval val result = itree.queryInterval(interval) assert(result.areDistinct()) assert(result.toSet == set_itree.queryProbablyOverlappingValues(set_interval)) } + } } object SetInterval { diff --git a/hail/hail/test/src/is/hail/utils/PartitionCountsSuite.scala b/hail/hail/test/src/is/hail/utils/PartitionCountsSuite.scala index 2f3997a6f9d..00d0d26f67c 100644 --- a/hail/hail/test/src/is/hail/utils/PartitionCountsSuite.scala +++ b/hail/hail/test/src/is/hail/utils/PartitionCountsSuite.scala @@ -2,12 +2,9 @@ package is.hail.utils import is.hail.utils.PartitionCounts._ -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test +class PartitionCountsSuite extends munit.FunSuite { -class PartitionCountsSuite extends TestNGSuite { - - @Test def testHeadPCs() = { + test("HeadPCs") { for ( ((a, n), b) <- Seq( (IndexedSeq(0L), 5L) -> IndexedSeq(0L), @@ -19,10 +16,10 @@ class PartitionCountsSuite extends TestNGSuite { (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L), ) ) - assert(getHeadPCs(a, n) == b, s"getHeadPartitionCounts($a, $n)") + assertEquals(getHeadPCs(a, n), b, s"getHeadPartitionCounts($a, $n)") } - @Test def testTailPCs() = { + test("TailPCs") { for ( ((a, n), b) <- Seq( (IndexedSeq(0L), 5L) -> IndexedSeq(0L), @@ -34,15 +31,16 @@ class PartitionCountsSuite extends TestNGSuite { (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L), ) ) { - assert(getTailPCs(a, n) == b, s"getTailPartitionCounts($a, $n)") - assert( - getTailPCs(a, n) == getHeadPCs(a.reverse, n).reverse, + assertEquals(getTailPCs(a, n), b, s"getTailPartitionCounts($a, $n)") + assertEquals( + getTailPCs(a, n), + getHeadPCs(a.reverse, n).reverse, s"getTailPartitionCounts($a, $n) via head", ) } } - @Test def testIncrementalPCSubset() = { + test("IncrementalPCSubset") { val pcs = Array(0L, 0L, 5L, 6L, 4L, 3L, 3L, 3L, 2L, 1L) def headOffset(n: Long) = @@ -51,8 +49,8 @@ class PartitionCountsSuite extends TestNGSuite { for (n <- 0L until pcs.sum) { val PCSubsetOffset(i, nKeep, nDrop) = headOffset(n) val total = (0 to i).map(j => if (j == i) nKeep else pcs(j)).sum - assert(nKeep + nDrop == pcs(i)) - assert(total == n) + assertEquals(nKeep + nDrop, pcs(i)) + assertEquals(total, n) } def tailOffset(n: Long) = @@ -61,8 +59,8 @@ class PartitionCountsSuite extends TestNGSuite { for (n <- 0L until pcs.sum) { val PCSubsetOffset(i, nKeep, nDrop) = tailOffset(n) val total = (i to (pcs.length - 1)).map(j => if (j == i) nKeep else pcs(j)).sum - assert(nKeep + nDrop == pcs(i)) - assert(total == n) + assertEquals(nKeep + nDrop, pcs(i)) + assertEquals(total, n) } } } diff --git a/hail/hail/test/src/is/hail/utils/RowIntervalSuite.scala b/hail/hail/test/src/is/hail/utils/RowIntervalSuite.scala index 7bdc450fb8c..3e922b89fff 100644 --- a/hail/hail/test/src/is/hail/utils/RowIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/utils/RowIntervalSuite.scala @@ -8,7 +8,6 @@ import is.hail.rvd.{PartitionBoundOrdering, RVDPartitioner} import is.hail.types.virtual.{TBoolean, TInt32, TStruct} import org.apache.spark.sql.Row -import org.testng.annotations.Test class RowIntervalSuite extends HailSuite { lazy val t = TStruct("a" -> TInt32, "b" -> TInt32, "c" -> TInt32) @@ -35,7 +34,7 @@ class RowIntervalSuite extends HailSuite { )(ExecStrategy.compileOnly) } - @Test def testContains(): Unit = { + test("Contains") { assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true), Row(1, 1, 3)) assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true), Row(0, 1, 5)) assertContains( @@ -85,7 +84,7 @@ class RowIntervalSuite extends HailSuite { assert(!Interval(Row(0, 1, 5, 7), Row(2, 1, 4, 5), false, false).contains(pord, Row(0, 1, 5))) } - @Test def testAbovePosition(): Unit = { + test("AbovePosition") { assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isAbovePosition(pord, Row(0, 1, 4))) assert(Interval(Row(0, 1, 5), Row(1, 2, 4), false, true).isAbovePosition(pord, Row(0, 1, 5))) assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isAbovePosition(pord, Row(0, 1, 5))) @@ -105,7 +104,7 @@ class RowIntervalSuite extends HailSuite { )) } - @Test def testBelowPosition(): Unit = { + test("BelowPosition") { assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isBelowPosition(pord, Row(1, 2, 5))) assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, false).isBelowPosition(pord, Row(1, 2, 4))) assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isBelowPosition(pord, Row(1, 2, 4))) @@ -116,7 +115,7 @@ class RowIntervalSuite extends HailSuite { assert(!Interval(Row(1, 1, 8), Row(1, 2), true, true).isBelowPosition(pord, Row(1, 2, 5))) } - @Test def testAbutts(): Unit = { + test("Abutts") { assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).abutts( pord, Interval(Row(1, 2, 4), Row(1, 3, 4), false, true), @@ -136,7 +135,7 @@ class RowIntervalSuite extends HailSuite { )) } - @Test def testLteqWithOverlap(): Unit = { + test("LteqWithOverlap") { val eord = pord.intervalEndpointOrdering assert(!eord.lteqWithOverlap(3)( IntervalEndpoint(Row(0, 1, 6), -1), @@ -210,7 +209,7 @@ class RowIntervalSuite extends HailSuite { )) } - @Test def testIsValid(): Unit = { + test("IsValid") { assert(Interval.isValid(pord, Row(0, 1, 5), Row(0, 2), false, false)) assert(!Interval.isValid(pord, Row(0, 1, 5), Row(0, 0), false, false)) assert(Interval.isValid(pord, Row(0, 1, 5), Row(0, 1), false, true)) diff --git a/hail/hail/test/src/is/hail/utils/SemanticVersionSuite.scala b/hail/hail/test/src/is/hail/utils/SemanticVersionSuite.scala index 0f3b2b47240..a638830318e 100644 --- a/hail/hail/test/src/is/hail/utils/SemanticVersionSuite.scala +++ b/hail/hail/test/src/is/hail/utils/SemanticVersionSuite.scala @@ -1,10 +1,7 @@ package is.hail.utils -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class SemanticVersionSuite extends TestNGSuite { - @Test def testOrdering(): Unit = { +class SemanticVersionSuite extends munit.FunSuite { + test("Ordering") { val versions = Array( SemanticVersion(1, 1, 0), SemanticVersion(1, 1, 1), diff --git a/hail/hail/test/src/is/hail/utils/SpillingCollectIteratorSuite.scala b/hail/hail/test/src/is/hail/utils/SpillingCollectIteratorSuite.scala index b2205236b55..54734162c76 100644 --- a/hail/hail/test/src/is/hail/utils/SpillingCollectIteratorSuite.scala +++ b/hail/hail/test/src/is/hail/utils/SpillingCollectIteratorSuite.scala @@ -2,16 +2,14 @@ package is.hail.utils import is.hail.HailSuite -import org.testng.annotations.Test - class SpillingCollectIteratorSuite extends HailSuite { - @Test def addOneElement(): Unit = { + test("addOneElement") { val array = (0 to 1234) val sci = SpillingCollectIterator(ctx.localTmpdir, fs, sc.parallelize(array, 99), 100) assert(sci.hasNext) - assert(sci.next() == 0) + assertEquals(sci.next(), 0) assert(sci.hasNext) - assert(sci.next() == 1) + assertEquals(sci.next(), 1) assert(sci.toArray sameElements (2 to 1234)) } } diff --git a/hail/hail/test/src/is/hail/utils/TreeTraversalSuite.scala b/hail/hail/test/src/is/hail/utils/TreeTraversalSuite.scala index 0106d66c23a..d3637e55fb1 100644 --- a/hail/hail/test/src/is/hail/utils/TreeTraversalSuite.scala +++ b/hail/hail/test/src/is/hail/utils/TreeTraversalSuite.scala @@ -1,32 +1,29 @@ package is.hail.utils -import org.testng.Assert -import org.testng.annotations.Test - -class TreeTraversalSuite { +class TreeTraversalSuite extends munit.FunSuite { def binaryTree(i: Int): Iterator[Int] = (1 to 2).map(2 * i + _).iterator.filter(_ < 7) - @Test def testPostOrder() = - Assert.assertEquals( - TreeTraversal.postOrder(binaryTree)(0).toArray, - Array(3, 4, 1, 5, 6, 2, 0), - "", + test("PostOrder") { + assertEquals( + TreeTraversal.postOrder(binaryTree)(0).toArray.toSeq, + Array(3, 4, 1, 5, 6, 2, 0).toSeq, ) + } - @Test def testPreOrder() = - Assert.assertEquals( - TreeTraversal.preOrder(binaryTree)(0).toArray, - Array(0, 1, 3, 4, 2, 5, 6), - "", + test("PreOrder") { + assertEquals( + TreeTraversal.preOrder(binaryTree)(0).toArray.toSeq, + Array(0, 1, 3, 4, 2, 5, 6).toSeq, ) + } - @Test def levelOrder() = - Assert.assertEquals( - TreeTraversal.levelOrder(binaryTree)(0).toArray, - (0 to 6).toArray, - "", + test("levelOrder") { + assertEquals( + TreeTraversal.levelOrder(binaryTree)(0).toArray.toSeq, + (0 to 6).toArray.toSeq, ) + } } diff --git a/hail/hail/test/src/is/hail/utils/UnionFindSuite.scala b/hail/hail/test/src/is/hail/utils/UnionFindSuite.scala index 734cf984453..98e2392e4e9 100644 --- a/hail/hail/test/src/is/hail/utils/UnionFindSuite.scala +++ b/hail/hail/test/src/is/hail/utils/UnionFindSuite.scala @@ -1,31 +1,26 @@ package is.hail.utils -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test - -class UnionFindSuite extends TestNGSuite { - @Test - def emptyUnionFindHasNoSets(): Unit = - assert(new UnionFind().size == 0) +class UnionFindSuite extends munit.FunSuite { + test("emptyUnionFindHasNoSets") { + assertEquals(new UnionFind().size, 0) + } - @Test - def growingPastInitialCapacityOK(): Unit = { + test("growingPastInitialCapacityOK") { val uf = new UnionFind(4) uf.makeSet(0) uf.makeSet(1) uf.makeSet(2) uf.makeSet(3) uf.makeSet(4) - assert(uf.find(0) == 0) - assert(uf.find(1) == 1) - assert(uf.find(2) == 2) - assert(uf.find(3) == 3) - assert(uf.find(4) == 4) - assert(uf.size == 5) + assertEquals(uf.find(0), 0) + assertEquals(uf.find(1), 1) + assertEquals(uf.find(2), 2) + assertEquals(uf.find(3), 3) + assertEquals(uf.find(4), 4) + assertEquals(uf.size, 5) } - @Test - def simpleUnions(): Unit = { + test("simpleUnions") { val uf = new UnionFind() uf.makeSet(0) @@ -34,12 +29,11 @@ class UnionFindSuite extends TestNGSuite { uf.union(0, 1) val (x, y) = (uf.find(0), uf.find(1)) - assert(x == y) + assertEquals(x, y) assert(x == 0 || x == 1) } - @Test - def nonMonotonicMakeSet(): Unit = { + test("nonMonotonicMakeSet") { val uf = new UnionFind() uf.makeSet(1000) @@ -47,28 +41,27 @@ class UnionFindSuite extends TestNGSuite { uf.makeSet(4097) uf.makeSet(4095) - assert(uf.find(1000) == 1000) - assert(uf.find(1024) == 1024) - assert(uf.find(4097) == 4097) - assert(uf.find(4095) == 4095) + assertEquals(uf.find(1000), 1000) + assertEquals(uf.find(1024), 1024) + assertEquals(uf.find(4097), 4097) + assertEquals(uf.find(4095), 4095) assert(!uf.sameSet(1000, 1024)) assert(!uf.sameSet(1000, 4097)) assert(!uf.sameSet(1000, 4095)) assert(!uf.sameSet(1024, 4097)) assert(!uf.sameSet(1024, 4095)) assert(!uf.sameSet(4097, 4095)) - assert(uf.size == 4) + assertEquals(uf.size, 4) } - @Test - def multipleUnions(): Unit = { + test("multipleUnions") { val uf = new UnionFind() uf.makeSet(1) uf.makeSet(2) uf.makeSet(3) uf.makeSet(4) - assert(uf.size == 4) + assertEquals(uf.size, 4) uf.union(1, 2) @@ -76,7 +69,7 @@ class UnionFindSuite extends TestNGSuite { assert(!uf.sameSet(1, 3)) assert(!uf.sameSet(1, 4)) assert(!uf.sameSet(3, 4)) - assert(uf.size == 3) + assertEquals(uf.size, 3) uf.union(1, 4) @@ -84,18 +77,17 @@ class UnionFindSuite extends TestNGSuite { assert(!uf.sameSet(1, 3)) assert(uf.sameSet(1, 4)) assert(!uf.sameSet(3, 4)) - assert(uf.size == 2) + assertEquals(uf.size, 2) uf.union(2, 3) assert(uf.sameSet(1, 2)) assert(uf.sameSet(1, 3)) assert(uf.sameSet(1, 4)) - assert(uf.size == 1) + assertEquals(uf.size, 1) } - @Test - def unionsNoInterveningFinds(): Unit = { + test("unionsNoInterveningFinds") { val uf = new UnionFind() uf.makeSet(1) @@ -105,14 +97,14 @@ class UnionFindSuite extends TestNGSuite { uf.makeSet(5) uf.makeSet(6) - assert(uf.size == 6) + assertEquals(uf.size, 6) uf.union(1, 2) uf.union(1, 4) uf.union(5, 3) uf.union(2, 6) - assert(uf.size == 2) + assertEquals(uf.size, 2) assert(uf.sameSet(1, 2)) assert(uf.sameSet(1, 4)) assert(uf.sameSet(5, 3)) @@ -120,8 +112,7 @@ class UnionFindSuite extends TestNGSuite { assert(!uf.sameSet(1, 5)) } - @Test - def sameSetWorks(): Unit = { + test("sameSetWorks") { val uf = new UnionFind() uf.makeSet(1) diff --git a/hail/hail/test/src/is/hail/utils/UtilsSuite.scala b/hail/hail/test/src/is/hail/utils/UtilsSuite.scala index 967150d461a..5786ec9e0ef 100644 --- a/hail/hail/test/src/is/hail/utils/UtilsSuite.scala +++ b/hail/hail/test/src/is/hail/utils/UtilsSuite.scala @@ -9,12 +9,10 @@ import is.hail.sparkextras.implicits._ import org.apache.spark.storage.StorageLevel import org.scalacheck.Gen import org.scalacheck.Gen.containerOf -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { - @Test def testD_==(): Unit = { +class UtilsSuite extends HailSuite with munit.ScalaCheckSuite { + test("D_==") { assert(D_==(1, 1)) assert(D_==(1, 1 + 1e-7)) assert(!D_==(1, 1 + 1e-5)) @@ -26,7 +24,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(D_==(1e-320, -1e-320)) } - @Test def testFlushDouble(): Unit = { + test("FlushDouble") { assert(flushDouble(8.0e-308) == 8.0e-308) assert(flushDouble(-8.0e-308) == -8.0e-308) assert(flushDouble(8.0e-309) == 0.0) @@ -34,7 +32,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(flushDouble(0.0) == 0.0) } - @Test def testAreDistinct(): Unit = { + test("AreDistinct") { assert(Array().areDistinct()) assert(Array(1).areDistinct()) assert(Array(1, 2).areDistinct()) @@ -42,7 +40,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(!Array(1, 2, 1).areDistinct()) } - @Test def testIsIncreasing(): Unit = { + test("IsIncreasing") { assert(Seq[Int]().isIncreasing) assert(Seq(1).isIncreasing) assert(Seq(1, 2).isIncreasing) @@ -52,7 +50,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(Array(1, 2).isIncreasing) } - @Test def testIsSorted(): Unit = { + test("IsSorted") { assert(Seq[Int]().isSorted) assert(Seq(1).isSorted) assert(Seq(1, 2).isSorted) @@ -62,7 +60,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(Array(1, 1).isSorted) } - @Test def testPairRDDNoDup(): Unit = { + test("PairRDDNoDup") { val answer1 = Array((1, (1, Option(1))), (2, (4, Option(2))), (3, (9, Option(3))), (4, (16, Option(4)))) val pairRDD1 = sc.parallelize(ArraySeq(1, 2, 3, 4)).map(i => (i, i * i)) @@ -70,10 +68,10 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val join = pairRDD1.leftOuterJoin(pairRDD2.distinct()) assert(join.collect().sortBy(t => t._1) sameElements answer1) - assert(join.count() == 4) + assertEquals(join.count(), 4L) } - @Test def testForallExists(): Unit = { + test("ForallExists") { val rdd1 = sc.parallelize(ArraySeq(1, 2, 3, 4, 5)) assert(rdd1.forall(_ > 0)) @@ -83,7 +81,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(!rdd1.exists(_ < 0)) } - @Test def testSortFileListEntry(): Unit = { + test("SortFileListEntry") { val fs = new HadoopFS(new SerializableHadoopConfiguration(sc.hadoopConfiguration)) val partFileNames = fs.glob(getTestResource("part-*")) @@ -93,10 +91,11 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { ).last ) - assert(partFileNames(0) == "part-40001" && partFileNames(1) == "part-100001") + assert(partFileNames(0) == "part-40001") + assert(partFileNames(1) == "part-100001") } - @Test def storageLevelStringTest() = { + test("storageLevelString") { val sls = List( "NONE", "DISK_ONLY", "DISK_ONLY_2", "MEMORY_ONLY", "MEMORY_ONLY_2", "MEMORY_ONLY_SER", "MEMORY_ONLY_SER_2", @@ -106,7 +105,7 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { sls.foreach(sl => StorageLevel.fromString(sl)) } - @Test def testDictionaryOrdering(): Unit = { + test("DictionaryOrdering") { val stringList = Seq("Cats", "Crayon", "Dog") val longestToShortestLength = Ordering.by[String, Int](-_.length) @@ -114,43 +113,48 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { val alphabetically = Ordering.String val ord1 = dictionaryOrdering(alphabetically, byFirstLetter, longestToShortestLength) - assert(stringList.sorted(ord1) == stringList) + assertEquals(stringList.sorted(ord1), stringList) val ord2 = dictionaryOrdering(byFirstLetter, longestToShortestLength) - assert(stringList.sorted(ord2) == Seq("Crayon", "Cats", "Dog")) + assertEquals(stringList.sorted(ord2), Seq("Crayon", "Cats", "Dog")) } - @Test def testCollectAsSet(): Unit = + property("CollectAsSet") = forAll(containerOf[ArraySeq, Int](Gen.choose(-1000, 1000)), Gen.choose(1, 10)) { case (values, parts) => val rdd = sc.parallelize(values, numSlices = parts) - assert(rdd.collectAsSet() == rdd.collect().toSet) + rdd.collectAsSet() == rdd.collect().toSet } - @Test def testDigitsNeeded(): Unit = { - assert(digitsNeeded(0) == 1) - assert(digitsNeeded(1) == 1) - assert(digitsNeeded(7) == 1) - assert(digitsNeeded(9) == 1) - assert(digitsNeeded(13) == 2) - assert(digitsNeeded(30173) == 5) + test("DigitsNeeded") { + assertEquals(digitsNeeded(0), 1) + assertEquals(digitsNeeded(1), 1) + assertEquals(digitsNeeded(7), 1) + assertEquals(digitsNeeded(9), 1) + assertEquals(digitsNeeded(13), 2) + assertEquals(digitsNeeded(30173), 5) } - @Test def toMapUniqueEmpty(): Unit = - assert(toMapIfUnique(Seq[(Int, Int)]())(x => x % 2) == Right(Map())) + test("toMapUniqueEmpty") { + assertEquals(toMapIfUnique(Seq[(Int, Int)]())(x => x % 2), Right(Map[Int, Int]())) + } - @Test def toMapUniqueSingleton(): Unit = - assert(toMapIfUnique(Seq(1 -> 2))(x => x % 2) == Right(Map(1 -> 2))) + test("toMapUniqueSingleton") { + assertEquals(toMapIfUnique(Seq(1 -> 2))(x => x % 2), Right(Map(1 -> 2))) + } - @Test def toMapUniqueSmallNoDupe(): Unit = - assert(toMapIfUnique(Seq(1 -> 2, 3 -> 6, 10 -> 2))(x => x % 5) == - Right(Map(1 -> 2, 3 -> 6, 0 -> 2))) + test("toMapUniqueSmallNoDupe") { + assertEquals( + toMapIfUnique(Seq(1 -> 2, 3 -> 6, 10 -> 2))(x => x % 5), + Right(Map(1 -> 2, 3 -> 6, 0 -> 2)), + ) + } - @Test def toMapUniqueSmallDupes(): Unit = - assert(toMapIfUnique(Seq(1 -> 2, 6 -> 6, 10 -> 2))(x => x % 5) == - Left(Map(1 -> Seq(1, 6)))) + test("toMapUniqueSmallDupes") { + assertEquals(toMapIfUnique(Seq(1 -> 2, 6 -> 6, 10 -> 2))(x => x % 5), Left(Map(1 -> Seq(1, 6)))) + } - @Test def testItemPartition(): Unit = { - def test(n: Int, k: Int): Unit = { + test("ItemPartition") { + def check(n: Int, k: Int): Unit = { val a = new Array[Int](k) var prevj = 0 for (i <- 0 until n) { @@ -167,51 +171,51 @@ class UtilsSuite extends HailSuite with ScalaCheckDrivenPropertyChecks { assert(a sameElements p) } - test(0, 0) - test(0, 4) - test(2, 4) - test(2, 5) - test(12, 4) - test(12, 5) + check(0, 0) + check(0, 4) + check(2, 4) + check(2, 5) + check(12, 4) + check(12, 5) } - @Test def testTreeAggDepth(): Unit = { - assert(treeAggDepth(20, 20) == 1) - assert(treeAggDepth(20, 19) == 2) - assert(treeAggDepth(399, 20) == 2) - assert(treeAggDepth(400, 20) == 2) - assert(treeAggDepth(401, 20) == 3) - assert(treeAggDepth(0, 20) == 1) + test("TreeAggDepth") { + assertEquals(treeAggDepth(20, 20), 1) + assertEquals(treeAggDepth(20, 19), 2) + assertEquals(treeAggDepth(399, 20), 2) + assertEquals(treeAggDepth(400, 20), 2) + assertEquals(treeAggDepth(401, 20), 3) + assertEquals(treeAggDepth(0, 20), 1) } - @Test def testMerge(): Unit = { + test("Merge") { val lt: (Int, Int) => Boolean = _ < _ val empty: IndexedSeq[Int] = IndexedSeq.empty - assert(merge(empty, empty, lt) == empty) + assertEquals(merge(empty, empty, lt), empty) val ones: IndexedSeq[Int] = ArraySeq(1) - assert(merge(ones, empty, lt) == ones) - assert(merge(empty, ones, lt) == ones) + assertEquals(merge(ones, empty, lt), ones) + assertEquals(merge(empty, ones, lt), ones) val twos: IndexedSeq[Int] = ArraySeq(2) - assert(merge(ones, twos, lt) == (1 to 2)) - assert(merge(twos, ones, lt) == (1 to 2)) + assertEquals(merge(ones, twos, lt), (1 to 2).toIndexedSeq) + assertEquals(merge(twos, ones, lt), (1 to 2).toIndexedSeq) val threes: IndexedSeq[Int] = ArraySeq(3) - assert(merge(twos, ones ++ threes, lt) == (1 to 3)) + assertEquals(merge(twos, ones ++ threes, lt), (1 to 3).toIndexedSeq) // inputs need to be sorted - assert(merge(twos, threes ++ ones, lt) == Seq(2, 3, 1)) + assertEquals(merge(twos, threes ++ ones, lt), IndexedSeq(2, 3, 1)) } } diff --git a/hail/hail/test/src/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala b/hail/hail/test/src/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala index 8f0e4a1089c..4dd3495f9d6 100644 --- a/hail/hail/test/src/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala +++ b/hail/hail/test/src/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala @@ -1,14 +1,10 @@ package is.hail.utils.prettyPrint +import is.hail.TestCaseSupport import is.hail.collection.implicits.toRichIterator -import scala.jdk.CollectionConverters._ - -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.{DataProvider, Test} - -class PrettyPrintWriterSuite extends TestNGSuite { - def data: Array[(Doc, Array[(Int, Int, Int, String)])] = +class PrettyPrintWriterSuite extends munit.FunSuite with TestCaseSupport { + val data: Array[(Doc, Array[(Int, Int, Int, String)])] = Array( ( nest(2, hsep("prefix", sep("text", "to", "lay", "out"))), @@ -129,23 +125,27 @@ class PrettyPrintWriterSuite extends TestNGSuite { ), ) - @DataProvider(name = "data") - def flatData: java.util.Iterator[Array[Object]] = - (for { - (doc, cases) <- data.iterator - (width, ribbonWidth, maxLines, expected) <- cases.iterator - } yield Array(doc, Int.box(width), Int.box(ribbonWidth), Int.box(maxLines), expected)).asJava - - @Test(dataProvider = "data") - def testPP(doc: Doc, width: Integer, ribbonWidth: Integer, maxLines: Integer, expected: String) - : Unit = { - val ruler = "=" * width - assert(expected == s"$ruler\n${doc.render(width, ribbonWidth, maxLines)}\n$ruler") + object checkPP extends TestCases { + def apply( + doc: Doc, + width: Int, + ribbonWidth: Int, + maxLines: Int, + expected: String, + )(implicit loc: munit.Location + ): Unit = test("PP") { + val ruler = "=" * width + assertEquals(expected, s"$ruler\n${doc.render(width, ribbonWidth, maxLines)}\n$ruler") + } } - @Test def testIntersperse(): Unit = { + for { + (doc, cases) <- data + (width, ribbonWidth, maxLines, expected) <- cases + } checkPP(doc, width, ribbonWidth, maxLines, expected) + + test("Intersperse") { val it = Array("A", "B", "C").iterator.intersperse("(", ",", ")") - assert(it.mkString == "(A,B,C)") + assertEquals(it.mkString, "(A,B,C)") } - } diff --git a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala index 8ed05ebb17e..9f9fac23742 100644 --- a/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/GenotypeSuite.scala @@ -1,41 +1,31 @@ package is.hail.variant import is.hail.scalacheck.partition -import is.hail.testUtils.Variant import is.hail.utils._ import org.scalacheck.Gen -import org.scalatest -import org.scalatestplus.scalacheck.CheckerAsserting.assertingNatureOfAssertion -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.scalatestplus.testng.TestNGSuite -import org.testng.annotations.Test +import org.scalacheck.Prop.forAll -class GenotypeSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { +class GenotypeSuite extends munit.ScalaCheckSuite { - val v = Variant("1", 1, "A", "T") - - @Test def gtPairGtIndexIsId(): Unit = + property("gtPairGtIndexIsId") = forAll(Gen.choose(0, 32768), Gen.choose(0, 32768)) { (x, y) => val (j, k) = if (x < y) (x, y) else (y, x) val gt = AllelePair(j, k) - assert(Genotype.allelePair(Genotype.diploidGtIndex(gt)) == gt) + + Genotype.allelePair(Genotype.diploidGtIndex(gt)) == gt } def triangleNumberOf(i: Int): Int = (i * i + i) / 2 - @Test def gtIndexGtPairIsId(): Unit = - forAll(Gen.choose(0, 10000)) { idx => - assert(Genotype.diploidGtIndex(Genotype.allelePair(idx)) == idx) - } + property("gtIndexGtPairIsId") = + forAll(Gen.choose(0, 10000))(idx => Genotype.diploidGtIndex(Genotype.allelePair(idx)) == idx) - @Test def gtPairAndGtPairSqrtEqual(): Unit = - forAll(Gen.choose(0, 10000)) { idx => - assert(Genotype.allelePair(idx) == Genotype.allelePairSqrt(idx)) - } + property("gtPairAndGtPairSqrtEqual") = + forAll(Gen.choose(0, 10000))(idx => Genotype.allelePair(idx) == Genotype.allelePairSqrt(idx)) - @Test def testGtFromLinear(): Unit = { + property("GtFromLinear") = { val gen = for { nGenotype <- Gen.choose(2, 5).map(triangleNumberOf) @@ -44,19 +34,19 @@ class GenotypeSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { forAll(gen) { gp => val gt = Option(uniqueMaxIndex(gp)) - assert(gp.sum == 32768) + assertEquals(gp.sum, 32768) val dMax = gp.max - scalatest.Inspectors.forAll(gt.toSeq) { gt => + gt.toSeq.foreach { gt => val dosageP = gp(gt) dosageP == dMax && gp.zipWithIndex.forall { case (d, index) => index == gt || d != dosageP } } - assert(gp.count(_ == dMax) > 1 || gt.contains(gp.indexOf(dMax))) + gp.count(_ == dMax) > 1 || gt.contains(gp.indexOf(dMax)) } } - @Test def testPlToDosage(): Unit = { + test("PlToDosage") { val gt0 = Genotype.plToDosage(0, 20, 100) val gt1 = Genotype.plToDosage(20, 0, 100) val gt2 = Genotype.plToDosage(20, 100, 0) @@ -66,17 +56,15 @@ class GenotypeSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { assert(D_==(gt2, 1.980198019704931)) } - @Test def testCall(): Unit = { - scalatest.Inspectors.forAll(0 until 9) { gt => + test("Call") { + (0 until 9).foreach { gt => val c = Call2.fromUnphasedDiploidGtIndex(gt) - assert( - !Call.isPhased(c) && - Call.ploidy(c) == 2 && - Call.isDiploid(c) && - Call.isUnphasedDiploid(c) && - Call.unphasedDiploidGtIndex(c) == gt && - Call.alleleRepr(c) == gt - ) + assert(!Call.isPhased(c)) + assertEquals(Call.ploidy(c), 2) + assert(Call.isDiploid(c)) + assert(Call.isUnphasedDiploid(c)) + assertEquals(Call.unphasedDiploidGtIndex(c), gt) + assertEquals(Call.alleleRepr(c), gt) } val c0 = Call2(0, 0, phased = true) @@ -87,40 +75,60 @@ class GenotypeSuite extends TestNGSuite with ScalaCheckDrivenPropertyChecks { val x = Array((c0, 0, 0), (c1a, 1, 1), (c1b, 1, 2), (c2, 2, 4), (c4, 4, 8)) - assert(x.forall { case (c, unphasedGt, alleleRepr) => + x.foreach { case (c, unphasedGt, alleleRepr) => val alleles = Call.alleles(c) - c != Call2.fromUnphasedDiploidGtIndex(unphasedGt) && - Call.isPhased(c) && - Call.ploidy(c) == 2 && - Call.isDiploid(c) && - !Call.isUnphasedDiploid(c) && - Call.unphasedDiploidGtIndex(Call2(alleles(0), alleles(1))) == unphasedGt && - Call.alleleRepr(c) == alleleRepr - }) - - assert(Call.isHomRef(c0) && !Call.isHet(c0) && !Call.isHomVar(c0) && - !Call.isHetNonRef(c0) && !Call.isHetRef(c0) && !Call.isNonRef(c0)) - - assert(!Call.isHomRef(c1a) && Call.isHet(c1a) && !Call.isHomVar(c1a) && - !Call.isHetNonRef(c1a) && Call.isHetRef(c1a) && Call.isNonRef(c1a)) - - assert(!Call.isHomRef(c1b) && Call.isHet(c1b) && !Call.isHomVar(c1b) && - !Call.isHetNonRef(c1b) && Call.isHetRef(c1b) && Call.isNonRef(c1b)) - - assert(!Call.isHomRef(c2) && !Call.isHet(c2) && Call.isHomVar(c2) && - !Call.isHetNonRef(c2) && !Call.isHetRef(c2) && Call.isNonRef(c2)) - - assert(!Call.isHomRef(c4) && Call.isHet(c4) && !Call.isHomVar(c4) && - Call.isHetNonRef(c4) && !Call.isHetRef(c4) && Call.isNonRef(c4)) - - assert(Call.parse("-") == Call0()) - assert(Call.parse("|-") == Call0(true)) - assert(Call.parse("1") == Call1(1)) - assert(Call.parse("|1") == Call1(1, phased = true)) - assert(Call.parse("0/0") == Call2(0, 0)) - assert(Call.parse("0|1") == Call2(0, 1, phased = true)) - assertThrows[UnsupportedOperationException](Call.parse("1/1/1")) - assertThrows[UnsupportedOperationException](Call.parse("1|1|1")) + assert(c != Call2.fromUnphasedDiploidGtIndex(unphasedGt)) + assert(Call.isPhased(c)) + assertEquals(Call.ploidy(c), 2) + assert(Call.isDiploid(c)) + assert(!Call.isUnphasedDiploid(c)) + assertEquals(Call.unphasedDiploidGtIndex(Call2(alleles(0), alleles(1))), unphasedGt) + assertEquals(Call.alleleRepr(c), alleleRepr) + } + + assert(Call.isHomRef(c0)) + assert(!Call.isHet(c0)) + assert(!Call.isHomVar(c0)) + assert(!Call.isHetNonRef(c0)) + assert(!Call.isHetRef(c0)) + assert(!Call.isNonRef(c0)) + + assert(!Call.isHomRef(c1a)) + assert(Call.isHet(c1a)) + assert(!Call.isHomVar(c1a)) + assert(!Call.isHetNonRef(c1a)) + assert(Call.isHetRef(c1a)) + assert(Call.isNonRef(c1a)) + + assert(!Call.isHomRef(c1b)) + assert(Call.isHet(c1b)) + assert(!Call.isHomVar(c1b)) + assert(!Call.isHetNonRef(c1b)) + assert(Call.isHetRef(c1b)) + assert(Call.isNonRef(c1b)) + + assert(!Call.isHomRef(c2)) + assert(!Call.isHet(c2)) + assert(Call.isHomVar(c2)) + assert(!Call.isHetNonRef(c2)) + assert(!Call.isHetRef(c2)) + assert(Call.isNonRef(c2)) + + assert(!Call.isHomRef(c4)) + assert(Call.isHet(c4)) + assert(!Call.isHomVar(c4)) + assert(Call.isHetNonRef(c4)) + assert(!Call.isHetRef(c4)) + assert(Call.isNonRef(c4)) + + assertEquals(Call.parse("-"), Call0()) + assertEquals(Call.parse("|-"), Call0(true)) + assertEquals(Call.parse("1"), Call1(1)) + assertEquals(Call.parse("|1"), Call1(1, phased = true)) + assertEquals(Call.parse("0/0"), Call2(0, 0)) + assertEquals(Call.parse("0|1"), Call2(0, 1, phased = true)) + intercept[UnsupportedOperationException](Call.parse("1/1/1")): Unit + intercept[UnsupportedOperationException](Call.parse("1|1|1")): Unit val he = intercept[HailException](Call.parse("0/")) assert(he.msg.contains("invalid call expression")) } diff --git a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala index 7371b2d7939..570cfbb89e4 100644 --- a/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/hail/test/src/is/hail/variant/LocusIntervalSuite.scala @@ -3,164 +3,237 @@ package is.hail.variant import is.hail.HailSuite import is.hail.utils._ -import org.testng.annotations.Test - class LocusIntervalSuite extends HailSuite { def rg = ctx.references(ReferenceGenome.GRCh37) - @Test def testParser(): Unit = { + test("Parser") { val xMax = rg.contigLength("X") val yMax = rg.contigLength("Y") val chr22Max = rg.contigLength("22") - assert(Locus.parseInterval("[1:100-1:101)", rg) == Interval( - Locus("1", 100), - Locus("1", 101), - true, - false, - )) - assert(Locus.parseInterval("[1:100-101)", rg) == Interval( - Locus("1", 100), - Locus("1", 101), - true, - false, - )) - assert(Locus.parseInterval("[X:100-101)", rg) == Interval( - Locus("X", 100), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("[X:100-end)", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - true, - false, - )) - assert(Locus.parseInterval("[X:100-End)", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - true, - false, - )) - assert(Locus.parseInterval("[X:100-END)", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - true, - false, - )) - assert(Locus.parseInterval("[X:start-101)", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("[X:Start-101)", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("[X:START-101)", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("[X:START-Y:END)", rg) == Interval( - Locus("X", 1), - Locus("Y", yMax), - true, - false, - )) - assert(Locus.parseInterval("[X-Y)", rg) == Interval( - Locus("X", 1), - Locus("Y", yMax), - true, - false, - )) - assert(Locus.parseInterval("[1-22)", rg) == Interval( - Locus("1", 1), - Locus("22", chr22Max), - true, - false, - )) + assertEquals( + Locus.parseInterval("[1:100-1:101)", rg), + Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[1:100-101)", rg), + Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:100-101)", rg), + Interval( + Locus("X", 100), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:100-end)", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:100-End)", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:100-END)", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:start-101)", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:Start-101)", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:START-101)", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:START-Y:END)", rg), + Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X-Y)", rg), + Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[1-22)", rg), + Interval( + Locus("1", 1), + Locus("22", chr22Max), + true, + false, + ), + ) - assert(Locus.parseInterval("1:100-1:101", rg) == Interval( - Locus("1", 100), - Locus("1", 101), - true, - false, - )) - assert(Locus.parseInterval("1:100-101", rg) == Interval( - Locus("1", 100), - Locus("1", 101), - true, - false, - )) - assert(Locus.parseInterval("X:100-end", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - true, - true, - )) - assert(Locus.parseInterval("(X:100-End]", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - false, - true, - )) - assert(Locus.parseInterval("(X:100-END)", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - false, - false, - )) - assert(Locus.parseInterval("[X:start-101)", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("(X:Start-101]", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - false, - true, - )) - assert(Locus.parseInterval("X:START-101", rg) == Interval( - Locus("X", 1), - Locus("X", 101), - true, - false, - )) - assert(Locus.parseInterval("X:START-Y:END", rg) == Interval( - Locus("X", 1), - Locus("Y", yMax), - true, - true, - )) - assert(Locus.parseInterval("X-Y", rg) == Interval(Locus("X", 1), Locus("Y", yMax), true, true)) - assert(Locus.parseInterval("1-22", rg) == Interval( - Locus("1", 1), - Locus("22", chr22Max), - true, - true, - )) + assertEquals( + Locus.parseInterval("1:100-1:101", rg), + Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("1:100-101", rg), + Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("X:100-end", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + true, + true, + ), + ) + assertEquals( + Locus.parseInterval("(X:100-End]", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + false, + true, + ), + ) + assertEquals( + Locus.parseInterval("(X:100-END)", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + false, + false, + ), + ) + assertEquals( + Locus.parseInterval("[X:start-101)", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("(X:Start-101]", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + false, + true, + ), + ) + assertEquals( + Locus.parseInterval("X:START-101", rg), + Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("X:START-Y:END", rg), + Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + true, + ), + ) + assertEquals( + Locus.parseInterval("X-Y", rg), + Interval(Locus("X", 1), Locus("Y", yMax), true, true), + ) + assertEquals( + Locus.parseInterval("1-22", rg), + Interval( + Locus("1", 1), + Locus("22", chr22Max), + true, + true, + ), + ) // test normalizing end points - assert(Locus.parseInterval(s"(X:100-${xMax + 1})", rg) == Interval( - Locus("X", 100), - Locus("X", xMax), - false, - true, - )) - assert(Locus.parseInterval(s"(X:0-$xMax]", rg) == Interval( - Locus("X", 1), - Locus("X", xMax), - true, - true, - )) + assertEquals( + Locus.parseInterval(s"(X:100-${xMax + 1})", rg), + Interval( + Locus("X", 100), + Locus("X", xMax), + false, + true, + ), + ) + assertEquals( + Locus.parseInterval(s"(X:0-$xMax]", rg), + Interval( + Locus("X", 1), + Locus("X", xMax), + true, + true, + ), + ) interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( "[X:0-5)", rg, @@ -170,43 +243,61 @@ class LocusIntervalSuite extends HailSuite { rg, )) - assert(Locus.parseInterval("[16:29500000-30200000)", rg) == Interval( - Locus("16", 29500000), - Locus("16", 30200000), - true, - false, - )) - assert(Locus.parseInterval("[16:29.5M-30.2M)", rg) == Interval( - Locus("16", 29500000), - Locus("16", 30200000), - true, - false, - )) - assert(Locus.parseInterval("[16:29500K-30200K)", rg) == Interval( - Locus("16", 29500000), - Locus("16", 30200000), - true, - false, - )) - assert(Locus.parseInterval("[1:100K-2:200K)", rg) == Interval( - Locus("1", 100000), - Locus("2", 200000), - true, - false, - )) + assertEquals( + Locus.parseInterval("[16:29500000-30200000)", rg), + Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[16:29.5M-30.2M)", rg), + Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[16:29500K-30200K)", rg), + Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[1:100K-2:200K)", rg), + Interval( + Locus("1", 100000), + Locus("2", 200000), + true, + false, + ), + ) - assert(Locus.parseInterval("[1:1.111K-2000)", rg) == Interval( - Locus("1", 1111), - Locus("1", 2000), - true, - false, - )) - assert(Locus.parseInterval("[1:1.111111M-2000000)", rg) == Interval( - Locus("1", 1111111), - Locus("1", 2000000), - true, - false, - )) + assertEquals( + Locus.parseInterval("[1:1.111K-2000)", rg), + Interval( + Locus("1", 1111), + Locus("1", 2000), + true, + false, + ), + ) + assertEquals( + Locus.parseInterval("[1:1.111111M-2000000)", rg), + Interval( + Locus("1", 1111111), + Locus("1", 2000000), + true, + false, + ), + ) interceptFatal("invalid interval expression") { Locus.parseInterval("4::start-5:end", rg) @@ -228,15 +319,21 @@ class LocusIntervalSuite extends HailSuite { val gr38 = ctx.references(ReferenceGenome.GRCh38) val x = "[GL000197.1:3739-GL000202.1:7538)" - assert(Locus.parseInterval(x, gr37) == - Interval(Locus("GL000197.1", 3739), Locus("GL000202.1", 7538), true, false)) + assertEquals( + Locus.parseInterval(x, gr37), + Interval(Locus("GL000197.1", 3739), Locus("GL000202.1", 7538), true, false), + ) val y = "[HLA-DRB1*13:02:01:5-HLA-DRB1*14:05:01:100)" - assert(Locus.parseInterval(y, gr38) == - Interval(Locus("HLA-DRB1*13:02:01", 5), Locus("HLA-DRB1*14:05:01", 100), true, false)) + assertEquals( + Locus.parseInterval(y, gr38), + Interval(Locus("HLA-DRB1*13:02:01", 5), Locus("HLA-DRB1*14:05:01", 100), true, false), + ) val z = "[HLA-DRB1*13:02:01:5-100)" - assert(Locus.parseInterval(z, gr38) == - Interval(Locus("HLA-DRB1*13:02:01", 5), Locus("HLA-DRB1*13:02:01", 100), true, false)) + assertEquals( + Locus.parseInterval(z, gr38), + Interval(Locus("HLA-DRB1*13:02:01", 5), Locus("HLA-DRB1*13:02:01", 100), true, false), + ) } } diff --git a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala index 62258dc3c0c..cf878e53fbc 100644 --- a/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/hail/test/src/is/hail/variant/ReferenceGenomeSuite.scala @@ -5,49 +5,56 @@ import is.hail.backend.HailStateManager import is.hail.collection.compat.immutable.ArraySeq import is.hail.expr.ir.EmitFunctionBuilder import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} -import is.hail.scalacheck.{genLocus, genNullable} +import is.hail.scalacheck.{genLocus, genNonMissing} import is.hail.types.virtual.{TInterval, TLocus} import is.hail.utils._ import htsjdk.samtools.reference.ReferenceSequenceFileFactory import org.scalacheck.Prop.forAll -import org.testng.annotations.Test -class ReferenceGenomeSuite extends HailSuite { +class ReferenceGenomeSuite extends HailSuite with munit.ScalaCheckSuite { - @Test def testGRCh37(): Unit = { + test("GRCh37") { assert(ctx.references.contains(ReferenceGenome.GRCh37)) val grch37 = ctx.references(ReferenceGenome.GRCh37) - assert(grch37.inX("X") && grch37.inY("Y") && grch37.isMitochondrial("MT")) - assert(grch37.contigLength("1") == 249250621) + assert(grch37.inX("X")) + assert(grch37.inY("Y")) + assert(grch37.isMitochondrial("MT")) + assertEquals(grch37.contigLength("1"), 249250621) val parXLocus = Array(Locus("X", 2499520), Locus("X", 155260460)) val parYLocus = Array(Locus("Y", 50001), Locus("Y", 59035050)) val nonParXLocus = Array(Locus("X", 50), Locus("X", 50000000)) val nonParYLocus = Array(Locus("Y", 5000), Locus("Y", 10000000)) - assert(parXLocus.forall(grch37.inXPar) && parYLocus.forall(grch37.inYPar)) - assert(!nonParXLocus.forall(grch37.inXPar) && !nonParYLocus.forall(grch37.inYPar)) + assert(parXLocus.forall(grch37.inXPar)) + assert(parYLocus.forall(grch37.inYPar)) + assert(!nonParXLocus.forall(grch37.inXPar)) + assert(!nonParYLocus.forall(grch37.inYPar)) } - @Test def testGRCh38(): Unit = { + test("GRCh38") { assert(ctx.references.contains(ReferenceGenome.GRCh38)) val grch38 = ctx.references(ReferenceGenome.GRCh38) - assert(grch38.inX("chrX") && grch38.inY("chrY") && grch38.isMitochondrial("chrM")) - assert(grch38.contigLength("chr1") == 248956422) + assert(grch38.inX("chrX")) + assert(grch38.inY("chrY")) + assert(grch38.isMitochondrial("chrM")) + assertEquals(grch38.contigLength("chr1"), 248956422) val parXLocus38 = Array(Locus("chrX", 2781479), Locus("chrX", 156030895)) val parYLocus38 = Array(Locus("chrY", 50001), Locus("chrY", 57217415)) val nonParXLocus38 = Array(Locus("chrX", 50), Locus("chrX", 50000000)) val nonParYLocus38 = Array(Locus("chrY", 5000), Locus("chrY", 10000000)) - assert(parXLocus38.forall(grch38.inXPar) && parYLocus38.forall(grch38.inYPar)) - assert(!nonParXLocus38.forall(grch38.inXPar) && !nonParYLocus38.forall(grch38.inYPar)) + assert(parXLocus38.forall(grch38.inXPar)) + assert(parYLocus38.forall(grch38.inYPar)) + assert(!nonParXLocus38.forall(grch38.inXPar)) + assert(!nonParYLocus38.forall(grch38.inYPar)) } - @Test def testAssertions(): Unit = { + test("Assertions") { interceptFatal("Must have at least one contig in the reference genome.")( ReferenceGenome("test", ArraySeq.empty[String], Map.empty[String, Int]) ) @@ -102,46 +109,46 @@ class ReferenceGenomeSuite extends HailSuite { )) } - @Test def testContigRemap(): Unit = { + test("ContigRemap") { val mapping = Map("23" -> "foo") interceptFatal("have remapped contigs in reference genome")( ctx.references(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } - @Test def testComparisonOps(): Unit = { + test("ComparisonOps") { val rg = ctx.references(ReferenceGenome.GRCh37) // Test contigs assert(rg.compare("3", "18") < 0) assert(rg.compare("18", "3") > 0) - assert(rg.compare("7", "7") == 0) + assertEquals(rg.compare("7", "7"), 0) assert(rg.compare("3", "X") < 0) assert(rg.compare("X", "3") > 0) - assert(rg.compare("X", "X") == 0) + assertEquals(rg.compare("X", "X"), 0) assert(rg.compare("X", "Y") < 0) assert(rg.compare("Y", "X") > 0) assert(rg.compare("Y", "MT") < 0) } - @Test def testWriteToFile(): Unit = { + test("WriteToFile") { val tmpFile = ctx.createTmpPath("grWrite", "json") val rg = ctx.references(ReferenceGenome.GRCh37) rg.copy(name = "GRCh37_2").write(fs, tmpFile) val gr2 = ReferenceGenome.fromFile(fs, tmpFile) - assert((rg.contigs sameElements gr2.contigs) && - rg.lengths == gr2.lengths && - rg.xContigs == gr2.xContigs && - rg.yContigs == gr2.yContigs && - rg.mtContigs == gr2.mtContigs && - (rg.parInput sameElements gr2.parInput)) + assert(rg.contigs sameElements gr2.contigs) + assertEquals(rg.lengths, gr2.lengths) + assertEquals(rg.xContigs, gr2.xContigs) + assertEquals(rg.yContigs, gr2.yContigs) + assertEquals(rg.mtContigs, gr2.mtContigs) + assert(rg.parInput sameElements gr2.parInput) } - @Test def testFasta(): Unit = { + property("Fasta") { val fastaFile = getTestResource("fake_reference.fasta") val fastaFileGzip = getTestResource("fake_reference.fasta.gz") val indexFile = getTestResource("fake_reference.fasta.fai") @@ -162,90 +169,93 @@ class ReferenceGenomeSuite extends HailSuite { new java.io.File(uriPath(refReaderPathGz)) ) - { - "cache gives same base as from file" |: forAll(genLocus(rg)) { l => - val contig = l.contig - val pos = l.position - val expected = refReader.getSubsequenceAt(contig, pos.toLong, pos.toLong).getBaseString - val expectedGz = - refReaderGz.getSubsequenceAt(contig, pos.toLong, pos.toLong).getBaseString - assert(expected == expectedGz, "wat: fasta files don't have the same data") - fr.lookup(contig, pos, 0, 0) == expected && frGzip.lookup(contig, pos, 0, 0) == expectedGz - } - }.check() - - { - "interval test" |: forAll( - genNullable(ctx, TInterval(TLocus(rg.name))).suchThat(_ != null) - ) { - case i: Interval => - val start = i.start.asInstanceOf[Locus] - val end = i.end.asInstanceOf[Locus] - - val ordering = TLocus(rg.name).ordering(HailStateManager(Map(rg.name -> rg))) - - def getHtsjdkIntervalSequence: String = { - val sb = new StringBuilder - var pos = start - while (ordering.lteq(pos, end) && pos != null) { - val endPos = - if (pos.contig != end.contig) rg.contigLength(pos.contig) else end.position - sb ++= refReader.getSubsequenceAt( - pos.contig, - pos.position.toLong, - endPos.toLong, - ).getBaseString - pos = - if (rg.contigsIndex.get(pos.contig) == rg.contigs.length - 1) - null - else - Locus(rg.contigs(rg.contigsIndex.get(pos.contig) + 1), 1) - } - sb.result() + assertEquals(fr.lookup("a", 25, 0, 5), "A") + assertEquals(fr.lookup("b", 1, 5, 0), "T") + assertEquals(fr.lookup("c", 5, 10, 10), "GGATCCGTGC") + assertEquals( + fr.lookup(Interval( + Locus("a", 1), + Locus("a", 5), + includesStart = true, + includesEnd = false, + )), + "AGGT", + ) + assertEquals( + fr.lookup(Interval( + Locus("a", 20), + Locus("b", 5), + includesStart = false, + includesEnd = false, + )), + "ACGTATAAT", + ) + assertEquals( + fr.lookup(Interval( + Locus("a", 20), + Locus("c", 5), + includesStart = false, + includesEnd = false, + )), + "ACGTATAATTAAATTAGCCAGGAT", + ) + + ("cache gives same base as from file" |: forAll(genLocus(rg)) { l => + val contig = l.contig + val pos = l.position + val expected = refReader.getSubsequenceAt(contig, pos.toLong, pos.toLong).getBaseString + val expectedGz = + refReaderGz.getSubsequenceAt(contig, pos.toLong, pos.toLong).getBaseString + assertEquals(expected, expectedGz, "fasta files don't have the same data") + fr.lookup(contig, pos, 0, 0) == expected && frGzip.lookup(contig, pos, 0, 0) == expectedGz + }) ++ ("interval test" |: forAll( + genNonMissing(ctx, TInterval(TLocus(rg.name))) + ) { + case i: Interval => + val start = i.start.asInstanceOf[Locus] + val end = i.end.asInstanceOf[Locus] + + val ordering = TLocus(rg.name).ordering(HailStateManager(Map(rg.name -> rg))) + + def getHtsjdkIntervalSequence: String = { + val sb = new StringBuilder + var pos = start + while (ordering.lteq(pos, end) && pos != null) { + val endPos = + if (pos.contig != end.contig) rg.contigLength(pos.contig) else end.position + sb ++= refReader.getSubsequenceAt( + pos.contig, + pos.position.toLong, + endPos.toLong, + ).getBaseString + pos = + if (rg.contigsIndex.get(pos.contig) == rg.contigs.length - 1) + null + else + Locus(rg.contigs(rg.contigsIndex.get(pos.contig) + 1), 1) } + sb.result() + } - fr.lookup( - Interval(start, end, includesStart = true, includesEnd = true) - ) == getHtsjdkIntervalSequence - } - }.check() - - assert(fr.lookup("a", 25, 0, 5) == "A") - assert(fr.lookup("b", 1, 5, 0) == "T") - assert(fr.lookup("c", 5, 10, 10) == "GGATCCGTGC") - assert(fr.lookup(Interval( - Locus("a", 1), - Locus("a", 5), - includesStart = true, - includesEnd = false, - )) == "AGGT") - assert(fr.lookup(Interval( - Locus("a", 20), - Locus("b", 5), - includesStart = false, - includesEnd = false, - )) == "ACGTATAAT") - assert(fr.lookup(Interval( - Locus("a", 20), - Locus("c", 5), - includesStart = false, - includesEnd = false, - )) == "ACGTATAATTAAATTAGCCAGGAT") + fr.lookup( + Interval(start, end, includesStart = true, includesEnd = true) + ) == getHtsjdkIntervalSequence + }) } } - @Test def testSerializeOnFB(): Unit = { + test("SerializeOnFB") { val grch38 = ctx.references(ReferenceGenome.GRCh38) val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") val rgfield = fb.getReferenceGenome(grch38.name) fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) ctx.scopedExecution { (cl, fs, tc, r) => val f = fb.resultWithIndex()(cl, fs, tc, r) - assert(f("X") == grch38.isValidContig("X")) + assertEquals(f("X"), grch38.isValidContig("X")) } } - @Test def testSerializeWithLiftoverOnFB(): Unit = + test("SerializeWithLiftoverOnFB") { ctx.local(references = ReferenceGenome.builtinReferences()) { ctx => val grch37 = ctx.references(ReferenceGenome.GRCh37) val liftoverFile = getTestResource("grch37_to_grch38_chr20.over.chain.gz") @@ -264,11 +274,11 @@ class ReferenceGenomeSuite extends HailSuite { ctx.scopedExecution { (cl, fs, tc, r) => val f = fb.resultWithIndex()(cl, fs, tc, r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( - "GRCh38", - Locus("20", 60001), - 0.95, - )) + assertEquals( + f("GRCh38", Locus("20", 60001), 0.95), + grch37.liftoverLocus("GRCh38", Locus("20", 60001), 0.95), + ) } } + } } diff --git a/hail/hail/test/src/is/hail/variant/vsm/PartitioningSuite.scala b/hail/hail/test/src/is/hail/variant/vsm/PartitioningSuite.scala index 5de4fe23683..0480f39ec2b 100644 --- a/hail/hail/test/src/is/hail/variant/vsm/PartitioningSuite.scala +++ b/hail/hail/test/src/is/hail/variant/vsm/PartitioningSuite.scala @@ -8,10 +8,8 @@ import is.hail.expr.ir.{Interpret, MatrixAnnotateRowsTable, TableLiteral, TableV import is.hail.rvd.RVD import is.hail.types.virtual.{TInt32, TStruct, TableType} -import org.testng.annotations.Test - class PartitioningSuite extends HailSuite { - @Test def testShuffleOnEmptyRDD(): Unit = { + test("ShuffleOnEmptyRDD") { val typ = TableType(TStruct("tidx" -> TInt32), FastSeq("tidx"), TStruct.empty) val t = TableLiteral( TableValue( diff --git a/hail/hail/utils/src/is/hail/utils/package.scala b/hail/hail/utils/src/is/hail/utils/package.scala index 12c0f733693..002fb9587b4 100644 --- a/hail/hail/utils/src/is/hail/utils/package.scala +++ b/hail/hail/utils/src/is/hail/utils/package.scala @@ -232,29 +232,31 @@ package object utils extends ErrorHandling with Logging { num / denom val defaultTolerance = 1e-6 + val defaultAbsTolerance = 1e-30 - def D_epsilon(a: Double, b: Double, tolerance: Double = defaultTolerance): Double = - math.max(java.lang.Double.MIN_NORMAL, tolerance * math.max(math.abs(a), math.abs(b))) + def D_epsilon(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Double = + defaultAbsTolerance + tolerance * math.max(math.abs(a), math.abs(b)) - def D_==(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - a == b || math.abs(a - b) <= D_epsilon(a, b, tolerance) + // This is more or less the implementation of numpy.isclose + def D_==(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + a == b || math.abs(a - b) <= D_epsilon(a, b, tolerance, absTolerance) - def D_!=(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - !(a == b) && math.abs(a - b) > D_epsilon(a, b, tolerance) + def D_!=(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + !(a == b) && math.abs(a - b) > D_epsilon(a, b, tolerance, absTolerance) - def D_<(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - !(a == b) && a - b < -D_epsilon(a, b, tolerance) + def D_<(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + !(a == b) && a - b < -D_epsilon(a, b, tolerance, absTolerance) - def D_<=(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - (a == b) || a - b <= D_epsilon(a, b, tolerance) + def D_<=(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + (a == b) || a - b <= D_epsilon(a, b, tolerance, absTolerance) - def D_>(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - !(a == b) && a - b > D_epsilon(a, b, tolerance) + def D_>(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + !(a == b) && a - b > D_epsilon(a, b, tolerance, absTolerance) - def D_>=(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = - (a == b) || a - b >= -D_epsilon(a, b, tolerance) + def D_>=(a: Double, b: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = + (a == b) || a - b >= -D_epsilon(a, b, tolerance, absTolerance) - def D0_==(x: Double, y: Double, tolerance: Double = defaultTolerance): Boolean = + def D0_==(x: Double, y: Double, tolerance: Double = defaultTolerance, absTolerance: Double = defaultAbsTolerance): Boolean = if (x.isNaN) y.isNaN else if (x.isPosInfinity) @@ -262,7 +264,7 @@ package object utils extends ErrorHandling with Logging { else if (x.isNegInfinity) y.isNegInfinity else - D_==(x, y, tolerance) + D_==(x, y, tolerance, absTolerance) def flushDouble(a: Double): Double = if (math.abs(a) < java.lang.Double.MIN_NORMAL) 0.0 else a diff --git a/hail/mill-build/mill-build/src/MvnCoordinate.scala b/hail/mill-build/mill-build/src/MvnCoordinate.scala index 31b83b98c61..c2b25a3d46c 100644 --- a/hail/mill-build/mill-build/src/MvnCoordinate.scala +++ b/hail/mill-build/mill-build/src/MvnCoordinate.scala @@ -285,6 +285,7 @@ object MvnCoordinate: val `metrics-jvm` = "io.dropwizard.metrics:metrics-jvm" val `mill-scalafix` = "com.goyeau::mill-scalafix" val `minlog` = "com.esotericsoftware:minlog" + val `munit` = "org.scalameta::munit" val `mockito-scala` = "org.mockito::mockito-scala" val `netty-all` = "io.netty:netty-all" val `netty-buffer` = "io.netty:netty-buffer" diff --git a/hail/testng-build.xml b/hail/testng-build.xml deleted file mode 100644 index e0f4d14a1b3..00000000000 --- a/hail/testng-build.xml +++ /dev/null @@ -1,10 +0,0 @@ - - - - - - - - - - diff --git a/hail/testng-fs.xml b/hail/testng-fs.xml deleted file mode 100644 index 990a20e5fbd..00000000000 --- a/hail/testng-fs.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - diff --git a/hail/testng-services.xml b/hail/testng-services.xml deleted file mode 100644 index 0b25e13f634..00000000000 --- a/hail/testng-services.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - diff --git a/hail/testng.xml b/hail/testng.xml deleted file mode 100644 index 116463c32b8..00000000000 --- a/hail/testng.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - - -