1
1
using Microsoft . Extensions . Logging ;
2
- using Microsoft . ML . OnnxRuntime . Tensors ;
3
2
using OnnxStack . Core ;
4
3
using OnnxStack . Core . Config ;
5
4
using OnnxStack . Core . Image ;
6
- using OnnxStack . Core . Services ;
5
+ using OnnxStack . Core . Video ;
7
6
using OnnxStack . StableDiffusion . Common ;
8
7
using OnnxStack . StableDiffusion . Config ;
9
8
using OnnxStack . StableDiffusion . Enums ;
10
9
using OnnxStack . StableDiffusion . Models ;
11
10
using OnnxStack . StableDiffusion . Pipelines ;
12
11
using OnnxStack . UI . Models ;
13
- using SixLabors . ImageSharp ;
14
12
using SixLabors . ImageSharp . PixelFormats ;
15
13
using System ;
16
14
using System . Collections . Concurrent ;
17
15
using System . Collections . Generic ;
18
- using System . IO ;
19
- using System . Runtime . CompilerServices ;
20
16
using System . Threading ;
21
17
using System . Threading . Tasks ;
22
18
@@ -28,7 +24,6 @@ namespace OnnxStack.UI.Services
28
24
/// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
29
25
public sealed class StableDiffusionService : IStableDiffusionService
30
26
{
31
- private readonly IVideoService _videoService ;
32
27
private readonly ILogger < StableDiffusionService > _logger ;
33
28
private readonly OnnxStackUIConfig _configuration ;
34
29
private readonly Dictionary < IOnnxModel , IPipeline > _pipelines ;
@@ -38,11 +33,10 @@ public sealed class StableDiffusionService : IStableDiffusionService
38
33
/// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
39
34
/// </summary>
40
35
/// <param name="schedulerService">The scheduler service.</param>
41
- public StableDiffusionService ( OnnxStackUIConfig configuration , IVideoService videoService , ILogger < StableDiffusionService > logger )
36
+ public StableDiffusionService ( OnnxStackUIConfig configuration , ILogger < StableDiffusionService > logger )
42
37
{
43
38
_logger = logger ;
44
39
_configuration = configuration ;
45
- _videoService = videoService ;
46
40
_pipelines = new Dictionary < IOnnxModel , IPipeline > ( ) ;
47
41
_controlNetSessions = new ConcurrentDictionary < IOnnxModel , ControlNetModel > ( ) ;
48
42
}
@@ -64,8 +58,6 @@ public async Task<bool> LoadModelAsync(StableDiffusionModelSet model)
64
58
}
65
59
66
60
67
-
68
-
69
61
/// <summary>
70
62
/// Unloads the model.
71
63
/// </summary>
@@ -95,6 +87,11 @@ public bool IsModelLoaded(StableDiffusionModelSet modelOptions)
95
87
}
96
88
97
89
90
+ /// <summary>
91
+ /// Loads the model.
92
+ /// </summary>
93
+ /// <param name="model"></param>
94
+ /// <returns></returns>
98
95
public async Task < bool > LoadControlNetModelAsync ( ControlNetModelSet model )
99
96
{
100
97
if ( _controlNetSessions . ContainsKey ( model ) )
@@ -106,6 +103,12 @@ public async Task<bool> LoadControlNetModelAsync(ControlNetModelSet model)
106
103
return _controlNetSessions . TryAdd ( model , controlNet ) ;
107
104
}
108
105
106
+
107
+ /// <summary>
108
+ /// Unloads the model.
109
+ /// </summary>
110
+ /// <param name="model"></param>
111
+ /// <returns></returns>
109
112
public Task < bool > UnloadControlNetModelAsync ( ControlNetModelSet model )
110
113
{
111
114
if ( _controlNetSessions . Remove ( model , out var controlNet ) )
@@ -115,6 +118,14 @@ public Task<bool> UnloadControlNetModelAsync(ControlNetModelSet model)
115
118
return Task . FromResult ( true ) ;
116
119
}
117
120
121
+
122
+ /// <summary>
123
+ /// Determines whether the specified model is loaded
124
+ /// </summary>
125
+ /// <param name="modelOptions">The model options.</param>
126
+ /// <returns>
127
+ /// <c>true</c> if the specified model is loaded; otherwise, <c>false</c>.
128
+ /// </returns>
118
129
public bool IsControlNetModelLoaded ( ControlNetModelSet modelOptions )
119
130
{
120
131
return _controlNetSessions . ContainsKey ( modelOptions ) ;
@@ -129,164 +140,55 @@ public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions)
129
140
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
130
141
/// <param name="cancellationToken">The cancellation token.</param>
131
142
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
132
- public async Task < OnnxImage > GenerateAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133
- {
134
- return await DiffuseAsync ( model , prompt , options , progressCallback , cancellationToken )
135
- . ContinueWith ( t => new OnnxImage ( t . Result ) , cancellationToken )
136
- . ConfigureAwait ( false ) ;
137
- }
138
-
139
-
140
-
141
-
142
-
143
- /// <summary>
144
- /// Generates a batch of StableDiffusion image using the prompt and options provided.
145
- /// </summary>
146
- /// <param name="modelOptions">The model options.</param>
147
- /// <param name="promptOptions">The prompt options.</param>
148
- /// <param name="schedulerOptions">The scheduler options.</param>
149
- /// <param name="batchOptions">The batch options.</param>
150
- /// <param name="progressCallback">The progress callback.</param>
151
- /// <param name="cancellationToken">The cancellation token.</param>
152
- /// <returns></returns>
153
- public IAsyncEnumerable < BatchResult > GenerateBatchAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
154
- {
155
- return DiffuseBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) ;
156
- }
157
-
158
-
159
-
160
-
161
-
162
-
163
-
164
-
165
-
166
- /// <summary>
167
- /// Runs the diffusion process
168
- /// </summary>
169
- /// <param name="modelOptions">The model options.</param>
170
- /// <param name="promptOptions">The prompt options.</param>
171
- /// <param name="schedulerOptions">The scheduler options.</param>
172
- /// <param name="progress">The progress.</param>
173
- /// <param name="cancellationToken">The cancellation token.</param>
174
- /// <returns></returns>
175
- /// <exception cref="System.Exception">
176
- /// Pipeline not found or is unsupported
177
- /// or
178
- /// Diffuser not found or is unsupported
179
- /// or
180
- /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
181
- /// </exception>
182
- private async Task < DenseTensor < float > > DiffuseAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
143
+ public async Task < OnnxImage > GenerateImageAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
183
144
{
184
- if ( ! _pipelines . TryGetValue ( modelOptions . BaseModel , out var pipeline ) )
145
+ if ( ! _pipelines . TryGetValue ( model . BaseModel , out var pipeline ) )
185
146
throw new Exception ( "Pipeline not found or is unsupported" ) ;
186
147
187
148
var controlNet = default ( ControlNetModel ) ;
188
- if ( modelOptions . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( modelOptions . ControlNetModel , out controlNet ) )
149
+ if ( model . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( model . ControlNetModel , out controlNet ) )
189
150
throw new Exception ( "ControlNet not loaded" ) ;
190
151
191
- pipeline . ValidateInputs ( promptOptions , schedulerOptions ) ;
152
+ pipeline . ValidateInputs ( prompt , options ) ;
192
153
193
- await GenerateInputVideoFrames ( promptOptions , progress ) ;
194
- return await pipeline . RunAsync ( promptOptions , schedulerOptions , controlNet , progress , cancellationToken ) ;
154
+ return await pipeline . GenerateImageAsync ( prompt , options , controlNet , progressCallback , cancellationToken ) ;
195
155
}
196
156
197
157
198
158
/// <summary>
199
- /// Runs the batch diffusion process .
159
+ /// Generates the StableDiffusion video using the prompt and options provided .
200
160
/// </summary>
201
- /// <param name="modelOptions">The model options.</param>
202
- /// <param name="promptOptions">The prompt options.</param>
203
- /// <param name="schedulerOptions">The scheduler options.</param>
204
- /// <param name="batchOptions">The batch options.</param>
205
- /// <param name="progress">The progress.</param>
161
+ /// <param name="model">The model.</param>
162
+ /// <param name="prompt">The prompt.</param>
163
+ /// <param name="options">The options.</param>
164
+ /// <param name="progressCallback">The progress callback.</param>
206
165
/// <param name="cancellationToken">The cancellation token.</param>
207
166
/// <returns></returns>
208
167
/// <exception cref="System.Exception">
209
168
/// Pipeline not found or is unsupported
210
169
/// or
211
- /// Diffuser not found or is unsupported
212
- /// or
213
- /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
170
+ /// ControlNet not loaded
214
171
/// </exception>
215
- private async IAsyncEnumerable < BatchResult > DiffuseBatchAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
172
+ public async Task < OnnxVideo > GenerateVideoAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
216
173
{
217
- if ( ! _pipelines . TryGetValue ( modelOptions . BaseModel , out var pipeline ) )
174
+ if ( ! _pipelines . TryGetValue ( model . BaseModel , out var pipeline ) )
218
175
throw new Exception ( "Pipeline not found or is unsupported" ) ;
219
176
220
177
var controlNet = default ( ControlNetModel ) ;
221
- if ( modelOptions . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( modelOptions . ControlNetModel , out controlNet ) )
178
+ if ( model . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( model . ControlNetModel , out controlNet ) )
222
179
throw new Exception ( "ControlNet not loaded" ) ;
223
180
224
- pipeline . ValidateInputs ( promptOptions , schedulerOptions ) ;
181
+ pipeline . ValidateInputs ( prompt , options ) ;
225
182
226
- await GenerateInputVideoFrames ( promptOptions , progressCallback ) ;
227
- await foreach ( var result in pipeline . RunBatchAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
228
- {
229
- yield return result ;
230
- }
183
+ return await pipeline . GenerateVideoAsync ( prompt , options , controlNet , progressCallback , cancellationToken ) ;
231
184
}
232
185
233
186
234
187
/// <summary>
235
- /// Generates the video result as bytes .
188
+ /// Creates the pipeline .
236
189
/// </summary>
237
- /// <param name="options">The options.</param>
238
- /// <param name="videoTensor">The video tensor.</param>
239
- /// <param name="progress">The progress.</param>
240
- /// <param name="cancellationToken">The cancellation token.</param>
241
- /// <returns></returns>
242
- private async Task < byte [ ] > GenerateVideoResultAsBytesAsync ( DenseTensor < float > videoTensor , float videoFPS , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
243
- {
244
- progress ? . Invoke ( new DiffusionProgress ( "Generating Video Result..." ) ) ;
245
- var videoResult = await _videoService . CreateVideoAsync ( videoTensor , videoFPS , cancellationToken ) ;
246
- return videoResult . Data ;
247
- }
248
-
249
-
250
- /// <summary>
251
- /// Generates the video result as stream.
252
- /// </summary>
253
- /// <param name="options">The options.</param>
254
- /// <param name="videoTensor">The video tensor.</param>
255
- /// <param name="progress">The progress.</param>
256
- /// <param name="cancellationToken">The cancellation token.</param>
190
+ /// <param name="model">The model.</param>
257
191
/// <returns></returns>
258
- private async Task < MemoryStream > GenerateVideoResultAsStreamAsync ( DenseTensor < float > videoTensor , float videoFPS , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
259
- {
260
- return new MemoryStream ( await GenerateVideoResultAsBytesAsync ( videoTensor , videoFPS , progress , cancellationToken ) ) ;
261
- }
262
-
263
-
264
- /// <summary>
265
- /// Generates the input video frames.
266
- /// </summary>
267
- /// <param name="promptOptions">The prompt options.</param>
268
- /// <param name="progress">The progress.</param>
269
- private async Task GenerateInputVideoFrames ( PromptOptions promptOptions , Action < DiffusionProgress > progress )
270
- {
271
- if ( ! promptOptions . HasInputVideo || promptOptions . InputVideo . VideoFrames is not null )
272
- return ;
273
-
274
- if ( promptOptions . VideoInputFPS == 0 || promptOptions . VideoOutputFPS == 0 )
275
- {
276
- var videoInfo = await _videoService . GetVideoInfoAsync ( promptOptions . InputVideo ) ;
277
- if ( promptOptions . VideoInputFPS == 0 )
278
- promptOptions . VideoInputFPS = videoInfo . FPS ;
279
-
280
- if ( promptOptions . VideoOutputFPS == 0 )
281
- promptOptions . VideoOutputFPS = videoInfo . FPS ;
282
- }
283
-
284
- var videoFrame = await _videoService . CreateFramesAsync ( promptOptions . InputVideo , promptOptions . VideoInputFPS ) ;
285
- progress ? . Invoke ( new DiffusionProgress ( $ "Generating video frames @ { promptOptions . VideoInputFPS } fps") ) ;
286
- promptOptions . InputVideo . VideoFrames = videoFrame ;
287
- }
288
-
289
-
290
192
private IPipeline CreatePipeline ( StableDiffusionModelSet model )
291
193
{
292
194
return model . PipelineType switch
0 commit comments