@@ -58,6 +58,13 @@ mutable struct TearingState <: StateSelection.TransformationState{System}
5858 """
5959 additional_observed:: Vector{Equation}
6060 statemachines:: Vector{System}
61+ """
62+ Source information for each equation in the `TearingState`. `Vector{Symbol}` for each
63+ equation representing the path of the subsystem to which it belongs. Empty entries
64+ indicate unknown source. If this field is empty, either the system has no equations
65+ or source information is unknown.
66+ """
67+ eqs_source:: Vector{Vector{Symbol}}
6168end
6269
6370function Base. show (io:: IO , state:: TearingState )
@@ -83,15 +90,37 @@ function Base.push!(ev::EquationsView, eq)
8390 push! (ev. ts. extra_eqs, eq)
8491end
8592
86- function TearingState (sys:: System ; check:: Bool = true , sort_eqs:: Bool = true )
93+ function TearingState (sys:: System , source_info :: Union{Nothing, MTKBase.EquationSourceInformation} = nothing ; check:: Bool = true , sort_eqs:: Bool = true )
8794 # flatten system
88- sys = MTKBase. flatten (sys)
95+ if source_info === nothing
96+ sys = MTKBase. flatten (sys)
97+ else
98+ @assert isempty (MTKBase. get_systems (sys)) """
99+ If `source_info` is provided to `TearingState`, the system must be flattened.
100+ """
101+ end
89102 sys = MTKBase. discrete_unknowns_to_parameters (sys)
90103 sys = MTKBase. discover_globalscoped (sys)
91104 MTKBase. check_no_parameter_equations (sys)
92105 iv = MTKBase. get_iv (sys)
106+ sources = Vector{Symbol}[]
93107 # flatten array equations
94- eqs = MTKBase. flatten_equations (equations (sys))
108+ if source_info === nothing
109+ eqs = MTKBase. flatten_equations (equations (sys))
110+ else
111+ eqs = Equation[]
112+ @assert length (equations (sys)) == length (source_info. eqs_source) """
113+ Mismatch between source information provided to `TearingState` and the structure \
114+ of the system.
115+ """
116+ for (eq, src) in zip (equations (sys), source_info. eqs_source)
117+ scal_eq = MTKBase. flatten_equation (eq)
118+ append! (eqs, scal_eq)
119+ for _ in scal_eq
120+ push! (sources, src)
121+ end
122+ end
123+ end
95124 original_eqs = copy (eqs)
96125 neqs = length (eqs)
97126 param_derivative_map = Dict {SymbolicT, SymbolicT} ()
@@ -234,6 +263,9 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
234263 eqs = eqs[sortidxs]
235264 original_eqs = original_eqs[sortidxs]
236265 symbolic_incidence = symbolic_incidence[sortidxs]
266+ if ! isempty (sources)
267+ sources = sources[sortidxs]
268+ end
237269 end
238270
239271 dervaridxs = OrderedSet {Int} ()
@@ -262,7 +294,7 @@ function TearingState(sys::System; check::Bool = true, sort_eqs::Bool = true)
262294 structure = SystemStructure (complete (var_to_diff), complete (eq_to_diff),
263295 complete (graph), nothing , var_types, false )
264296 return TearingState (sys, fullvars, structure, Equation[], param_derivative_map,
265- no_deriv_params, original_eqs, Equation[], typeof (sys)[])
297+ no_deriv_params, original_eqs, Equation[], typeof (sys)[], sources )
266298end
267299
268300function sort_fullvars (fullvars:: Vector{SymbolicT} , dervaridxs:: Vector{Int} , var_types:: Vector{VariableType} , @nospecialize (iv:: Union{SymbolicT, Nothing} ))
0 commit comments