diff --git a/docs/src/api/api.md b/docs/src/api/api.md index e118d227e5..83d5bc449f 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -26,6 +26,7 @@ within_compile ```@docs ConcreteRArray ConcreteRNumber +ShapeDtypeStruct ``` ## Inspect Generated HLO diff --git a/examples/shape_dtype_struct_example.jl b/examples/shape_dtype_struct_example.jl new file mode 100644 index 0000000000..24d4a11b89 --- /dev/null +++ b/examples/shape_dtype_struct_example.jl @@ -0,0 +1,94 @@ +# Example: Using ShapeDtypeStruct for compilation without concrete arrays +# +# This example demonstrates how to use Reactant.ShapeDtypeStruct to compile +# functions without having to construct full ConcreteRArray instances with +# actual data. This is useful for: +# 1. Faster compilation when you only need shape/dtype information +# 2. Memory efficiency when working with large arrays +# 3. Similar workflow to JAX's ShapeDtypeStruct + +using Reactant + +# Example 1: Basic usage with a simple function +println("Example 1: Basic compilation with ShapeDtypeStruct") +println("=" ^ 60) + +# Define a simple function that sums an array +f_sum(x) = sum(x) + +# Instead of creating a full ConcreteRArray: +# x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) # This allocates memory! + +# Use ShapeDtypeStruct to specify only shape and dtype: +spec = Reactant.ShapeDtypeStruct((10, 20), Float32) +println("Created ShapeDtypeStruct: ", spec) +println(" Shape: ", size(spec)) +println(" Element type: ", eltype(spec)) +println(" Dimensions: ", ndims(spec)) + +# Compile the function using the spec +compiled_f_sum = Reactant.compile(f_sum, (spec,)) +println("✓ Function compiled successfully") + +# Now execute with actual data +x_actual = Reactant.ConcreteRArray(rand(Float32, 10, 20)) +result = compiled_f_sum(x_actual) +println("Result: ", result, " (type: ", typeof(result), ")") +println() + +# Example 2: Multiple arguments +println("Example 2: Multiple arguments with ShapeDtypeStruct") +println("=" ^ 60) + +f_add(x, y) = x .+ y + +spec1 = Reactant.ShapeDtypeStruct((5, 5), Float64) +spec2 = Reactant.ShapeDtypeStruct((5, 5), Float64) + +compiled_f_add = Reactant.compile(f_add, (spec1, spec2)) +println("✓ Function with 2 arguments compiled") + +x_data = Reactant.ConcreteRArray(rand(Float64, 5, 5)) +y_data = Reactant.ConcreteRArray(rand(Float64, 5, 5)) +result_add = compiled_f_add(x_data, y_data) +println("Result shape: ", size(result_add)) +println() + +# Example 3: Different dtypes +println("Example 3: Compilation with different dtypes") +println("=" ^ 60) + +f_sin(x) = sin.(x) + +for dtype in [Float32, Float64] + spec = Reactant.ShapeDtypeStruct((100,), dtype) + compiled = Reactant.compile(f_sin, (spec,)) + + x = Reactant.ConcreteRArray(rand(dtype, 100)) + result = compiled(x) + println("✓ Compiled and ran for dtype: ", dtype) +end +println() + +# Example 4: Benefits demonstration +println("Example 4: Memory efficiency") +println("=" ^ 60) + +# For very large arrays, you can compile without allocating the full array: +large_spec = Reactant.ShapeDtypeStruct((10000, 10000), Float32) +println("Created spec for large array: ", size(large_spec)) +println(" This doesn't allocate ", prod(size(large_spec)) * sizeof(Float32) / 1e9, " GB of memory!") + +# Compile a function for this large array +f_large(x) = sum(x .* x) +compiled_large = Reactant.compile(f_large, (large_spec,)) +println("✓ Compiled function for large array without allocating memory") +println() + +println("All examples completed successfully!") +println() +println("Key Takeaways:") +println("1. ShapeDtypeStruct allows compilation without data allocation") +println("2. Same compiled function can be used with actual ConcreteRArray data") +println("3. Useful for large arrays and rapid prototyping") +println("4. Similar API to JAX's ShapeDtypeStruct") diff --git a/src/Reactant.jl b/src/Reactant.jl index 57ba65d442..a1a8532efb 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -124,6 +124,10 @@ unwrapped_eltype(::TracedRNumber{T}) where {T} = T unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T) unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T) +# For ShapeDtypeStruct +unwrapped_eltype(::Type{ShapeDtypeStruct{T,N}}) where {T,N} = T +unwrapped_eltype(::ShapeDtypeStruct{T,N}) where {T,N} = T + include("Ops.jl") Base.push!(no_rewrite_ancestor_modules, Ops) @@ -288,6 +292,7 @@ export ConcreteRArray, ConcretePJRTNumber, ConcreteIFRTArray, ConcreteIFRTNumber, + ShapeDtypeStruct, @compile, @code_hlo, @code_mhlo, diff --git a/src/Tracing.jl b/src/Tracing.jl index 4e5a88f51b..be05552aeb 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1407,6 +1407,42 @@ Base.@nospecializeinfer function make_tracer( return res end +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::ShapeDtypeStruct{T,N}), + @nospecialize(path), + mode; + kwargs..., +) where {T,N} + if mode == TracedToTypes + throw( + ArgumentError( + "ShapeDtypeStruct cannot be used as a function call argument; it is only valid for compilation signatures." + ), + ) + end + if mode == ArrayToConcrete + throw( + ErrorException( + "Cannot convert ShapeDtypeStruct to ConcreteRArray. ShapeDtypeStruct is only for compilation signatures." + ), + ) + end + # ShapeDtypeStruct behaves like ConcreteToTraced mode - creates a TracedRArray without data + # Accept both ConcreteToTraced and TracedSetPath modes + if mode != ConcreteToTraced && mode != TracedSetPath + throw( + ArgumentError( + "ShapeDtypeStruct can only be used with ConcreteToTraced or TracedSetPath mode, got $mode" + ), + ) + end + haskey(seen, prev) && return seen[prev]::TracedRArray{T,N} + res = TracedRArray{T,N}((path,), nothing, size(prev)) + seen[prev] = res + return res +end + Base.@nospecializeinfer function make_tracer( seen, prev::ConcretePJRTNumber{T}, diff --git a/src/Types.jl b/src/Types.jl index 948765684d..7b52f1e568 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -131,6 +131,55 @@ const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} +## ShapeDtypeStruct +""" + ShapeDtypeStruct{T,N}(shape::NTuple{N,Int}) + ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N} + +Lightweight structure that specifies the shape and element type (dtype) of an array +without allocating the actual array data. Similar to JAX's `ShapeDtypeStruct`. + +This is useful for compiling functions without constructing the full `ConcreteRArray`, +which can save memory and improve compilation performance. + +# Examples +```julia +# Specify shape and dtype for a 2D array +spec = Reactant.ShapeDtypeStruct((10, 20), Float32) + +# Compile a function using just the spec +f(x) = sum(x) +compiled_f = Reactant.compile(f, (spec,)) + +# Execute with actual data +x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) +result = compiled_f(x) +``` + +See also: [`compile`](@ref), [`ConcreteRArray`](@ref) +""" +struct ShapeDtypeStruct{T,N} + shape::NTuple{N,Int} + + function ShapeDtypeStruct{T,N}(shape::NTuple{N,Int}) where {T,N} + return new{T,N}(shape) + end +end + +function ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N} + return ShapeDtypeStruct{T,N}(shape) +end + +function ShapeDtypeStruct(shape::Tuple{Vararg{Integer}}, dtype::Type{T}) where {T} + return ShapeDtypeStruct(map(Int, shape), dtype) +end + +Base.size(x::ShapeDtypeStruct) = x.shape +Base.ndims(::ShapeDtypeStruct{T,N}) where {T,N} = N +Base.eltype(::ShapeDtypeStruct{T}) where {T} = T + +@leaf ShapeDtypeStruct + # Concrete Types ## ConcretePJRTNumber mutable struct ConcretePJRTNumber{T,D} <: AbstractConcreteNumber{T} diff --git a/test/core/compile.jl b/test/core/compile.jl index dfb0d2c515..74b05415ac 100644 --- a/test/core/compile.jl +++ b/test/core/compile.jl @@ -593,3 +593,68 @@ end @test Array(y[:Mhalo]) ≈ [1.0f0, 2.0f0] @test Array(y[:x]) ≈ [2.0f0, 3.0f0] end + +@testset "ShapeDtypeStruct compilation" begin + @testset "Basic compilation with ShapeDtypeStruct" begin + # Define a simple function + f(x) = sum(x) + + # Compile using ShapeDtypeStruct instead of ConcreteRArray + spec = Reactant.ShapeDtypeStruct((10, 20), Float32) + compiled_f = Reactant.compile(f, (spec,)) + + # Execute with actual data + x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) + result = compiled_f(x) + + @test result isa Reactant.ConcreteRNumber{Float32} + @test result ≈ sum(Array(x)) + end + + @testset "Multiple arguments with ShapeDtypeStruct" begin + f(x, y) = x .+ y + + spec1 = Reactant.ShapeDtypeStruct((5, 5), Float64) + spec2 = Reactant.ShapeDtypeStruct((5, 5), Float64) + compiled_f = Reactant.compile(f, (spec1, spec2)) + + x = Reactant.ConcreteRArray(rand(Float64, 5, 5)) + y = Reactant.ConcreteRArray(rand(Float64, 5, 5)) + result = compiled_f(x, y) + + @test result isa Reactant.ConcreteRArray{Float64,2} + @test result ≈ Array(x) .+ Array(y) + end + + @testset "ShapeDtypeStruct with different dtypes" begin + f(x) = sin.(x) + + for dtype in [Float32, Float64] + spec = Reactant.ShapeDtypeStruct((10,), dtype) + compiled_f = Reactant.compile(f, (spec,)) + + x = Reactant.ConcreteRArray(rand(dtype, 10)) + result = compiled_f(x) + + @test result isa Reactant.ConcreteRArray{dtype,1} + @test result ≈ sin.(Array(x)) + end + end + + @testset "ShapeDtypeStruct constructor variations" begin + # Test different constructor forms + spec1 = Reactant.ShapeDtypeStruct{Float32,2}((3, 4)) + @test size(spec1) == (3, 4) + @test eltype(spec1) == Float32 + @test ndims(spec1) == 2 + + spec2 = Reactant.ShapeDtypeStruct((3, 4), Float32) + @test size(spec2) == (3, 4) + @test eltype(spec2) == Float32 + + # Test with integer tuple (not Int tuple) + spec3 = Reactant.ShapeDtypeStruct((3, 4), Float64) + @test size(spec3) == (3, 4) + @test eltype(spec3) == Float64 + end +end