Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions src/hashconsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,62 @@ function isequal_argsvec(v1::ArgsT{T}, v2::ArgsT{T}, full::Bool) where {T}
return true
end

function metadata_isequal_metadict(m1::Base.ImmutableDict{DataType, Any}, m2::Base.ImmutableDict{DataType, Any})
xor(isdefined(m1, :parent), isdefined(m2, :parent)) && return false
isdefined(m1, :parent) || return true
isequal(m1.key, m2.key) || return false
metadata_isequal(m1.value, m2.value)::Bool || return false
return metadata_isequal_metadict(m1.parent, m2.parent)
end

function metadata_isequal(m1::MetadataT, m2::MetadataT)
@nospecialize m1 m2
if m1 === nothing && m2 === nothing
return true
elseif m1 === nothing || m2 === nothing
return false
end
return metadata_isequal_metadict(m1, m2)
end

function metadata_isequal(m1, m2)
@nospecialize m1 m2
typeof(m1) === typeof(m2) || return false
if m1 isa BasicSymbolic{SymReal} && m2 isa BasicSymbolic{SymReal}
return isequal_bsimpl(m1, m2, true)
elseif m1 isa BasicSymbolic{SafeReal} && m2 isa BasicSymbolic{SafeReal}
return isequal_bsimpl(m1, m2, true)
elseif m1 isa BasicSymbolic{TreeReal} && m2 isa BasicSymbolic{TreeReal}
return isequal_bsimpl(m1, m2, true)
elseif m1 isa Vector{BasicSymbolic{SymReal}} && m2 isa Vector{BasicSymbolic{SymReal}}
return isequal(m1, m2)
elseif m1 isa Vector{BasicSymbolic{SafeReal}} && m2 isa Vector{BasicSymbolic{SafeReal}}
return isequal(m1, m2)
elseif m1 isa Vector{BasicSymbolic{TreeReal}} && m2 isa Vector{BasicSymbolic{TreeReal}}
return isequal(m1, m2)
elseif m1 isa Matrix{BasicSymbolic{SymReal}} && m2 isa Matrix{BasicSymbolic{SymReal}}
return isequal(m1, m2)
elseif m1 isa Matrix{BasicSymbolic{SafeReal}} && m2 isa Matrix{BasicSymbolic{SafeReal}}
return isequal(m1, m2)
elseif m1 isa Matrix{BasicSymbolic{TreeReal}} && m2 isa Matrix{BasicSymbolic{TreeReal}}
return isequal(m1, m2)
elseif m1 isa Float64 && m2 isa Float64
return isequal(m1, m2)
elseif m1 isa Int && m2 isa Int
return isequal(m1, m2)
elseif m1 isa Vector{Float64} && m2 isa Vector{Float64}
return isequal(m1, m2)
elseif m1 isa Vector{Int} && m2 isa Vector{Int}
return isequal(m1, m2)
elseif m1 isa Matrix{Float64} && m2 isa Matrix{Float64}
return isequal(m1, m2)
elseif m1 isa Matrix{Int} && m2 isa Matrix{Int}
return isequal(m1, m2)
else
return isequal(m1, m2)::Bool
end
end

"""
$TYPEDSIGNATURES

Expand Down Expand Up @@ -144,7 +200,7 @@ function isequal_bsimpl(a::BSImpl.Type{T}, b::BSImpl.Type{T}, full::Bool) where
end
end
if full && partial && !(Ta <: BSImpl.Const)
partial = isequal(metadata(a), metadata(b))
partial = metadata_isequal(metadata(a), metadata(b))
end
return partial
end
Expand Down Expand Up @@ -286,6 +342,53 @@ function hash_maybe_fntype(T::TypeT, h::UInt)
end
end

function value_typed_hash(m, h::UInt)
@nospecialize m
h = hash(typeof(m), h)
if m isa BasicSymbolic{SymReal}
return hash_bsimpl(m, h, true)
elseif m isa BasicSymbolic{SafeReal}
return hash_bsimpl(m, h, true)
elseif m isa BasicSymbolic{TreeReal}
return hash_bsimpl(m, h, true)
elseif m isa Vector{BasicSymbolic{SymReal}}
return hash(m, h)
elseif m isa Vector{BasicSymbolic{SafeReal}}
return hash(m, h)
elseif m isa Vector{BasicSymbolic{TreeReal}}
return hash(m, h)
elseif m isa Matrix{BasicSymbolic{SymReal}}
return hash(m, h)
elseif m isa Matrix{BasicSymbolic{SafeReal}}
return hash(m, h)
elseif m isa Matrix{BasicSymbolic{TreeReal}}
return hash(m, h)
elseif m isa Float64
return hash(m, h)
elseif m isa Int
return hash(m, h)
else
return hash(m, h)::UInt
end
end

function hash_metadict(m::Base.ImmutableDict{DataType, Any}, h::UInt)
isdefined(m, :parent) || return h
kvh = hash(m.key, value_typed_hash(m.value, zero(UInt))) ⊻ h
return hash_metadict(m.parent, kvh)
end

function hash_metadata(m::MetadataT, h::UInt)
@nospecialize m
if m === nothing
return hash(nothing, h)
elseif m isa Base.ImmutableDict{DataType, Any}
hv = Base.hasha_seed
return hash(hash_metadict(m, hv), h)
end
_unreachable()
end

"""
hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}

Expand Down Expand Up @@ -359,7 +462,7 @@ function hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}
end

if full
partial = s.hash2 = Base.hash(metadata(s), partial)::UInt
partial = s.hash2 = hash_metadata(metadata(s), partial)::UInt
else
s.hash = partial
end
Expand Down
13 changes: 13 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,16 @@ end
@test ex2.hash2[] != h
end
end

@testset "Metadata type is considered in hashconsing" begin
@syms x
x1 = setmetadata(x, Int, 1)
x2 = setmetadata(x, Int, 1.0)
@test x1 !== x2
@test hash2(x1) != hash2(x2)

x1 = setmetadata(x, Int, [1])
x2 = setmetadata(x, Int, [1])
@test x1 === x2
@test hash2(x1) == hash2(x2)
end
Loading