-
Notifications
You must be signed in to change notification settings - Fork 56
Add ShapeDtypeStruct for compilation without concrete array allocation #2431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fe6bab1
57179de
bf9ccc2
b937ca2
c892e05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Example: Using ShapeDtypeStruct for compilation without concrete arrays | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would be the first script in the "examples" folder. IMHO they can quickly become outdated and confusing for new users if we don't test them. specially in Reactant where we have been breaking API quite a lot. so... maybe this is better handled in the docs as an |
||
| # | ||
| # This example demonstrates how to use Reactant.ShapeDtypeStruct to compile | ||
| # functions without having to construct full ConcreteRArray instances with | ||
| # actual data. This is useful for: | ||
| # 1. Faster compilation when you only need shape/dtype information | ||
| # 2. Memory efficiency when working with large arrays | ||
| # 3. Similar workflow to JAX's ShapeDtypeStruct | ||
|
|
||
| using Reactant | ||
|
|
||
| # Example 1: Basic usage with a simple function | ||
| println("Example 1: Basic compilation with ShapeDtypeStruct") | ||
| println("=" ^ 60) | ||
|
|
||
| # Define a simple function that sums an array | ||
| f_sum(x) = sum(x) | ||
|
|
||
| # Instead of creating a full ConcreteRArray: | ||
| # x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) # This allocates memory! | ||
|
|
||
| # Use ShapeDtypeStruct to specify only shape and dtype: | ||
| spec = Reactant.ShapeDtypeStruct((10, 20), Float32) | ||
| println("Created ShapeDtypeStruct: ", spec) | ||
| println(" Shape: ", size(spec)) | ||
| println(" Element type: ", eltype(spec)) | ||
| println(" Dimensions: ", ndims(spec)) | ||
|
|
||
| # Compile the function using the spec | ||
| compiled_f_sum = Reactant.compile(f_sum, (spec,)) | ||
| println("✓ Function compiled successfully") | ||
|
|
||
| # Now execute with actual data | ||
| x_actual = Reactant.ConcreteRArray(rand(Float32, 10, 20)) | ||
| result = compiled_f_sum(x_actual) | ||
| println("Result: ", result, " (type: ", typeof(result), ")") | ||
| println() | ||
|
|
||
| # Example 2: Multiple arguments | ||
| println("Example 2: Multiple arguments with ShapeDtypeStruct") | ||
| println("=" ^ 60) | ||
|
|
||
| f_add(x, y) = x .+ y | ||
|
|
||
| spec1 = Reactant.ShapeDtypeStruct((5, 5), Float64) | ||
| spec2 = Reactant.ShapeDtypeStruct((5, 5), Float64) | ||
|
|
||
| compiled_f_add = Reactant.compile(f_add, (spec1, spec2)) | ||
| println("✓ Function with 2 arguments compiled") | ||
|
|
||
| x_data = Reactant.ConcreteRArray(rand(Float64, 5, 5)) | ||
| y_data = Reactant.ConcreteRArray(rand(Float64, 5, 5)) | ||
| result_add = compiled_f_add(x_data, y_data) | ||
| println("Result shape: ", size(result_add)) | ||
| println() | ||
|
|
||
| # Example 3: Different dtypes | ||
| println("Example 3: Compilation with different dtypes") | ||
| println("=" ^ 60) | ||
|
|
||
| f_sin(x) = sin.(x) | ||
|
|
||
| for dtype in [Float32, Float64] | ||
| spec = Reactant.ShapeDtypeStruct((100,), dtype) | ||
| compiled = Reactant.compile(f_sin, (spec,)) | ||
|
|
||
| x = Reactant.ConcreteRArray(rand(dtype, 100)) | ||
| result = compiled(x) | ||
| println("✓ Compiled and ran for dtype: ", dtype) | ||
| end | ||
| println() | ||
|
|
||
| # Example 4: Benefits demonstration | ||
| println("Example 4: Memory efficiency") | ||
| println("=" ^ 60) | ||
|
|
||
| # For very large arrays, you can compile without allocating the full array: | ||
| large_spec = Reactant.ShapeDtypeStruct((10000, 10000), Float32) | ||
| println("Created spec for large array: ", size(large_spec)) | ||
| println(" This doesn't allocate ", prod(size(large_spec)) * sizeof(Float32) / 1e9, " GB of memory!") | ||
|
|
||
| # Compile a function for this large array | ||
| f_large(x) = sum(x .* x) | ||
| compiled_large = Reactant.compile(f_large, (large_spec,)) | ||
| println("✓ Compiled function for large array without allocating memory") | ||
| println() | ||
|
|
||
| println("All examples completed successfully!") | ||
| println() | ||
| println("Key Takeaways:") | ||
| println("1. ShapeDtypeStruct allows compilation without data allocation") | ||
| println("2. Same compiled function can be used with actual ConcreteRArray data") | ||
| println("3. Useful for large arrays and rapid prototyping") | ||
| println("4. Similar API to JAX's ShapeDtypeStruct") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1407,6 +1407,42 @@ Base.@nospecializeinfer function make_tracer( | |
| return res | ||
| end | ||
|
|
||
| Base.@nospecializeinfer function make_tracer( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mmm I think you're missing the |
||
| seen, | ||
| @nospecialize(prev::ShapeDtypeStruct{T,N}), | ||
| @nospecialize(path), | ||
| mode; | ||
| kwargs..., | ||
| ) where {T,N} | ||
| if mode == TracedToTypes | ||
| throw( | ||
| ArgumentError( | ||
| "ShapeDtypeStruct cannot be used as a function call argument; it is only valid for compilation signatures." | ||
| ), | ||
| ) | ||
| end | ||
| if mode == ArrayToConcrete | ||
| throw( | ||
| ErrorException( | ||
| "Cannot convert ShapeDtypeStruct to ConcreteRArray. ShapeDtypeStruct is only for compilation signatures." | ||
| ), | ||
| ) | ||
| end | ||
| # ShapeDtypeStruct behaves like ConcreteToTraced mode - creates a TracedRArray without data | ||
| # Accept both ConcreteToTraced and TracedSetPath modes | ||
| if mode != ConcreteToTraced && mode != TracedSetPath | ||
| throw( | ||
| ArgumentError( | ||
| "ShapeDtypeStruct can only be used with ConcreteToTraced or TracedSetPath mode, got $mode" | ||
| ), | ||
| ) | ||
| end | ||
|
Comment on lines
+1417
to
+1439
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be fused into one check |
||
| haskey(seen, prev) && return seen[prev]::TracedRArray{T,N} | ||
| res = TracedRArray{T,N}((path,), nothing, size(prev)) | ||
| seen[prev] = res | ||
| return res | ||
| end | ||
|
|
||
| Base.@nospecializeinfer function make_tracer( | ||
| seen, | ||
| prev::ConcretePJRTNumber{T}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,6 +131,55 @@ const AnyTracedRVector{T} = AnyTracedRArray{T,1} | |
| const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} | ||
| const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} | ||
|
|
||
| ## ShapeDtypeStruct | ||
| """ | ||
| ShapeDtypeStruct{T,N}(shape::NTuple{N,Int}) | ||
| ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N} | ||
|
|
||
| Lightweight structure that specifies the shape and element type (dtype) of an array | ||
| without allocating the actual array data. Similar to JAX's `ShapeDtypeStruct`. | ||
|
|
||
| This is useful for compiling functions without constructing the full `ConcreteRArray`, | ||
| which can save memory and improve compilation performance. | ||
|
|
||
| # Examples | ||
| ```julia | ||
| # Specify shape and dtype for a 2D array | ||
| spec = Reactant.ShapeDtypeStruct((10, 20), Float32) | ||
|
|
||
| # Compile a function using just the spec | ||
| f(x) = sum(x) | ||
| compiled_f = Reactant.compile(f, (spec,)) | ||
|
|
||
| # Execute with actual data | ||
| x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) | ||
| result = compiled_f(x) | ||
| ``` | ||
|
|
||
| See also: [`compile`](@ref), [`ConcreteRArray`](@ref) | ||
| """ | ||
| struct ShapeDtypeStruct{T,N} | ||
| shape::NTuple{N,Int} | ||
|
Comment on lines
+161
to
+162
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a The jax docs say this field is optional. |
||
|
|
||
| function ShapeDtypeStruct{T,N}(shape::NTuple{N,Int}) where {T,N} | ||
| return new{T,N}(shape) | ||
| end | ||
| end | ||
|
|
||
| function ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N} | ||
| return ShapeDtypeStruct{T,N}(shape) | ||
| end | ||
|
|
||
| function ShapeDtypeStruct(shape::Tuple{Vararg{Integer}}, dtype::Type{T}) where {T} | ||
| return ShapeDtypeStruct(map(Int, shape), dtype) | ||
| end | ||
|
|
||
| Base.size(x::ShapeDtypeStruct) = x.shape | ||
| Base.ndims(::ShapeDtypeStruct{T,N}) where {T,N} = N | ||
| Base.eltype(::ShapeDtypeStruct{T}) where {T} = T | ||
|
|
||
| @leaf ShapeDtypeStruct | ||
|
|
||
| # Concrete Types | ||
| ## ConcretePJRTNumber | ||
| mutable struct ConcretePJRTNumber{T,D} <: AbstractConcreteNumber{T} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a personal opinion so you can ignore it.
since this models an "array" in the end, wouldn't it be better suffixing
RArray?also
dtypename is used in numpy, where in Julia we useeltype...so how about
ShapeRArray?GhostRArray?ShadowRArray?