3
3
using Microsoft . ML . OnnxRuntime . Tensors ;
4
4
using OnnxStack . Core ;
5
5
using OnnxStack . Core . Config ;
6
+ using OnnxStack . Core . Image ;
6
7
using OnnxStack . Core . Model ;
7
8
using OnnxStack . Core . Services ;
8
9
using OnnxStack . StableDiffusion . Common ;
@@ -113,15 +114,38 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
113
114
// Process prompts
114
115
var promptEmbeddings = await _promptService . CreatePromptAsync ( modelOptions , promptOptions , performGuidance ) ;
115
116
117
+ // If video input, process frames
118
+ if ( promptOptions . HasInputVideo )
119
+ {
120
+ var frameIndex = 0 ;
121
+ DenseTensor < float > videoTensor = null ;
122
+ var videoFrames = promptOptions . InputVideo . VideoFrames . Frames ;
123
+ var schedulerFrameCallback = CreateBatchCallback ( progressCallback , videoFrames . Count , ( ) => frameIndex ) ;
124
+ foreach ( var videoFrame in videoFrames )
125
+ {
126
+ frameIndex ++ ;
127
+ promptOptions . InputImage = new InputImage ( videoFrame ) ;
128
+ var frameResultTensor = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , schedulerFrameCallback , cancellationToken ) ;
129
+
130
+ // Frame Progress
131
+ ReportBatchProgress ( progressCallback , frameIndex , videoFrames . Count , frameResultTensor ) ;
132
+
133
+ // Concatenate frame
134
+ videoTensor = videoTensor . Concatenate ( frameResultTensor ) ;
135
+ }
136
+
137
+ _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
138
+ return videoTensor ;
139
+ }
140
+
116
141
// Run Scheduler steps
117
142
var schedulerResult = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
118
-
119
143
_logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
120
-
121
144
return schedulerResult ;
122
145
}
123
146
124
147
148
+
125
149
/// <summary>
126
150
/// Runs the stable diffusion batch loop
127
151
/// </summary>
@@ -152,15 +176,11 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffu
152
176
var batchSchedulerOptions = BatchGenerator . GenerateBatch ( modelOptions , batchOptions , schedulerOptions ) ;
153
177
154
178
var batchIndex = 1 ;
155
- var schedulerCallback = ( DiffusionProgress progress ) => progressCallback ? . Invoke ( new DiffusionProgress ( batchIndex , batchSchedulerOptions . Count , progress . ProgressTensor )
156
- {
157
- SubProgressMax = progress . ProgressMax ,
158
- SubProgressValue = progress . ProgressValue ,
159
- } ) ;
179
+ var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
160
180
foreach ( var batchSchedulerOption in batchSchedulerOptions )
161
181
{
162
182
var diffuseTime = _logger ? . LogBegin ( "Diffuse starting..." ) ;
163
- yield return new BatchResult ( batchSchedulerOption , await SchedulerStepAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , schedulerCallback , cancellationToken ) ) ;
183
+ yield return new BatchResult ( batchSchedulerOption , await SchedulerStepAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ) ;
164
184
_logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
165
185
batchIndex ++ ;
166
186
}
@@ -264,9 +284,14 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
264
284
/// <param name="progress">The progress.</param>
265
285
/// <param name="progressMax">The progress maximum.</param>
266
286
/// <param name="output">The output.</param>
267
- protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > output )
287
+ protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > progressTensor )
268
288
{
269
- progressCallback ? . Invoke ( new DiffusionProgress ( progress , progressMax , output ) ) ;
289
+ progressCallback ? . Invoke ( new DiffusionProgress
290
+ {
291
+ StepMax = progressMax ,
292
+ StepValue = progress ,
293
+ StepTensor = progressTensor
294
+ } ) ;
270
295
}
271
296
272
297
@@ -279,13 +304,31 @@ protected void ReportProgress(Action<DiffusionProgress> progressCallback, int pr
279
304
/// <param name="subProgress">The sub progress.</param>
280
305
/// <param name="subProgressMax">The sub progress maximum.</param>
281
306
/// <param name="output">The output.</param>
282
- protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , int subProgress , int subProgressMax , DenseTensor < float > output )
307
+ protected void ReportBatchProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > progressTensor )
308
+ {
309
+ progressCallback ? . Invoke ( new DiffusionProgress
310
+ {
311
+ BatchMax = progressMax ,
312
+ BatchValue = progress ,
313
+ BatchTensor = progressTensor
314
+ } ) ;
315
+ }
316
+
317
+
318
+ private static Action < DiffusionProgress > CreateBatchCallback ( Action < DiffusionProgress > progressCallback , int batchCount , Func < int > batchIndex )
283
319
{
284
- progressCallback ? . Invoke ( new DiffusionProgress ( progress , progressMax , output )
320
+ if ( progressCallback == null )
321
+ return progressCallback ;
322
+
323
+ return ( DiffusionProgress progress ) => progressCallback ? . Invoke ( new DiffusionProgress
285
324
{
286
- SubProgressMax = subProgressMax ,
287
- SubProgressValue = subProgress ,
325
+ StepMax = progress . StepMax ,
326
+ StepValue = progress . StepValue ,
327
+ StepTensor = progress . StepTensor ,
328
+ BatchMax = batchCount ,
329
+ BatchValue = batchIndex ( )
288
330
} ) ;
289
331
}
332
+
290
333
}
291
334
}
0 commit comments