Skip to content

Commit ba5ecd9

Browse files
authored
Allow for vector of inputs in streaming inference (#472)
* Added fix for issue #318 * Added a MRE from issue #318 as test
1 parent 4cc142b commit ba5ecd9

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/model/plugins/reactivemp_inference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ getlabel(ref::GraphVariableRef) = ref.label
258258
getvariable(ref::GraphVariableRef) = ref.variable
259259
getname(ref::GraphVariableRef) = GraphPPL.getname(getlabel(ref))
260260

261+
GraphPPL.is_data(collection::AbstractArray{GraphVariableRef}) = all(GraphPPL.is_data, collection)
262+
261263
GraphPPL.is_data(ref::GraphVariableRef) = GraphPPL.is_data(ref.properties)
262264
GraphPPL.is_random(ref::GraphVariableRef) = GraphPPL.is_random(ref.properties)
263265
GraphPPL.is_constant(ref::GraphVariableRef) = GraphPPL.is_constant(ref.properties)

test/inference/inference_tests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,33 @@ end
10501050
end
10511051
end
10521052

1053+
@testitem "vector of inputs in streaming inference" begin
1054+
import RxInfer: infer
1055+
@model function test_model(x, y, mx, vx)
1056+
for i in 1:3
1057+
x[i] ~ NormalMeanVariance(mx, vx)
1058+
end
1059+
my ~ NormalMeanVariance(0, 1)
1060+
y ~ NormalMeanVariance(my, 1.0)
1061+
end
1062+
1063+
d = [(x = rand(3), y = rand()) for i in 1:10]
1064+
datastream = from(d) |> map(NamedTuple{(:x, :y), Tuple{Vector{Float64}, Float64}}, (d) -> d)
1065+
1066+
foo(x) = 1.0
1067+
1068+
autoupdates = @autoupdates begin
1069+
mx = foo(q(my))
1070+
vx = foo(q(my))
1071+
end
1072+
1073+
engine = infer(model = test_model(), datastream = datastream, autoupdates = autoupdates, initialization = @initialization begin
1074+
q(my) = NormalMeanVariance(1.0, 1.0)
1075+
end)
1076+
1077+
@test engine.history === nothing
1078+
end
1079+
10531080
@testitem "Test misspecified types in infer function" begin
10541081
@model function rolling_die(y)
10551082
θ ~ Dirichlet(ones(6))

0 commit comments

Comments
 (0)