Skip to content

Commit 1810743

Browse files
Merge pull request #20 from JuliaComputing/as/source-tracking
feat: allow tracking equation source information in `TearingState`
2 parents 2689e8a + 4b840e3 commit 1810743

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

lib/ModelingToolkitTearing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ BipartiteGraphs = "0.1.3"
2525
CommonSolve = "0.2"
2626
DocStringExtensions = "0.7, 0.8, 0.9"
2727
Graphs = "1"
28-
ModelingToolkitBase = "1.0.0"
28+
ModelingToolkitBase = "1.2.0"
2929
Moshi = "0.3"
3030
OffsetArrays = "1"
3131
OrderedCollections = "1.8.1"

lib/ModelingToolkitTearing/src/clock_inference/interface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}, iieqs::Vector{Int},
7373
@set! ts.sys.initialization_eqs = initeqs[iieqs]
7474
@set! ts.original_eqs = ts.original_eqs[ieqs]
7575
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
76+
if !isempty(ts.eqs_source)
77+
@set! ts.eqs_source = ts.eqs_source[ieqs]
78+
end
7679
if all(eq -> eq.rhs isa StateMachineOperator, MTKBase.get_eqs(ts.sys))
7780
names = Symbol[]
7881
for eq in MTKBase.get_eqs(ts.sys)

lib/ModelingToolkitTearing/src/tearingstate.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ mutable struct TearingState <: StateSelection.TransformationState{System}
5858
"""
5959
additional_observed::Vector{Equation}
6060
statemachines::Vector{System}
61+
"""
62+
Source information for each equation in the `TearingState`. `Vector{Symbol}` for each
63+
equation representing the path of the subsystem to which it belongs. Empty entries
64+
indicate unknown source. If this field is empty, either the system has no equations
65+
or source information is unknown.
66+
"""
67+
eqs_source::Vector{Vector{Symbol}}
6168
end
6269

6370
function Base.show(io::IO, state::TearingState)
@@ -83,15 +90,37 @@ function Base.push!(ev::EquationsView, eq)
8390
push!(ev.ts.extra_eqs, eq)
8491
end
8592

86-
function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
93+
function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationSourceInformation} = nothing; check::Bool = true, sort_eqs::Bool = true)
8794
# flatten system
88-
sys = MTKBase.flatten(sys)
95+
if source_info === nothing
96+
sys = MTKBase.flatten(sys)
97+
else
98+
@assert isempty(MTKBase.get_systems(sys)) """
99+
If `source_info` is provided to `TearingState`, the system must be flattened.
100+
"""
101+
end
89102
sys = MTKBase.discrete_unknowns_to_parameters(sys)
90103
sys = MTKBase.discover_globalscoped(sys)
91104
MTKBase.check_no_parameter_equations(sys)
92105
iv = MTKBase.get_iv(sys)
106+
sources = Vector{Symbol}[]
93107
# flatten array equations
94-
eqs = MTKBase.flatten_equations(equations(sys))
108+
if source_info === nothing
109+
eqs = MTKBase.flatten_equations(equations(sys))
110+
else
111+
eqs = Equation[]
112+
@assert length(equations(sys)) == length(source_info.eqs_source) """
113+
Mismatch between source information provided to `TearingState` and the structure \
114+
of the system.
115+
"""
116+
for (eq, src) in zip(equations(sys), source_info.eqs_source)
117+
scal_eq = MTKBase.flatten_equation(eq)
118+
append!(eqs, scal_eq)
119+
for _ in scal_eq
120+
push!(sources, src)
121+
end
122+
end
123+
end
95124
original_eqs = copy(eqs)
96125
neqs = length(eqs)
97126
param_derivative_map = Dict{SymbolicT, SymbolicT}()
@@ -234,6 +263,9 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
234263
eqs = eqs[sortidxs]
235264
original_eqs = original_eqs[sortidxs]
236265
symbolic_incidence = symbolic_incidence[sortidxs]
266+
if !isempty(sources)
267+
sources = sources[sortidxs]
268+
end
237269
end
238270

239271
dervaridxs = OrderedSet{Int}()
@@ -262,7 +294,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
262294
structure = SystemStructure(complete(var_to_diff), complete(eq_to_diff),
263295
complete(graph), nothing, var_types, false)
264296
return TearingState(sys, fullvars, structure, Equation[], param_derivative_map,
265-
no_deriv_params, original_eqs, Equation[], typeof(sys)[])
297+
no_deriv_params, original_eqs, Equation[], typeof(sys)[], sources)
266298
end
267299

268300
function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var_types::Vector{VariableType}, @nospecialize(iv::Union{SymbolicT, Nothing}))

0 commit comments

Comments
 (0)