Skip to content

Commit 89553e9

Browse files
feat: add unified System type
1 parent d54449d commit 89553e9

File tree

1 file changed

+290
-0
lines changed

1 file changed

+290
-0
lines changed

src/systems/system.jl

+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
struct System <: AbstractSystem
2+
tag::UInt
3+
eqs::Vector{Equation}
4+
noise_eqs::Union{Nothing, AbstractArray}
5+
jumps::Vector{Any}
6+
constraints::Vector{Union{Equation, Inequality}}
7+
unknowns::Vector
8+
ps::Vector
9+
brownians::Vector
10+
iv::Union{Nothing, BasicSymbolic{Real}, Vector{BasicSymbolic{Real}}}
11+
observed::Vector{Equation}
12+
parameter_dependencies::Vector{Equation}
13+
var_to_name::Dict{Symbol, Any}
14+
name::Symbol
15+
description::String
16+
defaults::Dict
17+
guesses::Dict
18+
systems::Vector{System}
19+
initialization_eqs::Vector{Equation}
20+
continuous_events::Vector{SymbolicContinuousCallback}
21+
discrete_events::Vector{SymbolicDiscreteCallback}
22+
connector_type::Any
23+
assertions::Dict{BasicSymbolic, String}
24+
metadata::Any
25+
gui_metadata::Any # ?
26+
is_dde::Bool
27+
tstops::Vector{Any}
28+
tearing_state::Any
29+
namespacing::Bool
30+
complete::Bool
31+
index_cache::Union{Nothing, IndexCache}
32+
ignored_connections::Union{
33+
Nothing, Tuple{Vector{IgnoredAnalysisPoint}, Vector{IgnoredAnalysisPoint}}}
34+
parent::Union{Nothing, System}
35+
36+
function System(tag, eqs, noise_eqs, jumps, constraints, unknowns, ps,
37+
brownians, iv, observed, parameter_dependencies, var_to_name, name, description,
38+
defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events,
39+
connector_type, assertions = Dict{BasicSymbolic, String}(),
40+
metadata = nothing, gui_metadata = nothing,
41+
is_dde = false, tstops = [], tearing_state = nothing, namespacing = true,
42+
complete = false, index_cache = nothing, ignored_connections = nothing,
43+
parent = nothing; checks::Union{Bool, Int} = true)
44+
45+
if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing
46+
check_independent_variables([iv])
47+
check_variables(unknowns, iv)
48+
check_parameters(ps, iv)
49+
check_equations(eqs, iv)
50+
if noise_eqs !== nothing && size(noise_eqs, 1) != length(eqs)
51+
throw(IllFormedNoiseEquationsError(size(noise_eqs, 1), length(eqs)))
52+
end
53+
check_equations(equations(continuous_events), iv)
54+
check_subsystems(systems)
55+
56+
end
57+
if checks == true || (checks & CheckUnits) > 0
58+
u = __get_unit_type(unknowns, ps, iv)
59+
check_units(u, eqs)
60+
noise_eqs !== nothing && check_units(u, noise_eqs)
61+
isempty(constraints) || check_units(u, constraints)
62+
end
63+
new(tag, eqs, noise_eqs, jumps, constraints, unknowns, ps, brownians, iv,
64+
observed, parameter_dependencies, var_to_name, name, description, defaults,
65+
guesses, systems, initialization_eqs, continuous_events, discrete_events,
66+
connector_type, assertions, metadata, gui_metadata, is_dde,
67+
tstops, tearing_state, namespacing, complete, index_cache, ignored_connections,
68+
parent)
69+
end
70+
end
71+
72+
function System(eqs, iv, dvs, ps, brownians = [];
73+
constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = [],
74+
observed = Equation[], parameter_dependencies = Equation[], defaults = Dict(),
75+
guesses = Dict(), systems = System[], initialization_eqs = Equation[],
76+
cevents = SymbolicContinuousCallback[], devents = SymbolicDiscreteCallback[],
77+
connector_type = nothing, assertions = Dict{BasicSymbolic, String}(),
78+
metadata = nothing, gui_metadata = nothing, is_dde = nothing, tstops = [],
79+
tearing_state = nothing, ignored_connections = nothing, parent = nothing,
80+
description = "", name = nothing, discover_from_metadata = true, checks = true)
81+
82+
name === nothing && throw(NoNameError())
83+
84+
iv = iv isa Array ? unwrap.(iv) : unwrap(iv)
85+
ps = unwrap.(ps)
86+
dvs = unwrap.(dvs)
87+
filter!(!Base.Fix2(isdelay, iv), dvs)
88+
brownians = unwrap.(brownians)
89+
90+
if !(eqs isa AbstractArray)
91+
eqs = [eqs]
92+
end
93+
94+
if noise_eqs !== nothing
95+
noise_eqs = unwrap.(noise_eqs)
96+
end
97+
98+
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps)
99+
defaults = anydict(defaults)
100+
guesses = anydict(guesses)
101+
var_to_name = anydict()
102+
103+
let defaults = discover_from_metadata ? defaults : Dict(),
104+
guesses = discover_from_metadata ? guesses : Dict()
105+
process_variables!(var_to_name, defaults, guesses, dvs)
106+
process_variables!(var_to_name, defaults, guesses, ps)
107+
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
108+
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
109+
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed])
110+
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed])
111+
end
112+
filter!(!(isnothing last), defaults)
113+
filter!(!(isnothing last), guesses)
114+
defaults = anydict([unwrap(k) => unwrap(v) for (k, v) in defaults])
115+
guesses = anydict([unwrap(k) => unwrap(v) for (k, v) in guesses])
116+
117+
sysnames = nameof.(systems)
118+
unique_sysnames = Set(sysnames)
119+
if length(unique_sysnames) != length(sysnames)
120+
throw(NonUniqueSubsystemsError(sysnames, unique_sysnames))
121+
end
122+
123+
cevents = SymbolicContinuousCallbacks(cevents)
124+
devents = SymbolicDiscreteCallbacks(devents)
125+
126+
if iv === nothing && !isempty(cevents) || !isempty(devents)
127+
throw(EventsInTimeIndependentSystemError(cevents, devents))
128+
end
129+
130+
if is_dde === nothing
131+
is_dde = _check_if_dde(eqs, iv, systems)
132+
end
133+
134+
assertions = Dict{BasicSymbolic, String}(unwrap(k) => v for (k, v) in assertions)
135+
136+
System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints, dvs, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, cevents, devents, connector_type, assertions, metadata, gui_metadata, is_dde, tstops, tearing_state, true, false, nothing, ignored_connections, parent; checks)
137+
end
138+
139+
function System(eqs, iv; kwargs...)
140+
iv === nothing && return System(eqs; kwargs...)
141+
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
142+
brownians = Set()
143+
for x in allunknowns
144+
x = unwrap(x)
145+
if getvariabletype(x) == BROWNIAN
146+
push!(brownians, x)
147+
end
148+
end
149+
setdiff!(allunknowns, brownians)
150+
151+
for eq in get(kwargs, :parameter_dependencies, Equation[])
152+
collect_vars!(allunknowns, ps, eq, iv)
153+
end
154+
155+
for eq in get(kwargs, :constraints, Equation[])
156+
collect_vars!(allunknowns, ps, eq, iv)
157+
end
158+
159+
for ssys in get(kwargs, :systems, System[])
160+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
161+
end
162+
163+
objective = get(kwargs, :objective, nothing)
164+
if objective !== nothing
165+
collect_vars!(allunknowns, ps, objective, iv)
166+
end
167+
168+
for v in allunknowns
169+
isdelay(v, iv) || continue
170+
collect_vars!(allunknowns, ps, arguments(v)[1], iv)
171+
end
172+
173+
new_ps = gather_array_params(ps)
174+
algevars = setdiff(allunknowns, diffvars)
175+
176+
noiseeqs = get(kwargs, :noise_eqs, nothing)
177+
if noiseeqs !== nothing
178+
# validate noise equations
179+
noisedvs = OrderedSet()
180+
noiseps = OrderedSet()
181+
collect_vars!(noisedvs, noiseps, noiseeqs, iv)
182+
for dv in noisedvs
183+
dv allunknowns ||
184+
throw(ArgumentError("Variable $dv in noise equations is not an unknown of the system."))
185+
end
186+
end
187+
188+
return System(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
189+
collect(new_ps), brownians; kwargs...)
190+
end
191+
192+
function System(eqs; kwargs...)
193+
eqs = collect(eqs)
194+
195+
allunknowns = OrderedSet()
196+
ps = OrderedSet()
197+
for eq in eqs
198+
collect_vars!(allunknowns, ps, eq, nothing)
199+
end
200+
for eq in get(kwargs, :parameter_dependencies, Equation[])
201+
collect_vars!(allunknowns, ps, eq, nothing)
202+
end
203+
for ssys in get(kwargs, :systems, System[])
204+
collect_scoped_vars!(allunknowns, ps, ssys, nothing)
205+
end
206+
207+
new_ps = gather_array_params(ps)
208+
209+
return System(eqs, nothing, collect(allunknowns), collect(new_ps); kwargs...)
210+
end
211+
212+
function gather_array_params(ps)
213+
new_ps = OrderedSet()
214+
for p in ps
215+
if iscall(p) && operation(p) === getindex
216+
par = arguments(p)[begin]
217+
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
218+
all(par[i] in ps for i in eachindex(par))
219+
push!(new_ps, par)
220+
else
221+
push!(new_ps, p)
222+
end
223+
else
224+
if symbolic_type(p) == ArraySymbolic() &&
225+
Symbolics.shape(unwrap(p)) != Symbolics.Unknown()
226+
for i in eachindex(p)
227+
delete!(new_ps, p[i])
228+
end
229+
end
230+
push!(new_ps, p)
231+
end
232+
end
233+
return new_ps
234+
end
235+
236+
struct IllFormedNoiseEquationsError <: Exception
237+
noise_eqs_rows::Int
238+
eqs_length::Int
239+
end
240+
241+
function Base.showerror(io::IO, err::IllFormedNoiseEquationsError)
242+
print(io, """
243+
Noise equations are ill-formed. The number of rows much must number of drift \
244+
equations. `size(neqs, 1) == $(err.noise_eqs_rows) != length(eqs) == \
245+
$(err.eqs_length)`.
246+
""")
247+
end
248+
249+
function NoNameError()
250+
ArgumentError("""
251+
The `name` keyword must be provided. Please consider using the `@named` macro.
252+
""")
253+
end
254+
255+
struct NonUniqueSubsystemsError <: Exception
256+
names::Vector{Symbol}
257+
uniques::Set{Symbol}
258+
end
259+
260+
function Base.showerror(io::IO, err::NonUniqueSubsystemsError)
261+
dupes = Set{Symbol}()
262+
for n in err.names
263+
if !(n in err.uniques)
264+
push!(dupes, n)
265+
end
266+
delete!(err.uniques, n)
267+
end
268+
println(io, "System names must be unique. The following system names were duplicated:")
269+
for n in dupes
270+
println(io, " ", n)
271+
end
272+
end
273+
274+
struct EventsInTimeIndependentSystemError <: Exception
275+
cevents::Vector
276+
devents::Vector
277+
end
278+
279+
function Base.showerror(io::IO, err::EventsInTimeIndependentSystemError)
280+
println(io, """
281+
Events are not supported in time-indepent systems. Provide an independent variable to \
282+
make the system time-dependent or remove the events.
283+
284+
The following continuous events were provided:
285+
$(err.cevents)
286+
287+
The following discrete events were provided:
288+
$(err.devents)
289+
""")
290+
end

0 commit comments

Comments
 (0)