Skip to content

Commit ba0807d

Browse files
Merge pull request #3572 from AayushSabharwal/as/remake-copy-initials
fix: copy initials to `u0` if `u0` not provided to `remake`
2 parents fcf519a + 998a516 commit ba0807d

File tree

4 files changed

+163
-22
lines changed

4 files changed

+163
-22
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ DataInterpolations = "6.4"
9595
DataStructures = "0.17, 0.18"
9696
DeepDiffs = "1"
9797
DelayDiffEq = "5.50"
98-
DiffEqBase = "6.165.1"
98+
DiffEqBase = "6.170.1"
9999
DiffEqCallbacks = "2.16, 3, 4"
100100
DiffEqNoiseProcess = "5"
101101
DiffRules = "0.1, 1.0"

src/systems/nonlinear/initializesystem.jl

+54-19
Original file line numberDiff line numberDiff line change
@@ -589,26 +589,41 @@ function SciMLBase.remake_initialization_data(
589589
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
590590
end
591591

592+
function promote_u0_p(u0, p::MTKParameters, t0)
593+
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
594+
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
595+
596+
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
597+
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
598+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
599+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
600+
601+
return u0, p
602+
end
603+
604+
function promote_u0_p(u0, p::AbstractArray, t0)
605+
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
606+
end
607+
592608
function SciMLBase.late_binding_update_u0_p(
593609
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
594610
supports_initialization(sys) || return newu0, newp
595-
u0 === missing && return newu0, (p === missing ? copy(newp) : newp)
611+
612+
initdata = prob.f.initialization_data
613+
meta = initdata === nothing ? nothing : initdata.metadata
614+
615+
newu0, newp = promote_u0_p(newu0, newp, t0)
616+
596617
# non-symbolic u0 updates initials...
597618
if !(eltype(u0) <: Pair)
598619
# if `p` is not provided or is symbolic
599620
p === missing || eltype(p) <: Pair || return newu0, newp
600621
(newu0 === nothing || isempty(newu0)) && return newu0, newp
601-
initdata = prob.f.initialization_data
602622
initdata === nothing && return newu0, newp
603623
meta = initdata.metadata
604624
meta isa InitializationMetadata || return newu0, newp
605625
newp = p === missing ? copy(newp) : newp
606-
initials, repack, alias = SciMLStructures.canonicalize(
607-
SciMLStructures.Initials(), newp)
608-
if eltype(initials) != eltype(newu0)
609-
initials = DiffEqBase.promote_u0(initials, newu0, t0)
610-
newp = repack(initials)
611-
end
626+
612627
if length(newu0) != length(prob.u0)
613628
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
614629
end
@@ -617,17 +632,6 @@ function SciMLBase.late_binding_update_u0_p(
617632
end
618633

619634
newp = p === missing ? copy(newp) : newp
620-
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
621-
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
622-
if eltype(tunables) != eltype(newu0)
623-
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
624-
newp = repack(tunables)
625-
end
626-
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
627-
if eltype(initials) != eltype(newu0)
628-
initials = DiffEqBase.promote_u0(initials, newu0, t0)
629-
newp = repack(initials)
630-
end
631635

632636
allsyms = all_symbols(sys)
633637
for (k, v) in u0
@@ -646,6 +650,37 @@ function SciMLBase.late_binding_update_u0_p(
646650
return newu0, newp
647651
end
648652

653+
function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
654+
supports_initialization(sys) || return prob
655+
initdata = prob.f.initialization_data
656+
initdata isa SciMLBase.OverrideInitData || return prob
657+
meta = initdata.metadata
658+
meta isa InitializationMetadata || return prob
659+
meta.get_updated_u0 === nothing && return prob
660+
661+
u0 = state_values(prob)
662+
u0 === nothing && return prob
663+
664+
p = parameter_values(prob)
665+
t0 = is_time_dependent(prob) ? current_time(prob) : nothing
666+
667+
if p isa MTKParameters
668+
buffer = p.initials
669+
else
670+
buffer = p
671+
end
672+
673+
u0 = DiffEqBase.promote_u0(u0, buffer, t0)
674+
675+
if ArrayInterface.ismutable(u0)
676+
T = typeof(u0)
677+
else
678+
T = StaticArrays.similar_type(u0)
679+
end
680+
681+
return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)))
682+
end
683+
649684
"""
650685
$(TYPEDSIGNATURES)
651686

src/systems/problem_utils.jl

+57-2
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ properly.
769769
770770
$(TYPEDFIELDS)
771771
"""
772-
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
772+
struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU}
773773
"""
774774
The `u0map` used to construct the initialization.
775775
"""
@@ -796,12 +796,62 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
796796
"""
797797
oop_reconstruct_u0_p::R
798798
"""
799+
A function which takes `(prob, initializeprob)` and return the `u0` to use for the problem.
800+
"""
801+
get_updated_u0::GUU
802+
"""
799803
A function which takes the `u0` of the problem and sets
800804
`Initial.(unknowns(sys))`.
801805
"""
802806
set_initial_unknowns!::SIU
803807
end
804808

809+
"""
810+
$(TYPEDEF)
811+
812+
A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
813+
Returns the value to use for the `u0` of the problem.
814+
815+
# Fields
816+
817+
$(TYPEDFIELDS)
818+
"""
819+
struct GetUpdatedU0{GG, GIU}
820+
"""
821+
Mask with length `length(unknowns(sys))` denoting indices of variables which should
822+
take the guess value from `initializeprob`.
823+
"""
824+
guessvars::BitVector
825+
"""
826+
Function which returns the values of variables in `initializeprob` for which
827+
`guessvars` is `true`, in the order they occur in `unknowns(sys)`.
828+
"""
829+
get_guessvars::GG
830+
"""
831+
Function which returns `Initial.(unknowns(sys))` as a `Vector`.
832+
"""
833+
get_initial_unknowns::GIU
834+
end
835+
836+
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
837+
dvs = unknowns(sys)
838+
eqs = equations(sys)
839+
guessvars = trues(length(dvs))
840+
for (i, var) in enumerate(dvs)
841+
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
842+
end
843+
get_guessvars = getu(initsys, dvs[guessvars])
844+
get_initial_unknowns = getu(sys, Initial.(dvs))
845+
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
846+
end
847+
848+
function (guu::GetUpdatedU0)(prob, initprob)
849+
buffer = guu.get_initial_unknowns(prob)
850+
algebuf = view(buffer, guu.guessvars)
851+
copyto!(algebuf, guu.get_guessvars(initprob))
852+
return buffer
853+
end
854+
805855
"""
806856
$(TYPEDSIGNATURES)
807857
@@ -840,10 +890,15 @@ function maybe_build_initialization_problem(
840890
end
841891
initializeprob = remake(initializeprob; p = initp)
842892

893+
get_initial_unknowns = if is_time_dependent(sys)
894+
GetUpdatedU0(sys, initializeprob.f.sys, op)
895+
else
896+
nothing
897+
end
843898
meta = InitializationMetadata(
844899
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
845900
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
846-
setp(sys, Initial.(unknowns(sys))))
901+
get_initial_unknowns, setp(sys, Initial.(unknowns(sys))))
847902

848903
if is_time_dependent(sys)
849904
all_init_syms = Set(all_symbols(initializeprob))

test/initializationsystem.jl

+51
Original file line numberDiff line numberDiff line change
@@ -1512,3 +1512,54 @@ end
15121512
@inferred remake(prob; u0 = 2 .* prob.u0, p = prob.p)
15131513
@inferred solve(prob)
15141514
end
1515+
1516+
@testset "Issue#3570, #3552: `Initial`s/guesses are copied to `u0` during `solve`/`init`" begin
1517+
@parameters g
1518+
@variables x(t) [state_priority = 10] y(t) λ(t)
1519+
eqs = [D(D(x)) ~ λ * x
1520+
D(D(y)) ~ λ * y - g
1521+
x^2 + y^2 ~ 1]
1522+
@mtkbuild pend = ODESystem(eqs, t)
1523+
1524+
prob = ODEProblem(
1525+
pend, [x => (2 / 2), D(x) => 0.0], (0.0, 1.5),
1526+
[g => 1], guesses ==> 1, y => 2 / 2])
1527+
sol = solve(prob)
1528+
1529+
@testset "Guesses of initialization problem copied to algebraic variables" begin
1530+
prob.f.initialization_data.initializeprob[λ] = 1.0
1531+
prob2 = DiffEqBase.get_updated_symbolic_problem(
1532+
pend, prob; u0 = prob.u0, p = prob.p)
1533+
@test prob2[λ] 1.0
1534+
end
1535+
1536+
@testset "Initial values for algebraic variables are retained" begin
1537+
prob2 = ODEProblem(
1538+
pend, [x => (2 / 2), D(y) => 0.0], (0.0, 1.5),
1539+
[g => 1], guesses ==> 1, y => 2 / 2])
1540+
sol = solve(prob)
1541+
@test SciMLBase.successful_retcode(sol)
1542+
prob3 = DiffEqBase.get_updated_symbolic_problem(
1543+
pend, prob2; u0 = prob2.u0, p = prob2.p)
1544+
@test prob3[D(y)] 0.0
1545+
end
1546+
1547+
@testset "`setsym_oop`" begin
1548+
setter = setsym_oop(prob, [Initial(x)])
1549+
(u0, p) = setter(prob, [0.8])
1550+
new_prob = remake(prob; u0, p, initializealg = BrownFullBasicInit())
1551+
new_sol = solve(new_prob)
1552+
@test new_sol[x, 1] 0.8
1553+
integ = init(new_prob)
1554+
@test integ[x] 0.8
1555+
end
1556+
1557+
@testset "`setsym`" begin
1558+
@test prob.ps[Initial(x)] 2 / 2
1559+
prob.ps[Initial(x)] = 0.8
1560+
sol = solve(prob; initializealg = BrownFullBasicInit())
1561+
@test sol[x, 1] 0.8
1562+
integ = init(prob; initializealg = BrownFullBasicInit())
1563+
@test integ[x] 0.8
1564+
end
1565+
end

0 commit comments

Comments
 (0)