Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,20 @@ overwriting.
See `@register_array_symbolic` to register functions which return arrays.
"""
macro register_symbolic(expr, define_promotion = true, wrap_arrays = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr)

args′ = map((a, T) -> :($a::$T), argnames, Ts)
f, ftype, argnames, Ts, is_typed, ret_type = destructure_registration_expr(expr)

# For typed scalar arguments, wrap in ExactType{T} to prevent symbolic dispatch.
# For untyped arguments (defaulting to Real), use Real directly.
# For array types, don't use ExactType since arrays can contain symbolic elements.
# This prevents exponential growth in method count - see issue #1724
args′ = map(argnames, Ts, is_typed) do a, T, typed
Teval = Base.eval(__module__, T)
if typed && !(Teval <: AbstractArray)
:($a::$ExactType{$T})
else
:($a::$T)
end
end
ret_type = isnothing(ret_type) ? Real : ret_type
N = length(args′)
symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}}
Expand Down Expand Up @@ -64,9 +75,12 @@ function destructure_registration_expr(expr)
f = expr.args[1]
args = expr.args[2:end]

# Default arg types to Real
# Default arg types to Real for untyped arguments
Ts = map(a -> a isa Symbol ? Real : (@assert(a.head == :(::)); a.args[2]), args)
argnames = map(a -> a isa Symbol ? a : a.args[1], args)
# Track which arguments were explicitly typed (not defaulting to Real)
# These should NOT get symbolic dispatch methods - see issue #1724
is_typed = map(a -> !(a isa Symbol), args)

ftype = if f isa Expr && f.head == :(::)
if length(f.args) == 1
Expand All @@ -77,7 +91,7 @@ function destructure_registration_expr(expr)
else
:($typeof($f))
end
f, ftype, argnames, Ts, ret_type
f, ftype, argnames, Ts, is_typed, ret_type
end

nested_unwrap(x) = unwrap(x)
Expand All @@ -96,7 +110,7 @@ symbolic_eltype(x::AbstractArray{BasicSymbolic{T}}) where {T} = eltype(symtype(C
symbolic_eltype(::AbstractArray{Num}) = Real
symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT

function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true)
function register_array_symbolic(f, ftype, argnames, Ts, is_typed, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true)
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
Expand All @@ -118,7 +132,18 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
eltype_expr = get(defs, :eltype, Any)
container_type = get(defs, :container_type, Array)

args′ = map((a, T) -> :($a::$T), argnames, Ts)
# For typed scalar arguments, wrap in ExactType{T} to prevent symbolic dispatch.
# For untyped arguments (defaulting to Real), use Real directly.
# For array types, don't use ExactType since arrays can contain symbolic elements.
# This prevents exponential growth in method count - see issue #1724
args′ = map(argnames, Ts, is_typed) do a, T, typed
Teval = Base.eval(@__MODULE__, T)
if typed && !(Teval <: AbstractArray)
:($a::$ExactType{$T})
else
:($a::$T)
end
end
N = length(args′)
symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}}
assigns = macroexpand(@__MODULE__, :(Base.Cartesian.@nexprs $N i -> ($argnames[i] = args[i])))
Expand Down Expand Up @@ -215,6 +240,6 @@ overloads for one function, all the rest of the registers must set
overwriting.
"""
macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr)
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays)
f, ftype, argnames, Ts, is_typed, ret_type = destructure_registration_expr(expr)
register_array_symbolic(f, ftype, argnames, Ts, is_typed, ret_type, block, define_promotion, wrap_arrays)
end
62 changes: 47 additions & 15 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
export @symbolic_wrap, @wrapped

"""
ExactType{T}

A marker type used internally by @register_symbolic to indicate that an argument
should accept only the specified type T, without generating symbolic dispatch methods.
This is used for explicitly typed arguments to prevent exponential growth in method
count. See https://github.com/JuliaSymbolics/Symbolics.jl/issues/1724
"""
struct ExactType{T} end

# Turn A{X} <: B{Int, X} into
#
# B{Int, X} where X
Expand Down Expand Up @@ -112,24 +122,46 @@ function wrap_func_expr(mod, expr, wrap_arrays = true)
# However later while emitting methods we omit the one
# method where all arguments are (1) since those are
# expected to be defined outside Symbolics
#
# Special case: ExactType{T} is a marker indicating that only T should be
# accepted (no symbolic dispatch). This is used by @register_symbolic for
# explicitly typed arguments. See https://github.com/JuliaSymbolics/Symbolics.jl/issues/1724
if arg isa Expr && arg.head == :(::)
T = Base.eval(mod, arg.args[2])
Ts = has_symwrapper(T) ? (T, BasicSymbolic{VartypeT}, wrapper_type(T)) :
(T, BasicSymbolic{VartypeT})
if T <: AbstractArray && wrap_arrays
eT = eltype(T)
if eT == Any
eT = Real
end
_arr_type_fn = if hasmethod(ndims, Tuple{Type{T}})
(elT) -> AbstractArray{S, ndims(T)} where {S <: elT}
else
(elT) -> AbstractArray{S} where {S <: elT}
# Check for ExactType marker from @register_symbolic
if T isa Type && T <: ExactType
Ts = (T.parameters[1],)
elseif has_symwrapper(T)
Ts = (T, BasicSymbolic{VartypeT}, wrapper_type(T))
if T <: AbstractArray && wrap_arrays
eT = eltype(T)
if eT == Any
eT = Real
end
_arr_type_fn = if hasmethod(ndims, Tuple{Type{T}})
(elT) -> AbstractArray{S, ndims(T)} where {S <: elT}
else
(elT) -> AbstractArray{S} where {S <: elT}
end
if has_symwrapper(eT)
Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}},
_arr_type_fn(wrapper_type(eT)))
else
Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}})
end
end
if has_symwrapper(eT)
Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}},
_arr_type_fn(wrapper_type(eT)))
else
else
Ts = (T, BasicSymbolic{VartypeT})
if T <: AbstractArray && wrap_arrays
eT = eltype(T)
if eT == Any
eT = Real
end
_arr_type_fn = if hasmethod(ndims, Tuple{Type{T}})
(elT) -> AbstractArray{S, ndims(T)} where {S <: elT}
else
(elT) -> AbstractArray{S} where {S <: elT}
end
Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}})
end
end
Expand Down
80 changes: 80 additions & 0 deletions test/register_method_count.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Test
using Symbolics

"""
Count the number of function definitions in a macro-expanded expression.
"""
function count_function_defs(expr)
count = Ref(0)
_count_function_defs!(expr, count)
return count[]
end

function _count_function_defs!(expr, count)
if expr isa Expr
if expr.head === :function
count[] += 1
end
for arg in expr.args
_count_function_defs!(arg, count)
end
end
end

@testset "Method count scaling for @register_symbolic" begin
# Test that the number of methods generated scales only with UNTYPED arguments.
#
# Structure of generated code:
# - 1 impl function (the actual implementation)
# - N dispatch functions (where N = prod(type_options per arg) - 1)
# - 1 promote_shape function
#
# Type options per argument:
# - Typed args (e.g., ::Float64) get 1 type option (the specified type)
# - Untyped args (defaulting to ::Real) get 3 type options: (Real, BasicSymbolic, Num)
#
# For k typed + m untyped args: 1^k * 3^m - 1 dispatch methods

# Case 1: 2 untyped args
# Type options: 3 × 3 = 9 combinations, minus (Real, Real) = 8 dispatch
# Total: 8 dispatch + 1 impl + 1 promote_shape = 10
f2u(a, b) = a * b
expr2u = @macroexpand @register_symbolic f2u(a, b)
n2u = count_function_defs(expr2u)
@test n2u == 10

# Case 2: 1 typed + 1 untyped
# Type options: 1 × 3 - 1 = 2 dispatch
# Total: 2 dispatch + 1 impl + 1 promote_shape = 4
f1t1u(a::Float64, b) = a * b
expr1t1u = @macroexpand @register_symbolic f1t1u(a::Float64, b)
n1t1u = count_function_defs(expr1t1u)
@test n1t1u == 4

# Case 3: 2 typed + 2 untyped
# Type options: 1 × 1 × 3 × 3 - 1 = 8 dispatch
# Total: 8 dispatch + 1 impl + 1 promote_shape = 10
f2t2u(a::Float64, b::Float64, c, d) = a * b * c * d
expr2t2u = @macroexpand @register_symbolic f2t2u(a::Float64, b::Float64, c, d)
n2t2u = count_function_defs(expr2t2u)
@test n2t2u == 10

# Case 4: 6 typed + 2 untyped
# Type options: 1^6 × 3^2 - 1 = 8 dispatch
# Total: 8 dispatch + 1 impl + 1 promote_shape = 10
f6t2u(a::Float64, b::Float64, c::Float64, d::Float64, e::Float64, f::Float64, g, h) =
a * b * c * d * e * f * g * h
expr6t2u = @macroexpand @register_symbolic f6t2u(
a::Float64, b::Float64, c::Float64, d::Float64, e::Float64, f::Float64, g, h)
n6t2u = count_function_defs(expr6t2u)
@test n6t2u == 10

# Case 5: All typed (4 args)
# Type options: 1^4 - 1 = 0 dispatch
# Total: 0 dispatch + 1 impl + 1 promote_shape = 2
f4t(a::Float64, b::Float64, c::Float64, d::Float64) = a * b * c * d
expr4t = @macroexpand @register_symbolic f4t(
a::Float64, b::Float64, c::Float64, d::Float64)
n4t = count_function_defs(expr4t)
@test n4t == 2
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Taylor Series Test" begin include("taylor.jl") end
@safetestset "Discontinuity registration test" begin include("discontinuities.jl") end
@safetestset "ODE solver test" begin include("diffeqs.jl") end
@safetestset "Method count test" begin include("register_method_count.jl") end
end
end

Expand Down
Loading