Skip to content

[BUG]: Precision parameter not respected for float literals in TemplateExpressionSpec #1141

@GongJr0

Description

@GongJr0

What happened?

In writing an automatic template generator I've come across an issue that where constants inside a template's combine string cause type promotion to Float64. This may be intended behavior but I wanted to note the behavior's existence here as a bug report.

Example template:

TemplateExpressionSpec(
    combine="2*x + 1.0 + f1(x,y)", 
    expressions=["f1"],
    variable_names=["x", "y"],
)

Small script to reproduce the issue:

from pysr import PySRRegressor, TemplateExpressionSpec
import numpy as np

if __name__ == "__main__":
    binary_operators = ["+", "-", "*", "/"]
    unary_operators = ["sin", "asinh"]

    X = np.random.normal(size=(100, 2))

    noise = np.random.normal(scale=0.1, size=X.shape[0])
    y = 2*X[:, 0] + 1.0 + np.asinh(X[:, 0]) + np.sin(X[:, 1]) + noise

    spec = TemplateExpressionSpec(
        combine="2*x+1.0 + f1(x,y)",
        expressions=["f1"],
        variable_names=["x", "y"],
    )

    model = PySRRegressor(
        binary_operators=binary_operators,
        unary_operators=unary_operators,
        expression_spec=spec,
        niterations=10,
        precision=32,
    )

    model.fit(X, y)
    print(model.get_best())

Possible Solution

In my case, I implemented a sympy,printing.str.StrPrinter subclass to wrap any float literal with the appropriate precision. The implementation:

  1. Walks the expression as a sympy object and marks float literals by wrapping them with the JFloat function.
  2. Converts rationals to p//q notation to avoid literals (may be redundant)
  3. Substitutes all instances of JFloat with the appropriate f"Float{prec}" function.

The steps above convert the example template (2*x + 1.0 + f1(x,y)) to 'x*Float32(2//1) + Float32(1.00000000000000) + f1(x, y)' for the default precision of 32 bits.

Below is an implementation:

import sympy as sp
from sympy.printing.str import StrPrinter

from typing import Any

class JFloat(sp.Function):
    nargs = 1


class JuliaTypedPrinter(StrPrinter):
    """Printer that emits Julia-ish code and renders JFloat(x) as Float{prec}(x)."""

    def __init__(self, prec: int, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.prec = prec

    def _print_JFloat(self, expr: sp.Expr) -> str:
        # expr.args[0] is the numeric literal
        inner = expr.args[0]
        
        if isinstance(inner, sp.Float):
            # Use full precision string SymPy provides
            lit = sp.sstr(inner)
        elif isinstance(inner, sp.Rational):
            # Emit as "p//q" to avoid Julia parsing as Float64.
            lit = f"{inner.p}//{inner.q}"
        else:
            lit = sp.sstr(inner)

        return f"Float{self.prec}({lit})"

    def _print_Pow(self, expr: sp.Expr) -> str:
        # Julia uses ^, SymPy prints ** by default
        base, exp = expr.as_base_exp()
        return f"({self._print(base)})^({self._print(exp)})"

    def _print_Mul(self, expr: sp.Expr) -> str: 
        # SymPy may insert rationals like 1/2*x
        return "*".join(self._print(a) for a in expr.args)

    def _print_Add(self, expr: sp.Expr) -> str:
        return " + ".join(self._print(a) for a in expr.args)


def _needs_float_wrap(expr: sp.Expr) -> bool:
    """Return True for numeric atoms that should be typed as Float{prec}."""
    # SymPy Float => definitely wrap
    if isinstance(expr, sp.Float):
        return True
    # Rational => usually wrap to avoid Float64 promotion in Julia
    if isinstance(expr, sp.Rational):
        return True

    return False


def wrap_numeric_literals(expr: sp.Expr) -> sp.Expr:
    """Wrap numeric literals in JFloat(...) where appropriate."""

    def repl(e: sp.Expr) -> sp.Expr:
        if e.is_Number and _needs_float_wrap(e):
            return JFloat(e)  # pyright: ignore
        return e

    # Replace numeric atoms bottom-up
    return expr.replace(
        lambda e: e.is_Number and _needs_float_wrap(e), repl
    )


def sympy_to_julia_typed(expr: sp.Expr, prec: int) -> str:
    expr2 = wrap_numeric_literals(expr)
    out = JuliaTypedPrinter(prec=prec).doprint(expr2)
    return out

Notes

My implementation requires the combine string to be parsed in python and converted to sympy object before getting passed to julia. There might be a more elegant solution on the julia-side.

Version

1.5.9

Operating System

Windows

Package Manager

pip

Interface

Script (i.e., python my_script.py)

Relevant log output

juliacall.JuliaError: Element type of `x` is Float64 is different from element type of `y` which is Float32.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _loss(::Vector{Float64}, ::Vector{Float32}, ::L2DistLoss)
    @ SymbolicRegression.LossFunctionsModule C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\LossFunctions.jl:25
  [3] _eval_loss(tree::TemplateExpression{Float32, TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}, Node{Float32}, ComposableExpression{Float32, Node{Float32}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, eval_options::EvalOptions{false, false, true, Nothing}}}, @NamedTuple{f1::ComposableExpression{Float32, Node{Float32}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, eval_options::EvalOptions{false, false, true, Nothing}}}}, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}, operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, parameters::@NamedTuple{}}}, dataset::SymbolicRegression.CoreModule.DatasetModule.BasicDataset{Float32, Float32, Matrix{Float32}, Vector{Float32}, Nothing, @NamedTuple{}, Nothing, Nothing, Nothing, Nothing}, options::Options{SymbolicRegression.CoreModule.OptionsStructModule.ComplexityMapping{Int64, Int64}, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, Node, TemplateExpression, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}}, MutationWeights, false, false, nothing, Nothing, 5}, regularization::Bool)
    @ SymbolicRegression.LossFunctionsModule C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\LossFunctions.jl:109
  [4] eval_loss(tree::TemplateExpression{Float32, TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}, Node{Float32}, ComposableExpression{Float32, Node{Float32}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, eval_options::EvalOptions{false, false, true, Nothing}}}, @NamedTuple{f1::ComposableExpression{Float32, Node{Float32}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, eval_options::EvalOptions{false, false, true, Nothing}}}}, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}, operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, variable_names::Nothing, parameters::@NamedTuple{}}}, dataset::SymbolicRegression.CoreModule.DatasetModule.BasicDataset{Float32, Float32, Matrix{Float32}, Vector{Float32}, Nothing, @NamedTuple{}, Nothing, Nothing, Nothing, Nothing}, options::Options{SymbolicRegression.CoreModule.OptionsStructModule.ComplexityMapping{Int64, Int64}, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, Node, TemplateExpression, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}}, MutationWeights, false, false, nothing, Nothing, 5}; regularization::Bool, idx::Nothing)
    @ SymbolicRegression.LossFunctionsModule C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\LossFunctions.jl:155
  [5] eval_loss
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\LossFunctions.jl:139 [inlined]
  [6] update_baseline_loss!
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\LossFunctions.jl:225 [inlined]
  [7] _validate_options(datasets::Vector{SymbolicRegression.CoreModule.DatasetModule.BasicDataset{Float32, Float32, Matrix{Float32}, Vector{Float32}, Nothing, @NamedTuple{}, Nothing, Nothing, Nothing, Nothing}}, ropt::SymbolicRegression.SearchUtilsModule.RuntimeOptions{:multithreading, 1, true, Nothing}, options::Options{SymbolicRegression.CoreModule.OptionsStructModule.ComplexityMapping{Int64, Int64}, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, Node, TemplateExpression, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}}, MutationWeights, false, false, nothing, Nothing, 5})
    @ SymbolicRegression C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:597
  [8] _equation_search(datasets::Vector{SymbolicRegression.CoreModule.DatasetModule.BasicDataset{Float32, Float32, Matrix{Float32}, Vector{Float32}, Nothing, @NamedTuple{}, Nothing, Nothing, Nothing, Nothing}}, ropt::SymbolicRegression.SearchUtilsModule.RuntimeOptions{:multithreading, 1, true, Nothing}, options::Options{SymbolicRegression.CoreModule.OptionsStructModule.ComplexityMapping{Int64, Int64}, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, Node, TemplateExpression, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}}, MutationWeights, false, false, nothing, Nothing, 5}, saved_state::Nothing)
    @ SymbolicRegression C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:567
  [9] equation_search(datasets::Vector{SymbolicRegression.CoreModule.DatasetModule.BasicDataset{Float32, Float32, Matrix{Float32}, Vector{Float32}, Nothing, @NamedTuple{}, Nothing, Nothing, Nothing, Nothing}}; options::Options{SymbolicRegression.CoreModule.OptionsStructModule.ComplexityMapping{Int64, Int64}, OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(sin), typeof(asinh)}}, Node, TemplateExpression, @NamedTuple{structure::TemplateStructure{(:f1,), (), typeof(__sr_template_8838205672071328129), @NamedTuple{f1::Int64}, @NamedTuple{}}}, MutationWeights, false, false, nothing, Nothing, 5}, saved_state::Nothing, runtime_options::Nothing, runtime_options_kws::@Kwargs{niterations::Int64, parallelism::String, numprocs::Nothing, procs::Nothing, addprocs_function::Nothing, heap_size_hint_in_bytes::Nothing, worker_imports::Nothing, runtests::Bool, return_state::Bool, run_id::String, verbosity::Int64, logger::Nothing, progress::Bool, v_dim_out::Val{1}})
    @ SymbolicRegression C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:561
 [10] equation_search
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:542 [inlined]
 [11] #equation_search#23
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:511 [inlined]
 [12] equation_search
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:456 [inlined]
 [13] #equation_search#24
    @ C:\Users\guney\.julia\packages\SymbolicRegression\L5TJa\src\SymbolicRegression.jl:535 [inlined]
 [14] pyjlany_call(self::typeof(equation_search), args_::Py, kwargs_::Py)
    @ PythonCall.JlWrap C:\Users\guney\.julia\packages\PythonCall\avYrV\src\JlWrap\any.jl:44
 [15] _pyjl_callmethod(f::Any, self_::Ptr{PythonCall.C.PyObject}, args_::Ptr{PythonCall.C.PyObject}, nargs::Int64)
    @ PythonCall.JlWrap C:\Users\guney\.julia\packages\PythonCall\avYrV\src\JlWrap\base.jl:73
 [16] _pyjl_callmethod(o::Ptr{PythonCall.C.PyObject}, args::Ptr{PythonCall.C.PyObject})
    @ PythonCall.JlWrap.Cjl C:\Users\guney\.julia\packages\PythonCall\avYrV\src\JlWrap\C.jl:63

Extra Info

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions