Skip to content

Commit 8455e85

Browse files
feat: support SDEProblem and SDEFunction for System
1 parent 1f642fc commit 8455e85

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

src/problems/compatibility.jl

+10
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,13 @@ function check_no_noise(sys::System, T)
7979
"""))
8080
end
8181
end
82+
83+
function check_has_noise(sys::System, T)
84+
altT = is_dde(sys) ? DDEProblem : ODEProblem
85+
if get_noise_eqs(sys) === nothing
86+
throw(SystemCompatibilityError("""
87+
A system without noise cannot be used to construct a `$T`. Consider an \
88+
`$altT` instead.
89+
"""))
90+
end
91+
end

src/problems/sdeproblem.jl

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
@fallback_iip_specialize function SciMLBase.SDEFunction{iip, spec}(
2+
sys::System, _d = nothing, u0 = nothing, p = nothing; tgrad = false, jac = false,
3+
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
4+
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
5+
simplify = false, cse = true, initialization_data = nothing,
6+
check_compatibility = true, kwargs...) where {iip, spec}
7+
check_complete(sys, SDEFunction)
8+
check_compatibility && check_compatible_system(SDEFunction, sys)
9+
10+
dvs = unknowns(sys)
11+
ps = parameters(sys)
12+
f = generate_rhs(sys, dvs, ps; expression = Val{false},
13+
eval_expression, eval_module, checkbounds = checkbounds, cse,
14+
kwargs...)
15+
g = generate_diffusion_function(sys, dvs, ps; expression = Val{false},
16+
eval_expression, eval_module, checkbounds, cse, kwargs...)
17+
18+
if spec === SciMLBase.FunctionWrapperSpecialize && iip
19+
if u0 === nothing || p === nothing || t === nothing
20+
error("u0, p, and t must be specified for FunctionWrapperSpecialize on ODEFunction.")
21+
end
22+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
23+
end
24+
25+
if tgrad
26+
_tgrad = generate_tgrad(sys, dvs, ps; expression = Val{false},
27+
simplify, cse, eval_expression, eval_module, checkbounds, kwargs...)
28+
else
29+
_tgrad = nothing
30+
end
31+
32+
if jac
33+
_jac = generate_jacobian(sys, dvs, ps; expression = Val{false},
34+
simplify, sparse, cse, eval_expression, eval_module, checkbounds, kwargs...)
35+
else
36+
_jac = nothing
37+
end
38+
39+
M = calculate_massmatrix(sys)
40+
_M = concrete_massmatrix(M; sparse, u0)
41+
42+
observedfun = ObservedFunctionCache(
43+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
44+
45+
_W_sparsity = W_sparsity(sys)
46+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
47+
48+
SDEFunction{iip, spec}(f, g;
49+
sys = sys,
50+
jac = _jac,
51+
tgrad = _tgrad,
52+
mass_matrix = _M,
53+
jac_prototype = W_prototype,
54+
observed = observedfun,
55+
sparsity = sparsity ? _W_sparsity : nothing,
56+
analytic = analytic,
57+
initialization_data)
58+
end
59+
60+
@fallback_iip_specialize function SciMLBase.SDEProblem{iip, spec}(
61+
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
62+
callback = nothing, check_length = true, eval_expression = false,
63+
eval_module = @__MODULE__, check_compatibility = true, sparse = false,
64+
sparsenoise = sparse, kwargs...) where {iip, spec}
65+
check_complete(sys, SDEProblem)
66+
check_compatibility && check_compatible_system(SDEProblem, sys)
67+
68+
f, u0, p = process_SciMLProblem(SDEFunction{iip, spec}, sys, u0map, parammap;
69+
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
70+
eval_module, check_compatibility, sparse, kwargs...)
71+
72+
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
73+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
74+
# Call `remake` so it runs initialization if it is trivial
75+
return remake(SDEProblem{iip}(f, u0, tspan, p; noise, noise_rate_prototype, kwargs...))
76+
end
77+
78+
function check_compatible_system(T::Union{Type{SDEFunction}, Type{SDEProblem}}, sys::System)
79+
check_time_dependent(sys, T)
80+
check_not_dde(sys)
81+
check_no_cost(sys, T)
82+
check_no_constraints(sys, T)
83+
check_no_jumps(sys, T)
84+
check_has_noise(sys, T)
85+
end
86+
87+
function calculate_noise_and_rate_prototype(sys::System, u0; sparsenoise = false)
88+
noiseeqs = get_noise_eqs(sys)
89+
if noiseeqs isa AbstractVector
90+
# diagonal noise
91+
noise_rate_prototype = nothing
92+
noise = nothing
93+
elseif size(noiseeqs, 2) == 1
94+
# scalar noise
95+
noise_rate_prototype = nothing
96+
noise = WienerProcess(0.0, 0.0, 0.0)
97+
elseif sparsenoise
98+
I, J, V = findnz(SparseArrays.sparse(noiseeqs))
99+
noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0)))
100+
noise = nothing
101+
else
102+
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
103+
noise = nothing
104+
end
105+
return noise, noise_rate_prototype
106+
end

src/systems/codegen.jl

+13
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ function generate_rhs(sys::System, dvs = unknowns(sys),
8282
f_oop, f_iip)
8383
end
8484

85+
function generate_diffusion_function(sys::System, dvs = unknowns(sys),
86+
ps = parameters(sys; initial_parameters = true); expression = Val{true}, eval_expression = false,
87+
eval_module = @__MODULE__, kwargs...)
88+
eqs = get_noise_eqs(sys)
89+
p = reorder_parameters(sys, ps)
90+
res = build_function_wrapper(sys, eqs, dvs, p..., get_iv(sys); kwargs...)
91+
if expression == Val{true}
92+
return res
93+
end
94+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
95+
return GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
96+
end
97+
8598
function calculate_tgrad(sys::System; simplify = false)
8699
# We need to remove explicit time dependence on the unknown because when we
87100
# have `u(t) * t` we want to have the tgrad to be `u(t)` instead of `u'(t) *

0 commit comments

Comments
 (0)