Skip to content

PyTorch Support #275

@kennykos

Description

@kennykos

When I try to pass a PyTorch tensor into a workunit as I would a CuPy array, as in the following script,

import torch
import pykokkos as pk

@pk.workunit
def work(wid, a):
    a[wid] = a[wid] + 1

def main():
    N = 10
    a = torch.ones(N)
    pk.set_default_space(pk.Cuda)
    pk.parallel_for("work", 10, work, a=a)
    print(a)

main()

I am met with the error

Traceback (most recent call last):
  File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 15, in <module>
    main()
  File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 12, in main
    pk.parallel_for("work", 10, work, a=a)
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/interface/parallel_dispatch.py", line 158, in parallel_for
    runtime_singleton.runtime.run_workunit(
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 153, in run_workunit
    return self.execute_workunit(name, policy, workunit, operation, parser, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 199, in execute_workunit
    members: PyKokkosMembers = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, restrict_views, restrict_signature, **kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 86, in precompile_workunit
    members: PyKokkosMembers = self.compiler.compile_object(module_setup,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/compiler.py", line 178, in compile_object
    entity.AST = parser.fix_types(entity, updated_types)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 170, in fix_types
    arg_obj.annotation = self.get_annotation_node(update_type)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 283, in get_annotation_node
    raise ValueError(f"Type inference for {type} is not supported")
ValueError: Type inference for Tensor is not supported

It would be very helpful for integration purposes if PyKokkos workunits supported PyTorch types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions