Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ within_compile
```@docs
ConcreteRArray
ConcreteRNumber
ShapeDtypeStruct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a personal opinion so you can ignore it.

since this models an "array" in the end, wouldn't it be better suffixing RArray?

also dtype name is used in numpy, where in Julia we use eltype...

so how about ShapeRArray? GhostRArray? ShadowRArray?

```

## Inspect Generated HLO
Expand Down
94 changes: 94 additions & 0 deletions examples/shape_dtype_struct_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Example: Using ShapeDtypeStruct for compilation without concrete arrays
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be the first script in the "examples" folder. IMHO they can quickly become outdated and confusing for new users if we don't test them. specially in Reactant where we have been breaking API quite a lot.

so... maybe this is better handled in the docs as an @example?

#
# This example demonstrates how to use Reactant.ShapeDtypeStruct to compile
# functions without having to construct full ConcreteRArray instances with
# actual data. This is useful for:
# 1. Faster compilation when you only need shape/dtype information
# 2. Memory efficiency when working with large arrays
# 3. Similar workflow to JAX's ShapeDtypeStruct

using Reactant

# Example 1: Basic usage with a simple function
println("Example 1: Basic compilation with ShapeDtypeStruct")
println("=" ^ 60)

# Define a simple function that sums an array
f_sum(x) = sum(x)

# Instead of creating a full ConcreteRArray:
# x = Reactant.ConcreteRArray(rand(Float32, 10, 20)) # This allocates memory!

# Use ShapeDtypeStruct to specify only shape and dtype:
spec = Reactant.ShapeDtypeStruct((10, 20), Float32)
println("Created ShapeDtypeStruct: ", spec)
println(" Shape: ", size(spec))
println(" Element type: ", eltype(spec))
println(" Dimensions: ", ndims(spec))

# Compile the function using the spec
compiled_f_sum = Reactant.compile(f_sum, (spec,))
println("✓ Function compiled successfully")

# Now execute with actual data
x_actual = Reactant.ConcreteRArray(rand(Float32, 10, 20))
result = compiled_f_sum(x_actual)
println("Result: ", result, " (type: ", typeof(result), ")")
println()

# Example 2: Multiple arguments
println("Example 2: Multiple arguments with ShapeDtypeStruct")
println("=" ^ 60)

f_add(x, y) = x .+ y

spec1 = Reactant.ShapeDtypeStruct((5, 5), Float64)
spec2 = Reactant.ShapeDtypeStruct((5, 5), Float64)

compiled_f_add = Reactant.compile(f_add, (spec1, spec2))
println("✓ Function with 2 arguments compiled")

x_data = Reactant.ConcreteRArray(rand(Float64, 5, 5))
y_data = Reactant.ConcreteRArray(rand(Float64, 5, 5))
result_add = compiled_f_add(x_data, y_data)
println("Result shape: ", size(result_add))
println()

# Example 3: Different dtypes
println("Example 3: Compilation with different dtypes")
println("=" ^ 60)

f_sin(x) = sin.(x)

for dtype in [Float32, Float64]
spec = Reactant.ShapeDtypeStruct((100,), dtype)
compiled = Reactant.compile(f_sin, (spec,))

x = Reactant.ConcreteRArray(rand(dtype, 100))
result = compiled(x)
println("✓ Compiled and ran for dtype: ", dtype)
end
println()

# Example 4: Benefits demonstration
println("Example 4: Memory efficiency")
println("=" ^ 60)

# For very large arrays, you can compile without allocating the full array:
large_spec = Reactant.ShapeDtypeStruct((10000, 10000), Float32)
println("Created spec for large array: ", size(large_spec))
println(" This doesn't allocate ", prod(size(large_spec)) * sizeof(Float32) / 1e9, " GB of memory!")

# Compile a function for this large array
f_large(x) = sum(x .* x)
compiled_large = Reactant.compile(f_large, (large_spec,))
println("✓ Compiled function for large array without allocating memory")
println()

println("All examples completed successfully!")
println()
println("Key Takeaways:")
println("1. ShapeDtypeStruct allows compilation without data allocation")
println("2. Same compiled function can be used with actual ConcreteRArray data")
println("3. Useful for large arrays and rapid prototyping")
println("4. Similar API to JAX's ShapeDtypeStruct")
5 changes: 5 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ unwrapped_eltype(::TracedRNumber{T}) where {T} = T
unwrapped_eltype(::Type{<:AbstractArray{T,N}}) where {T,N} = unwrapped_eltype(T)
unwrapped_eltype(::AbstractArray{T,N}) where {T,N} = unwrapped_eltype(T)

# For ShapeDtypeStruct
unwrapped_eltype(::Type{ShapeDtypeStruct{T,N}}) where {T,N} = T
unwrapped_eltype(::ShapeDtypeStruct{T,N}) where {T,N} = T

include("Ops.jl")
Base.push!(no_rewrite_ancestor_modules, Ops)

Expand Down Expand Up @@ -288,6 +292,7 @@ export ConcreteRArray,
ConcretePJRTNumber,
ConcreteIFRTArray,
ConcreteIFRTNumber,
ShapeDtypeStruct,
@compile,
@code_hlo,
@code_mhlo,
Expand Down
36 changes: 36 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,42 @@ Base.@nospecializeinfer function make_tracer(
return res
end

Base.@nospecializeinfer function make_tracer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I think you're missing the traced_type_inner rule, right?

seen,
@nospecialize(prev::ShapeDtypeStruct{T,N}),
@nospecialize(path),
mode;
kwargs...,
) where {T,N}
if mode == TracedToTypes
throw(
ArgumentError(
"ShapeDtypeStruct cannot be used as a function call argument; it is only valid for compilation signatures."
),
)
end
if mode == ArrayToConcrete
throw(
ErrorException(
"Cannot convert ShapeDtypeStruct to ConcreteRArray. ShapeDtypeStruct is only for compilation signatures."
),
)
end
# ShapeDtypeStruct behaves like ConcreteToTraced mode - creates a TracedRArray without data
# Accept both ConcreteToTraced and TracedSetPath modes
if mode != ConcreteToTraced && mode != TracedSetPath
throw(
ArgumentError(
"ShapeDtypeStruct can only be used with ConcreteToTraced or TracedSetPath mode, got $mode"
),
)
end
Comment on lines +1417 to +1439
Copy link
Collaborator

@mofeing mofeing Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be fused into one check

haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
res = TracedRArray{T,N}((path,), nothing, size(prev))
seen[prev] = res
return res
end

Base.@nospecializeinfer function make_tracer(
seen,
prev::ConcretePJRTNumber{T},
Expand Down
49 changes: 49 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,55 @@ const AnyTracedRVector{T} = AnyTracedRArray{T,1}
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

## ShapeDtypeStruct
"""
ShapeDtypeStruct{T,N}(shape::NTuple{N,Int})
ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N}

Lightweight structure that specifies the shape and element type (dtype) of an array
without allocating the actual array data. Similar to JAX's `ShapeDtypeStruct`.

This is useful for compiling functions without constructing the full `ConcreteRArray`,
which can save memory and improve compilation performance.

# Examples
```julia
# Specify shape and dtype for a 2D array
spec = Reactant.ShapeDtypeStruct((10, 20), Float32)

# Compile a function using just the spec
f(x) = sum(x)
compiled_f = Reactant.compile(f, (spec,))

# Execute with actual data
x = Reactant.ConcreteRArray(rand(Float32, 10, 20))
result = compiled_f(x)
```

See also: [`compile`](@ref), [`ConcreteRArray`](@ref)
"""
struct ShapeDtypeStruct{T,N}
shape::NTuple{N,Int}
Comment on lines +161 to +162
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a sharding field is missing. We won't be able to shard otherwise.

The jax docs say this field is optional.


function ShapeDtypeStruct{T,N}(shape::NTuple{N,Int}) where {T,N}
return new{T,N}(shape)
end
end

function ShapeDtypeStruct(shape::NTuple{N,Int}, dtype::Type{T}) where {T,N}
return ShapeDtypeStruct{T,N}(shape)
end

function ShapeDtypeStruct(shape::Tuple{Vararg{Integer}}, dtype::Type{T}) where {T}
return ShapeDtypeStruct(map(Int, shape), dtype)
end

Base.size(x::ShapeDtypeStruct) = x.shape
Base.ndims(::ShapeDtypeStruct{T,N}) where {T,N} = N
Base.eltype(::ShapeDtypeStruct{T}) where {T} = T

@leaf ShapeDtypeStruct

# Concrete Types
## ConcretePJRTNumber
mutable struct ConcretePJRTNumber{T,D} <: AbstractConcreteNumber{T}
Expand Down
65 changes: 65 additions & 0 deletions test/core/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,68 @@ end
@test Array(y[:Mhalo]) ≈ [1.0f0, 2.0f0]
@test Array(y[:x]) ≈ [2.0f0, 3.0f0]
end

@testset "ShapeDtypeStruct compilation" begin
@testset "Basic compilation with ShapeDtypeStruct" begin
# Define a simple function
f(x) = sum(x)

# Compile using ShapeDtypeStruct instead of ConcreteRArray
spec = Reactant.ShapeDtypeStruct((10, 20), Float32)
compiled_f = Reactant.compile(f, (spec,))

# Execute with actual data
x = Reactant.ConcreteRArray(rand(Float32, 10, 20))
result = compiled_f(x)

@test result isa Reactant.ConcreteRNumber{Float32}
@test result ≈ sum(Array(x))
end

@testset "Multiple arguments with ShapeDtypeStruct" begin
f(x, y) = x .+ y

spec1 = Reactant.ShapeDtypeStruct((5, 5), Float64)
spec2 = Reactant.ShapeDtypeStruct((5, 5), Float64)
compiled_f = Reactant.compile(f, (spec1, spec2))

x = Reactant.ConcreteRArray(rand(Float64, 5, 5))
y = Reactant.ConcreteRArray(rand(Float64, 5, 5))
result = compiled_f(x, y)

@test result isa Reactant.ConcreteRArray{Float64,2}
@test result ≈ Array(x) .+ Array(y)
end

@testset "ShapeDtypeStruct with different dtypes" begin
f(x) = sin.(x)

for dtype in [Float32, Float64]
spec = Reactant.ShapeDtypeStruct((10,), dtype)
compiled_f = Reactant.compile(f, (spec,))

x = Reactant.ConcreteRArray(rand(dtype, 10))
result = compiled_f(x)

@test result isa Reactant.ConcreteRArray{dtype,1}
@test result ≈ sin.(Array(x))
end
end

@testset "ShapeDtypeStruct constructor variations" begin
# Test different constructor forms
spec1 = Reactant.ShapeDtypeStruct{Float32,2}((3, 4))
@test size(spec1) == (3, 4)
@test eltype(spec1) == Float32
@test ndims(spec1) == 2

spec2 = Reactant.ShapeDtypeStruct((3, 4), Float32)
@test size(spec2) == (3, 4)
@test eltype(spec2) == Float32

# Test with integer tuple (not Int tuple)
spec3 = Reactant.ShapeDtypeStruct((3, 4), Float64)
@test size(spec3) == (3, 4)
@test eltype(spec3) == Float64
end
end