Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
66) PR #3359 for #3022. Uses the improved datatype functionality to
correctly transform reduction intrinsics into native code.

65) PR #3365 towards #2668. Updates more transformations to accept
keyword arguments instead of options.

Expand Down
3 changes: 1 addition & 2 deletions examples/nemo/scripts/omp_cpu_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def trans(psyir):
subroutine,
hoist_local_arrays=False,
convert_array_notation=True,
# See issue #3022
loopify_array_intrinsics=psyir.name != "getincom.f90",
loopify_array_intrinsics=True,
convert_range_loops=True,
hoist_expressions=False,
scalarise_loops=False
Expand Down
3 changes: 1 addition & 2 deletions examples/nemo/scripts/omp_gpu_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def trans(psyir):
subroutine,
hoist_local_arrays=False,
convert_array_notation=True,
# See issue #3022
loopify_array_intrinsics=psyir.name != "getincom.f90",
loopify_array_intrinsics=True,
convert_range_loops=True,
increase_array_ranks=not NEMOV4,
hoist_expressions=True
Expand Down
15 changes: 6 additions & 9 deletions examples/nemo/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,12 @@ def normalise_loops(
pass

if loopify_array_intrinsics:
filename = schedule.root.name
# TODO #3022: Some files have a bug in Maxval2LoopTrans
if filename not in ("flincom.f90", "histcom.f90", "restcom.f90"):
for intr in schedule.walk(IntrinsicCall):
if intr.intrinsic.name == "MAXVAL":
try:
Maxval2LoopTrans().apply(intr)
except TransformationError as err:
print(err.value)
for intr in schedule.walk(IntrinsicCall):
if intr.intrinsic.name == "MAXVAL":
try:
Maxval2LoopTrans().apply(intr, verbose=True)
except TransformationError as err:
print(err.value)

if convert_range_loops:
# Convert all array implicit loops to explicit loops
Expand Down
5 changes: 2 additions & 3 deletions src/psyclone/psyir/nodes/intrinsic_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,12 @@ def _maxval_return_type(node: IntrinsicCall) -> DataType:

:returns: the computed datatype for the IntrinsicCall.
"""
dtype = ScalarType(node.argument_by_name("array").datatype.intrinsic,
node.argument_by_name("array").datatype.precision)
arg = node.argument_by_name("array")
dtype = arg.datatype.elemental_type
if "dim" not in node.argument_names:
return dtype
# We have a dimension specified. We don't know the resultant shape
# in any detail as its dependent on the value of dim
arg = node.argument_by_name("array")
return _type_of_arg_with_rank_minus_one(arg, dtype)


Expand Down
16 changes: 11 additions & 5 deletions src/psyclone/psyir/tools/type_info_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from psyclone.errors import InternalError
from psyclone.psyir.nodes import Reference
from psyclone.psyir.symbols.datatypes import (
ScalarType, UnresolvedType, DataType
ScalarType, UnresolvedType, DataType, ArrayType
)


Expand Down Expand Up @@ -102,16 +102,16 @@ def compute_precision(

def compute_scalar_type(
argtypes: list[DataType]
) -> ScalarType.Intrinsic:
) -> ScalarType:
'''
Examines the argtypes to determine the base type of the result of a
numerical operation with them as operands. Usesthe rules in Section 7.2
numerical operation with them as operands. Uses the rules in Section 7.2
of the Fortran2008 standard. If the type cannot be determined then an
instance of `UnresolvedType` is returned.

:param argtypes: the types of the arguments.

:returns: the base (scalar) type of the result of the input arguments.
:returns: the elemental type of the result of the input arguments.

:raises InternalError: If more than two argument types are provided.
:raises TypeError: If the types differ and any are not a numeric datatype.
Expand All @@ -130,7 +130,9 @@ def compute_scalar_type(
return UnresolvedType()

# If all the datatypes are the same then we can return the first.
if (argtypes[0] == argtypes[1]):
if argtypes[0] == argtypes[1]:
if isinstance(argtypes[0], ArrayType):
return argtypes[0].elemental_type
return argtypes[0]

# TODO 1590 - ensure support for complex numbers here in the future.
Expand All @@ -145,8 +147,12 @@ def compute_scalar_type(

# If either has REAL intrinsic type, the result is a REAL.
if argtypes[0].intrinsic == ScalarType.Intrinsic.REAL:
if isinstance(argtypes[0], ArrayType):
return argtypes[0].elemental_type
return argtypes[0]
if argtypes[1].intrinsic == ScalarType.Intrinsic.REAL:
if isinstance(argtypes[1], ArrayType):
return argtypes[1].elemental_type
return argtypes[1]

# Otherwise, the type of the result is not consistent with
Expand Down
Loading
Loading