Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 299c421

Browse files
committed
Update example UI
1 parent 13e90c2 commit 299c421

10 files changed

+106
-242
lines changed

OnnxStack.Core/Video/VideoHelper.cs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,16 @@ public static async Task WriteVideoFramesAsync(IEnumerable<OnnxImage> onnxImages
6363
/// <param name="cancellationToken">The cancellation token.</param>
6464
private static async Task WriteVideoFramesAsync(IEnumerable<OnnxImage> onnxImages, string filename, float frameRate, double aspectRatio, CancellationToken cancellationToken = default)
6565
{
66+
if (File.Exists(filename))
67+
File.Delete(filename);
68+
6669
using (var videoWriter = CreateWriter(filename, frameRate, aspectRatio))
6770
{
6871
// Start FFMPEG
6972
videoWriter.Start();
7073
foreach (var image in onnxImages)
7174
{
7275
// Write each frame to the input stream of FFMPEG
73-
await Task.Yield();
7476
await videoWriter.StandardInput.BaseStream.WriteAsync(image.GetImageBytes(), cancellationToken);
7577
}
7678

@@ -96,11 +98,23 @@ public static async Task<VideoInfo> ReadVideoInfoAsync(byte[] videoBytes)
9698
}
9799

98100

101+
/// <summary>
102+
/// Reads the video information.
103+
/// </summary>
104+
/// <param name="filename">The filename.</param>
105+
/// <returns></returns>
106+
public static async Task<VideoInfo> ReadVideoInfoAsync(string filename)
107+
{
108+
var result = await FFProbe.AnalyseAsync(filename).ConfigureAwait(false);
109+
return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate);
110+
}
111+
112+
99113
/// <summary>
100114
/// Reads the video frames.
101115
/// </summary>
102116
/// <param name="videoBytes">The video bytes.</param>
103-
/// <param name="frameRate">The frame rate.</param>
117+
/// <param name="frameRate">The target frame rate.</param>
104118
/// <param name="cancellationToken">The cancellation token.</param>
105119
/// <returns></returns>
106120
public static async Task<List<OnnxImage>> ReadVideoFramesAsync(byte[] videoBytes, float frameRate = 15, CancellationToken cancellationToken = default)
@@ -111,6 +125,22 @@ public static async Task<List<OnnxImage>> ReadVideoFramesAsync(byte[] videoBytes
111125
}
112126

113127

128+
/// <summary>
129+
/// Reads the video frames.
130+
/// </summary>
131+
/// <param name="filename">The video bytes.</param>
132+
/// <param name="frameRate">The target frame rate.</param>
133+
/// <param name="cancellationToken">The cancellation token.</param>
134+
/// <returns></returns>
135+
public static async Task<List<OnnxImage>> ReadVideoFramesAsync(string filename, float frameRate = 15, CancellationToken cancellationToken = default)
136+
{
137+
var videoBytes = await File.ReadAllBytesAsync(filename, cancellationToken);
138+
return await CreateFramesInternalAsync(videoBytes, frameRate, cancellationToken)
139+
.Select(x => new OnnxImage(x))
140+
.ToListAsync(cancellationToken);
141+
}
142+
143+
114144
#region Private Members
115145

116146

OnnxStack.UI/Services/IStableDiffusionService.cs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using OnnxStack.Core.Image;
1+
using OnnxStack.Core.Image;
2+
using OnnxStack.Core.Video;
33
using OnnxStack.StableDiffusion.Common;
44
using OnnxStack.StableDiffusion.Config;
55
using System;
6-
using System.Collections.Generic;
76
using System.Threading;
87
using System.Threading.Tasks;
98

@@ -68,18 +67,18 @@ public interface IStableDiffusionService
6867
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
6968
/// <param name="cancellationToken">The cancellation token.</param>
7069
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
71-
Task<OnnxImage> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
70+
Task<OnnxImage> GenerateImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
71+
7272

7373
/// <summary>
74-
/// Generates a batch of StableDiffusion image using the prompt and options provided.
74+
/// Generates the StableDiffusion video using the prompt and options provided.
7575
/// </summary>
76-
/// <param name="modelOptions">The model options.</param>
77-
/// <param name="promptOptions">The prompt options.</param>
78-
/// <param name="schedulerOptions">The scheduler options.</param>
79-
/// <param name="batchOptions">The batch options.</param>
76+
/// <param name="model">The model.</param>
77+
/// <param name="prompt">The prompt.</param>
78+
/// <param name="options">The options.</param>
8079
/// <param name="progressCallback">The progress callback.</param>
8180
/// <param name="cancellationToken">The cancellation token.</param>
8281
/// <returns></returns>
83-
IAsyncEnumerable<BatchResult> GenerateBatchAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
82+
Task<OnnxVideo> GenerateVideoAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
8483
}
8584
}

OnnxStack.UI/Services/IUpscaleService.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,5 @@ public interface IUpscaleService
4343
/// <param name="inputImage">The input image.</param>
4444
/// <returns></returns>
4545
Task<OnnxImage> GenerateAsync(UpscaleModelSet modelOptions, OnnxImage inputImage, CancellationToken cancellationToken = default);
46-
47-
48-
/// <summary>
49-
/// Generates the upscaled video.
50-
/// </summary>
51-
/// <param name="modelOptions">The model options.</param>
52-
/// <param name="videoInput">The video input.</param>
53-
/// <returns></returns>
54-
Task<DenseTensor<float>> GenerateAsync(UpscaleModelSet modelOptions, OnnxVideo videoInput, CancellationToken cancellationToken = default);
5546
}
5647
}

OnnxStack.UI/Services/StableDiffusionService.cs

Lines changed: 39 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
using Microsoft.Extensions.Logging;
2-
using Microsoft.ML.OnnxRuntime.Tensors;
32
using OnnxStack.Core;
43
using OnnxStack.Core.Config;
54
using OnnxStack.Core.Image;
6-
using OnnxStack.Core.Services;
5+
using OnnxStack.Core.Video;
76
using OnnxStack.StableDiffusion.Common;
87
using OnnxStack.StableDiffusion.Config;
98
using OnnxStack.StableDiffusion.Enums;
109
using OnnxStack.StableDiffusion.Models;
1110
using OnnxStack.StableDiffusion.Pipelines;
1211
using OnnxStack.UI.Models;
13-
using SixLabors.ImageSharp;
1412
using SixLabors.ImageSharp.PixelFormats;
1513
using System;
1614
using System.Collections.Concurrent;
1715
using System.Collections.Generic;
18-
using System.IO;
19-
using System.Runtime.CompilerServices;
2016
using System.Threading;
2117
using System.Threading.Tasks;
2218

@@ -28,7 +24,6 @@ namespace OnnxStack.UI.Services
2824
/// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
2925
public sealed class StableDiffusionService : IStableDiffusionService
3026
{
31-
private readonly IVideoService _videoService;
3227
private readonly ILogger<StableDiffusionService> _logger;
3328
private readonly OnnxStackUIConfig _configuration;
3429
private readonly Dictionary<IOnnxModel, IPipeline> _pipelines;
@@ -38,11 +33,10 @@ public sealed class StableDiffusionService : IStableDiffusionService
3833
/// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
3934
/// </summary>
4035
/// <param name="schedulerService">The scheduler service.</param>
41-
public StableDiffusionService(OnnxStackUIConfig configuration, IVideoService videoService, ILogger<StableDiffusionService> logger)
36+
public StableDiffusionService(OnnxStackUIConfig configuration, ILogger<StableDiffusionService> logger)
4237
{
4338
_logger = logger;
4439
_configuration = configuration;
45-
_videoService = videoService;
4640
_pipelines = new Dictionary<IOnnxModel, IPipeline>();
4741
_controlNetSessions = new ConcurrentDictionary<IOnnxModel, ControlNetModel>();
4842
}
@@ -64,8 +58,6 @@ public async Task<bool> LoadModelAsync(StableDiffusionModelSet model)
6458
}
6559

6660

67-
68-
6961
/// <summary>
7062
/// Unloads the model.
7163
/// </summary>
@@ -95,6 +87,11 @@ public bool IsModelLoaded(StableDiffusionModelSet modelOptions)
9587
}
9688

9789

90+
/// <summary>
91+
/// Loads the model.
92+
/// </summary>
93+
/// <param name="model"></param>
94+
/// <returns></returns>
9895
public async Task<bool> LoadControlNetModelAsync(ControlNetModelSet model)
9996
{
10097
if (_controlNetSessions.ContainsKey(model))
@@ -106,6 +103,12 @@ public async Task<bool> LoadControlNetModelAsync(ControlNetModelSet model)
106103
return _controlNetSessions.TryAdd(model, controlNet);
107104
}
108105

106+
107+
/// <summary>
108+
/// Unloads the model.
109+
/// </summary>
110+
/// <param name="model"></param>
111+
/// <returns></returns>
109112
public Task<bool> UnloadControlNetModelAsync(ControlNetModelSet model)
110113
{
111114
if (_controlNetSessions.Remove(model, out var controlNet))
@@ -115,6 +118,14 @@ public Task<bool> UnloadControlNetModelAsync(ControlNetModelSet model)
115118
return Task.FromResult(true);
116119
}
117120

121+
122+
/// <summary>
123+
/// Determines whether the specified model is loaded
124+
/// </summary>
125+
/// <param name="modelOptions">The model options.</param>
126+
/// <returns>
127+
/// <c>true</c> if the specified model is loaded; otherwise, <c>false</c>.
128+
/// </returns>
118129
public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions)
119130
{
120131
return _controlNetSessions.ContainsKey(modelOptions);
@@ -129,164 +140,55 @@ public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions)
129140
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
130141
/// <param name="cancellationToken">The cancellation token.</param>
131142
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
132-
public async Task<OnnxImage> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
133-
{
134-
return await DiffuseAsync(model, prompt, options, progressCallback, cancellationToken)
135-
.ContinueWith(t => new OnnxImage(t.Result), cancellationToken)
136-
.ConfigureAwait(false);
137-
}
138-
139-
140-
141-
142-
143-
/// <summary>
144-
/// Generates a batch of StableDiffusion image using the prompt and options provided.
145-
/// </summary>
146-
/// <param name="modelOptions">The model options.</param>
147-
/// <param name="promptOptions">The prompt options.</param>
148-
/// <param name="schedulerOptions">The scheduler options.</param>
149-
/// <param name="batchOptions">The batch options.</param>
150-
/// <param name="progressCallback">The progress callback.</param>
151-
/// <param name="cancellationToken">The cancellation token.</param>
152-
/// <returns></returns>
153-
public IAsyncEnumerable<BatchResult> GenerateBatchAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
154-
{
155-
return DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken);
156-
}
157-
158-
159-
160-
161-
162-
163-
164-
165-
166-
/// <summary>
167-
/// Runs the diffusion process
168-
/// </summary>
169-
/// <param name="modelOptions">The model options.</param>
170-
/// <param name="promptOptions">The prompt options.</param>
171-
/// <param name="schedulerOptions">The scheduler options.</param>
172-
/// <param name="progress">The progress.</param>
173-
/// <param name="cancellationToken">The cancellation token.</param>
174-
/// <returns></returns>
175-
/// <exception cref="System.Exception">
176-
/// Pipeline not found or is unsupported
177-
/// or
178-
/// Diffuser not found or is unsupported
179-
/// or
180-
/// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
181-
/// </exception>
182-
private async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<DiffusionProgress> progress = null, CancellationToken cancellationToken = default)
143+
public async Task<OnnxImage> GenerateImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
183144
{
184-
if (!_pipelines.TryGetValue(modelOptions.BaseModel, out var pipeline))
145+
if (!_pipelines.TryGetValue(model.BaseModel, out var pipeline))
185146
throw new Exception("Pipeline not found or is unsupported");
186147

187148
var controlNet = default(ControlNetModel);
188-
if (modelOptions.ControlNetModel is not null && !_controlNetSessions.TryGetValue(modelOptions.ControlNetModel, out controlNet))
149+
if (model.ControlNetModel is not null && !_controlNetSessions.TryGetValue(model.ControlNetModel, out controlNet))
189150
throw new Exception("ControlNet not loaded");
190151

191-
pipeline.ValidateInputs(promptOptions, schedulerOptions);
152+
pipeline.ValidateInputs(prompt, options);
192153

193-
await GenerateInputVideoFrames(promptOptions, progress);
194-
return await pipeline.RunAsync(promptOptions, schedulerOptions, controlNet, progress, cancellationToken);
154+
return await pipeline.GenerateImageAsync(prompt, options, controlNet, progressCallback, cancellationToken);
195155
}
196156

197157

198158
/// <summary>
199-
/// Runs the batch diffusion process.
159+
/// Generates the StableDiffusion video using the prompt and options provided.
200160
/// </summary>
201-
/// <param name="modelOptions">The model options.</param>
202-
/// <param name="promptOptions">The prompt options.</param>
203-
/// <param name="schedulerOptions">The scheduler options.</param>
204-
/// <param name="batchOptions">The batch options.</param>
205-
/// <param name="progress">The progress.</param>
161+
/// <param name="model">The model.</param>
162+
/// <param name="prompt">The prompt.</param>
163+
/// <param name="options">The options.</param>
164+
/// <param name="progressCallback">The progress callback.</param>
206165
/// <param name="cancellationToken">The cancellation token.</param>
207166
/// <returns></returns>
208167
/// <exception cref="System.Exception">
209168
/// Pipeline not found or is unsupported
210169
/// or
211-
/// Diffuser not found or is unsupported
212-
/// or
213-
/// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
170+
/// ControlNet not loaded
214171
/// </exception>
215-
private async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
172+
public async Task<OnnxVideo> GenerateVideoAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
216173
{
217-
if (!_pipelines.TryGetValue(modelOptions.BaseModel, out var pipeline))
174+
if (!_pipelines.TryGetValue(model.BaseModel, out var pipeline))
218175
throw new Exception("Pipeline not found or is unsupported");
219176

220177
var controlNet = default(ControlNetModel);
221-
if (modelOptions.ControlNetModel is not null && !_controlNetSessions.TryGetValue(modelOptions.ControlNetModel, out controlNet))
178+
if (model.ControlNetModel is not null && !_controlNetSessions.TryGetValue(model.ControlNetModel, out controlNet))
222179
throw new Exception("ControlNet not loaded");
223180

224-
pipeline.ValidateInputs(promptOptions, schedulerOptions);
181+
pipeline.ValidateInputs(prompt, options);
225182

226-
await GenerateInputVideoFrames(promptOptions, progressCallback);
227-
await foreach (var result in pipeline.RunBatchAsync(batchOptions, promptOptions, schedulerOptions, controlNet, progressCallback, cancellationToken))
228-
{
229-
yield return result;
230-
}
183+
return await pipeline.GenerateVideoAsync(prompt, options, controlNet, progressCallback, cancellationToken);
231184
}
232185

233186

234187
/// <summary>
235-
/// Generates the video result as bytes.
188+
/// Creates the pipeline.
236189
/// </summary>
237-
/// <param name="options">The options.</param>
238-
/// <param name="videoTensor">The video tensor.</param>
239-
/// <param name="progress">The progress.</param>
240-
/// <param name="cancellationToken">The cancellation token.</param>
241-
/// <returns></returns>
242-
private async Task<byte[]> GenerateVideoResultAsBytesAsync(DenseTensor<float> videoTensor, float videoFPS, Action<DiffusionProgress> progress = null, CancellationToken cancellationToken = default)
243-
{
244-
progress?.Invoke(new DiffusionProgress("Generating Video Result..."));
245-
var videoResult = await _videoService.CreateVideoAsync(videoTensor, videoFPS, cancellationToken);
246-
return videoResult.Data;
247-
}
248-
249-
250-
/// <summary>
251-
/// Generates the video result as stream.
252-
/// </summary>
253-
/// <param name="options">The options.</param>
254-
/// <param name="videoTensor">The video tensor.</param>
255-
/// <param name="progress">The progress.</param>
256-
/// <param name="cancellationToken">The cancellation token.</param>
190+
/// <param name="model">The model.</param>
257191
/// <returns></returns>
258-
private async Task<MemoryStream> GenerateVideoResultAsStreamAsync(DenseTensor<float> videoTensor, float videoFPS, Action<DiffusionProgress> progress = null, CancellationToken cancellationToken = default)
259-
{
260-
return new MemoryStream(await GenerateVideoResultAsBytesAsync(videoTensor, videoFPS, progress, cancellationToken));
261-
}
262-
263-
264-
/// <summary>
265-
/// Generates the input video frames.
266-
/// </summary>
267-
/// <param name="promptOptions">The prompt options.</param>
268-
/// <param name="progress">The progress.</param>
269-
private async Task GenerateInputVideoFrames(PromptOptions promptOptions, Action<DiffusionProgress> progress)
270-
{
271-
if (!promptOptions.HasInputVideo || promptOptions.InputVideo.VideoFrames is not null)
272-
return;
273-
274-
if (promptOptions.VideoInputFPS == 0 || promptOptions.VideoOutputFPS == 0)
275-
{
276-
var videoInfo = await _videoService.GetVideoInfoAsync(promptOptions.InputVideo);
277-
if (promptOptions.VideoInputFPS == 0)
278-
promptOptions.VideoInputFPS = videoInfo.FPS;
279-
280-
if (promptOptions.VideoOutputFPS == 0)
281-
promptOptions.VideoOutputFPS = videoInfo.FPS;
282-
}
283-
284-
var videoFrame = await _videoService.CreateFramesAsync(promptOptions.InputVideo, promptOptions.VideoInputFPS);
285-
progress?.Invoke(new DiffusionProgress($"Generating video frames @ {promptOptions.VideoInputFPS}fps"));
286-
promptOptions.InputVideo.VideoFrames = videoFrame;
287-
}
288-
289-
290192
private IPipeline CreatePipeline(StableDiffusionModelSet model)
291193
{
292194
return model.PipelineType switch

0 commit comments

Comments
 (0)