Skip to content

Commit 0aa0d8b

Browse files
fix: fix type promotion in late_binding_update_u0_p
1 parent dd0e30b commit 0aa0d8b

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/systems/nonlinear/initializesystem.jl

+19-17
Original file line numberDiff line numberDiff line change
@@ -589,13 +589,31 @@ function SciMLBase.remake_initialization_data(
589589
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)
590590
end
591591

592+
function promote_u0_p(u0, p::MTKParameters, t0)
593+
u0 = DiffEqBase.promote_u0(u0, p.tunable, t0)
594+
u0 = DiffEqBase.promote_u0(u0, p.initials, t0)
595+
596+
tunables = DiffEqBase.promote_u0(p.tunable, u0, t0)
597+
initials = DiffEqBase.promote_u0(p.initials, u0, t0)
598+
p = SciMLStructures.replace(SciMLStructures.Tunable(), p, tunables)
599+
p = SciMLStructures.replace(SciMLStructures.Initials(), p, initials)
600+
601+
return u0, p
602+
end
603+
604+
function promote_u0_p(u0, p::AbstractArray, t0)
605+
return DiffEqBase.promote_u0(u0, p, t0), DiffEqBase.promote_u0(p, u0, t0)
606+
end
607+
592608
function SciMLBase.late_binding_update_u0_p(
593609
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
594610
supports_initialization(sys) || return newu0, newp
595611

596612
initdata = prob.f.initialization_data
597613
meta = initdata === nothing ? nothing : initdata.metadata
598614

615+
newu0, newp = promote_u0_p(newu0, newp, t0)
616+
599617
# non-symbolic u0 updates initials...
600618
if !(eltype(u0) <: Pair)
601619
# if `p` is not provided or is symbolic
@@ -605,12 +623,7 @@ function SciMLBase.late_binding_update_u0_p(
605623
meta = initdata.metadata
606624
meta isa InitializationMetadata || return newu0, newp
607625
newp = p === missing ? copy(newp) : newp
608-
initials, repack, alias = SciMLStructures.canonicalize(
609-
SciMLStructures.Initials(), newp)
610-
if eltype(initials) != eltype(newu0)
611-
initials = DiffEqBase.promote_u0(initials, newu0, t0)
612-
newp = repack(initials)
613-
end
626+
614627
if length(newu0) != length(prob.u0)
615628
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
616629
end
@@ -619,17 +632,6 @@ function SciMLBase.late_binding_update_u0_p(
619632
end
620633

621634
newp = p === missing ? copy(newp) : newp
622-
newu0 = DiffEqBase.promote_u0(newu0, newp, t0)
623-
tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
624-
if eltype(tunables) != eltype(newu0)
625-
tunables = DiffEqBase.promote_u0(tunables, newu0, t0)
626-
newp = repack(tunables)
627-
end
628-
initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
629-
if eltype(initials) != eltype(newu0)
630-
initials = DiffEqBase.promote_u0(initials, newu0, t0)
631-
newp = repack(initials)
632-
end
633635

634636
allsyms = all_symbols(sys)
635637
for (k, v) in u0

0 commit comments

Comments
 (0)