diff --git a/Project.toml b/Project.toml index 35e03e2924..8a7cd147b4 100644 --- a/Project.toml +++ b/Project.toml @@ -95,7 +95,7 @@ DataInterpolations = "6.4" DataStructures = "0.17, 0.18" DeepDiffs = "1" DelayDiffEq = "5.50" -DiffEqBase = "6.165.1" +DiffEqBase = "6.170.1" DiffEqCallbacks = "2.16, 3, 4" DiffEqNoiseProcess = "5" DiffRules = "0.1, 1.0" diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 50c0862863..5995272be9 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -589,26 +589,41 @@ function SciMLBase.remake_initialization_data( return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp) end +function promote_u0_p(u0, p::MTKParameters, t0) + u0 = DiffEqBase.promote_u0(u0, p.tunable, t0) + u0 = DiffEqBase.promote_u0(u0, p.initials, t0) + + tunables = DiffEqBase.promote_u0(p.tunable, u0, t0) + initials = DiffEqBase.promote_u0(p.initials, u0, t0) + p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables) + p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials) + + return u0, p +end + +function promote_u0_p(u0, p::AbstractArray, t0) + return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0) +end + function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) supports_initialization(sys) || return newu0, newp - u0 === missing && return newu0, (p === missing ? copy(newp) : newp) + + initdata = prob.f.initialization_data + meta = initdata === nothing ? nothing : initdata.metadata + + newu0, newp = promote_u0_p(newu0, newp, t0) + # non-symbolic u0 updates initials... if !(eltype(u0) <: Pair) # if `p` is not provided or is symbolic p === missing || eltype(p) <: Pair || return newu0, newp (newu0 === nothing || isempty(newu0)) && return newu0, newp - initdata = prob.f.initialization_data initdata === nothing && return newu0, newp meta = initdata.metadata meta isa InitializationMetadata || return newu0, newp newp = p === missing ? copy(newp) : newp - initials, repack, alias = SciMLStructures.canonicalize( - SciMLStructures.Initials(), newp) - if eltype(initials) != eltype(newu0) - initials = DiffEqBase.promote_u0(initials, newu0, t0) - newp = repack(initials) - end + if length(newu0) != length(prob.u0) throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end @@ -617,17 +632,6 @@ function SciMLBase.late_binding_update_u0_p( end newp = p === missing ? copy(newp) : newp - newu0 = DiffEqBase.promote_u0(newu0, newp, t0) - tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) - if eltype(tunables) != eltype(newu0) - tunables = DiffEqBase.promote_u0(tunables, newu0, t0) - newp = repack(tunables) - end - initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - if eltype(initials) != eltype(newu0) - initials = DiffEqBase.promote_u0(initials, newu0, t0) - newp = repack(initials) - end allsyms = all_symbols(sys) for (k, v) in u0 @@ -646,6 +650,37 @@ function SciMLBase.late_binding_update_u0_p( return newu0, newp end +function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...) + supports_initialization(sys) || return prob + initdata = prob.f.initialization_data + initdata isa SciMLBase.OverrideInitData || return prob + meta = initdata.metadata + meta isa InitializationMetadata || return prob + meta.get_updated_u0 === nothing && return prob + + u0 = state_values(prob) + u0 === nothing && return prob + + p = parameter_values(prob) + t0 = is_time_dependent(prob) ? current_time(prob) : nothing + + if p isa MTKParameters + buffer = p.initials + else + buffer = p + end + + u0 = DiffEqBase.promote_u0(u0, buffer, t0) + + if ArrayInterface.ismutable(u0) + T = typeof(u0) + else + T = StaticArrays.similar_type(u0) + end + + return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob))) +end + """ $(TYPEDSIGNATURES) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 41c64b78c5..47bf3c678d 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -769,7 +769,7 @@ properly. $(TYPEDFIELDS) """ -struct InitializationMetadata{R <: ReconstructInitializeprob, SIU} +struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU} """ The `u0map` used to construct the initialization. """ @@ -796,12 +796,62 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, SIU} """ oop_reconstruct_u0_p::R """ + A function which takes `(prob, initializeprob)` and return the `u0` to use for the problem. + """ + get_updated_u0::GUU + """ A function which takes the `u0` of the problem and sets `Initial.(unknowns(sys))`. """ set_initial_unknowns!::SIU end +""" + $(TYPEDEF) + +A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`. +Returns the value to use for the `u0` of the problem. + +# Fields + +$(TYPEDFIELDS) +""" +struct GetUpdatedU0{GG, GIU} + """ + Mask with length `length(unknowns(sys))` denoting indices of variables which should + take the guess value from `initializeprob`. + """ + guessvars::BitVector + """ + Function which returns the values of variables in `initializeprob` for which + `guessvars` is `true`, in the order they occur in `unknowns(sys)`. + """ + get_guessvars::GG + """ + Function which returns `Initial.(unknowns(sys))` as a `Vector`. + """ + get_initial_unknowns::GIU +end + +function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict) + dvs = unknowns(sys) + eqs = equations(sys) + guessvars = trues(length(dvs)) + for (i, var) in enumerate(dvs) + guessvars[i] = !isequal(get(op, var, nothing), Initial(var)) + end + get_guessvars = getu(initsys, dvs[guessvars]) + get_initial_unknowns = getu(sys, Initial.(dvs)) + return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns) +end + +function (guu::GetUpdatedU0)(prob, initprob) + buffer = guu.get_initial_unknowns(prob) + algebuf = view(buffer, guu.guessvars) + copyto!(algebuf, guu.get_guessvars(initprob)) + return buffer +end + """ $(TYPEDSIGNATURES) @@ -840,10 +890,15 @@ function maybe_build_initialization_problem( end initializeprob = remake(initializeprob; p = initp) + get_initial_unknowns = if is_time_dependent(sys) + GetUpdatedU0(sys, initializeprob.f.sys, op) + else + nothing + end meta = InitializationMetadata( u0map, pmap, guesses, Vector{Equation}(initialization_eqs), use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys), - setp(sys, Initial.(unknowns(sys)))) + get_initial_unknowns, setp(sys, Initial.(unknowns(sys)))) if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 2804c70833..5c36fcba3e 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1512,3 +1512,54 @@ end @inferred remake(prob; u0 = 2 .* prob.u0, p = prob.p) @inferred solve(prob) end + +@testset "Issue#3570, #3552: `Initial`s/guesses are copied to `u0` during `solve`/`init`" begin + @parameters g + @variables x(t) [state_priority = 10] y(t) λ(t) + eqs = [D(D(x)) ~ λ * x + D(D(y)) ~ λ * y - g + x^2 + y^2 ~ 1] + @mtkbuild pend = ODESystem(eqs, t) + + prob = ODEProblem( + pend, [x => (√2 / 2), D(x) => 0.0], (0.0, 1.5), + [g => 1], guesses = [λ => 1, y => √2 / 2]) + sol = solve(prob) + + @testset "Guesses of initialization problem copied to algebraic variables" begin + prob.f.initialization_data.initializeprob[λ] = 1.0 + prob2 = DiffEqBase.get_updated_symbolic_problem( + pend, prob; u0 = prob.u0, p = prob.p) + @test prob2[λ] ≈ 1.0 + end + + @testset "Initial values for algebraic variables are retained" begin + prob2 = ODEProblem( + pend, [x => (√2 / 2), D(y) => 0.0], (0.0, 1.5), + [g => 1], guesses = [λ => 1, y => √2 / 2]) + sol = solve(prob) + @test SciMLBase.successful_retcode(sol) + prob3 = DiffEqBase.get_updated_symbolic_problem( + pend, prob2; u0 = prob2.u0, p = prob2.p) + @test prob3[D(y)] ≈ 0.0 + end + + @testset "`setsym_oop`" begin + setter = setsym_oop(prob, [Initial(x)]) + (u0, p) = setter(prob, [0.8]) + new_prob = remake(prob; u0, p, initializealg = BrownFullBasicInit()) + new_sol = solve(new_prob) + @test new_sol[x, 1] ≈ 0.8 + integ = init(new_prob) + @test integ[x] ≈ 0.8 + end + + @testset "`setsym`" begin + @test prob.ps[Initial(x)] ≈ √2 / 2 + prob.ps[Initial(x)] = 0.8 + sol = solve(prob; initializealg = BrownFullBasicInit()) + @test sol[x, 1] ≈ 0.8 + integ = init(prob; initializealg = BrownFullBasicInit()) + @test integ[x] ≈ 0.8 + end +end