Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 0ffc9ea

Browse files
committed
VideoToVideo process prototype
1 parent 648df54 commit 0ffc9ea

34 files changed

+1491
-286
lines changed

OnnxStack.Console/Examples/StableDiffusionBatch.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.StableDiffusion.Common;
1+
using OnnxStack.Core.Image;
2+
using OnnxStack.StableDiffusion.Common;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Enums;
45
using OnnxStack.StableDiffusion.Helpers;
@@ -61,11 +62,11 @@ public async Task RunAsync()
6162
var batchIndex = 0;
6263
var callback = (DiffusionProgress progress) =>
6364
{
64-
batchIndex = progress.ProgressValue;
65-
OutputHelpers.WriteConsole($"Image: {progress.ProgressValue}/{progress.ProgressMax} - Step: {progress.SubProgressValue}/{progress.SubProgressMax}", ConsoleColor.Cyan);
65+
batchIndex = progress.BatchValue;
66+
OutputHelpers.WriteConsole($"Image: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan);
6667
};
6768

68-
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
69+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, default))
6970
{
7071
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
7172
var image = result.ImageResult.ToImage();

OnnxStack.Core/Image/Extensions.cs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using SixLabors.ImageSharp.PixelFormats;
3+
using SixLabors.ImageSharp;
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
using System.IO;
10+
11+
namespace OnnxStack.Core.Image
12+
{
13+
public static class Extensions
14+
{
15+
public static Image<Rgba32> ToImage(this DenseTensor<float> imageTensor)
16+
{
17+
var height = imageTensor.Dimensions[2];
18+
var width = imageTensor.Dimensions[3];
19+
var hasAlpha = imageTensor.Dimensions[1] == 4;
20+
var result = new Image<Rgba32>(width, height);
21+
for (var y = 0; y < height; y++)
22+
{
23+
for (var x = 0; x < width; x++)
24+
{
25+
result[x, y] = new Rgba32(
26+
CalculateByte(imageTensor, 0, y, x),
27+
CalculateByte(imageTensor, 1, y, x),
28+
CalculateByte(imageTensor, 2, y, x),
29+
hasAlpha ? CalculateByte(imageTensor, 3, y, x) : byte.MaxValue
30+
);
31+
}
32+
}
33+
return result;
34+
}
35+
36+
/// <summary>
37+
/// Converts to image byte array.
38+
/// </summary>
39+
/// <param name="imageTensor">The image tensor.</param>
40+
/// <returns></returns>
41+
public static byte[] ToImageBytes(this DenseTensor<float> imageTensor)
42+
{
43+
using (var image = imageTensor.ToImage())
44+
using (var memoryStream = new MemoryStream())
45+
{
46+
image.SaveAsPng(memoryStream);
47+
return memoryStream.ToArray();
48+
}
49+
}
50+
51+
/// <summary>
52+
/// Converts to image byte array.
53+
/// </summary>
54+
/// <param name="imageTensor">The image tensor.</param>
55+
/// <returns></returns>
56+
public static async Task<byte[]> ToImageBytesAsync(this DenseTensor<float> imageTensor)
57+
{
58+
using (var image = imageTensor.ToImage())
59+
using (var memoryStream = new MemoryStream())
60+
{
61+
await image.SaveAsPngAsync(memoryStream);
62+
return memoryStream.ToArray();
63+
}
64+
}
65+
66+
67+
private static byte CalculateByte(Tensor<float> imageTensor, int index, int y, int x)
68+
{
69+
return (byte)Math.Round(Math.Clamp(imageTensor[0, index, y, x] / 2 + 0.5, 0, 1) * 255);
70+
}
71+
72+
}
73+
}

OnnxStack.Core/Services/IVideoService.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.Core.Video;
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Video;
23
using System.Collections.Generic;
34
using System.IO;
45
using System.Threading;
@@ -87,6 +88,14 @@ public interface IVideoService
8788
/// <returns></returns>
8889
Task<VideoOutput> CreateVideoAsync(VideoFrames videoFrames, CancellationToken cancellationToken = default);
8990

91+
// <summary>
92+
/// Creates and MP4 video from a collection of PNG images.
93+
/// </summary>
94+
/// <param name="videoTensor">The video frames.</param>
95+
/// <param name="videoFPS">The video FPS.</param>
96+
/// <param name="cancellationToken">The cancellation token.</param>
97+
/// <returns></returns>
98+
Task<VideoOutput> CreateVideoAsync(DenseTensor<float> videoTensor, float videoFPS, CancellationToken cancellationToken = default);
9099

91100
/// <summary>
92101
/// Streams frames as PNG as they are processed from a video source

OnnxStack.Core/Services/VideoService.cs

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using FFMpegCore;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
23
using OnnxStack.Core.Config;
34
using OnnxStack.Core.Video;
45
using System;
@@ -94,6 +95,20 @@ public async Task<VideoOutput> CreateVideoAsync(VideoFrames videoFrames, Cancell
9495
}
9596

9697

98+
/// <summary>
99+
/// Creates and MP4 video from a collection of PNG images.
100+
/// </summary>
101+
/// <param name="videoTensor">The video tensor.</param>
102+
/// <param name="videoFPS">The video FPS.</param>
103+
/// <param name="cancellationToken">The cancellation token.</param>
104+
/// <returns></returns>
105+
public async Task<VideoOutput> CreateVideoAsync(DenseTensor<float> videoTensor, float videoFPS, CancellationToken cancellationToken = default)
106+
{
107+
var videoFrames = await videoTensor.ToVideoFramesAsBytesAsync().ToListAsync(cancellationToken);
108+
return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken);
109+
}
110+
111+
97112
/// <summary>
98113
/// Creates and MP4 video from a collection of PNG images.
99114
/// </summary>
@@ -141,6 +156,7 @@ public async Task<VideoFrames> CreateFramesAsync(byte[] videoBytes, float videoF
141156
{
142157
var videoInfo = await GetVideoInfoAsync(videoBytes, cancellationToken);
143158
var videoFrames = await CreateFramesInternalAsync(videoBytes, videoFPS, cancellationToken).ToListAsync(cancellationToken);
159+
videoInfo = videoInfo with { FPS = videoFPS };
144160
return new VideoFrames(videoInfo, videoFrames);
145161
}
146162

@@ -190,7 +206,7 @@ public IAsyncEnumerable<byte[]> StreamFramesAsync(byte[] videoBytes, float targe
190206
/// <returns></returns>
191207
private async Task<VideoInfo> GetVideoInfoInternalAsync(MemoryStream videoStream, CancellationToken cancellationToken = default)
192208
{
193-
var result = await FFProbe.AnalyseAsync(videoStream, cancellationToken: cancellationToken).ConfigureAwait(false);
209+
var result = await FFProbe.AnalyseAsync(videoStream).ConfigureAwait(false);
194210
return new VideoInfo(result.PrimaryVideoStream.Width, result.PrimaryVideoStream.Height, result.Duration, (int)result.PrimaryVideoStream.FrameRate);
195211
}
196212

OnnxStack.Core/Video/Extensions.cs

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Image;
3+
using SixLabors.ImageSharp;
4+
using SixLabors.ImageSharp.PixelFormats;
5+
using System.Collections.Generic;
6+
7+
namespace OnnxStack.Core.Video
8+
{
9+
public static class Extensions
10+
{
11+
public static IEnumerable<DenseTensor<float>> ToVideoFrames(this DenseTensor<float> videoTensor)
12+
{
13+
var count = videoTensor.Dimensions[0];
14+
var dimensions = videoTensor.Dimensions.ToArray();
15+
dimensions[0] = 1;
16+
17+
var newLength = (int)videoTensor.Length / count;
18+
for (int i = 0; i < count; i++)
19+
{
20+
var start = i * newLength;
21+
yield return new DenseTensor<float>(videoTensor.Buffer.Slice(start, newLength), dimensions);
22+
}
23+
}
24+
25+
public static IEnumerable<byte[]> ToVideoFramesAsBytes(this DenseTensor<float> videoTensor)
26+
{
27+
foreach (var frame in videoTensor.ToVideoFrames())
28+
{
29+
yield return frame.ToImageBytes();
30+
}
31+
}
32+
33+
public static async IAsyncEnumerable<byte[]> ToVideoFramesAsBytesAsync(this DenseTensor<float> videoTensor)
34+
{
35+
foreach (var frame in videoTensor.ToVideoFrames())
36+
{
37+
yield return await frame.ToImageBytesAsync();
38+
}
39+
}
40+
41+
public static IEnumerable<Image<Rgba32>> ToVideoFramesAsImage(this DenseTensor<float> videoTensor)
42+
{
43+
foreach (var frame in videoTensor.ToVideoFrames())
44+
{
45+
yield return frame.ToImage();
46+
}
47+
}
48+
}
49+
}

OnnxStack.Core/Video/VideoInfo.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
namespace OnnxStack.Core.Video
44
{
5-
public record VideoInfo(int Width, int Height, TimeSpan Duration, int FPS);
5+
public record VideoInfo(int Width, int Height, TimeSpan Duration, float FPS);
66
}

OnnxStack.Core/Video/VideoInput.cs

+15-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ public VideoInput() { }
2929
/// <param name="videoTensor">The video tensor.</param>
3030
public VideoInput(DenseTensor<float> videoTensor) => VideoTensor = videoTensor;
3131

32+
/// <summary>
33+
/// Initializes a new instance of the <see cref="VideoInput"/> class.
34+
/// </summary>
35+
/// <param name="videoFrames">The video frames.</param>
36+
public VideoInput(VideoFrames videoFrames) => VideoFrames = videoFrames;
37+
3238

3339
/// <summary>
3440
/// Gets the video bytes.
@@ -51,6 +57,13 @@ public VideoInput() { }
5157
public DenseTensor<float> VideoTensor { get; set; }
5258

5359

60+
/// <summary>
61+
/// Gets or sets the video frames.
62+
/// </summary>
63+
[JsonIgnore]
64+
public VideoFrames VideoFrames { get; set; }
65+
66+
5467
/// <summary>
5568
/// Gets a value indicating whether this instance has video.
5669
/// </summary>
@@ -60,6 +73,7 @@ public VideoInput() { }
6073
[JsonIgnore]
6174
public bool HasVideo => VideoBytes != null
6275
|| VideoStream != null
63-
|| VideoTensor != null;
76+
|| VideoTensor != null
77+
|| VideoFrames != null;
6478
}
6579
}

OnnxStack.StableDiffusion/Config/PromptOptions.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ public class PromptOptions
2020

2121
public InputImage InputImageMask { get; set; }
2222

23-
public VideoFrames InputVideo { get; set; }
23+
public VideoInput InputVideo { get; set; }
2424

25-
public bool HasInputVideo => InputVideo?.Frames?.Count > 0;
25+
public float VideoInputFPS { get; set; }
26+
public float VideoOutputFPS { get; set; }
27+
28+
public bool HasInputVideo => InputVideo?.HasVideo ?? false;
2629
public bool HasInputImage => InputImage?.HasImage ?? false;
2730
public bool HasInputImageMask => InputImageMask?.HasImage ?? false;
2831
}

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ public record SchedulerOptions
8484
public float AestheticScore { get; set; } = 6f;
8585
public float AestheticNegativeScore { get; set; } = 2.5f;
8686

87-
public float VideoFPS { get; set; }
88-
8987
public bool IsKarrasScheduler
9088
{
9189
get

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

+57-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.ML.OnnxRuntime.Tensors;
44
using OnnxStack.Core;
55
using OnnxStack.Core.Config;
6+
using OnnxStack.Core.Image;
67
using OnnxStack.Core.Model;
78
using OnnxStack.Core.Services;
89
using OnnxStack.StableDiffusion.Common;
@@ -113,15 +114,38 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
113114
// Process prompts
114115
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
115116

117+
// If video input, process frames
118+
if (promptOptions.HasInputVideo)
119+
{
120+
var frameIndex = 0;
121+
DenseTensor<float> videoTensor = null;
122+
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
123+
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
124+
foreach (var videoFrame in videoFrames)
125+
{
126+
frameIndex++;
127+
promptOptions.InputImage = new InputImage(videoFrame);
128+
var frameResultTensor = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);
129+
130+
// Frame Progress
131+
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);
132+
133+
// Concatenate frame
134+
videoTensor = videoTensor.Concatenate(frameResultTensor);
135+
}
136+
137+
_logger?.LogEnd($"Diffuse complete", diffuseTime);
138+
return videoTensor;
139+
}
140+
116141
// Run Scheduler steps
117142
var schedulerResult = await SchedulerStepAsync(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
118-
119143
_logger?.LogEnd($"Diffuse complete", diffuseTime);
120-
121144
return schedulerResult;
122145
}
123146

124147

148+
125149
/// <summary>
126150
/// Runs the stable diffusion batch loop
127151
/// </summary>
@@ -152,15 +176,11 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffu
152176
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
153177

154178
var batchIndex = 1;
155-
var schedulerCallback = (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress(batchIndex, batchSchedulerOptions.Count, progress.ProgressTensor)
156-
{
157-
SubProgressMax = progress.ProgressMax,
158-
SubProgressValue = progress.ProgressValue,
159-
});
179+
var batchSchedulerCallback = CreateBatchCallback(progressCallback, batchSchedulerOptions.Count, () => batchIndex);
160180
foreach (var batchSchedulerOption in batchSchedulerOptions)
161181
{
162182
var diffuseTime = _logger?.LogBegin("Diffuse starting...");
163-
yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken));
183+
yield return new BatchResult(batchSchedulerOption, await SchedulerStepAsync(modelOptions, promptOptions, batchSchedulerOption, promptEmbeddings, performGuidance, batchSchedulerCallback, cancellationToken));
164184
_logger?.LogEnd($"Diffuse complete", diffuseTime);
165185
batchIndex++;
166186
}
@@ -264,9 +284,14 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
264284
/// <param name="progress">The progress.</param>
265285
/// <param name="progressMax">The progress maximum.</param>
266286
/// <param name="output">The output.</param>
267-
protected void ReportProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, DenseTensor<float> output)
287+
protected void ReportProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, DenseTensor<float> progressTensor)
268288
{
269-
progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output));
289+
progressCallback?.Invoke(new DiffusionProgress
290+
{
291+
StepMax = progressMax,
292+
StepValue = progress,
293+
StepTensor = progressTensor
294+
});
270295
}
271296

272297

@@ -279,13 +304,31 @@ protected void ReportProgress(Action<DiffusionProgress> progressCallback, int pr
279304
/// <param name="subProgress">The sub progress.</param>
280305
/// <param name="subProgressMax">The sub progress maximum.</param>
281306
/// <param name="output">The output.</param>
282-
protected void ReportProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, int subProgress, int subProgressMax, DenseTensor<float> output)
307+
protected void ReportBatchProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, DenseTensor<float> progressTensor)
308+
{
309+
progressCallback?.Invoke(new DiffusionProgress
310+
{
311+
BatchMax = progressMax,
312+
BatchValue = progress,
313+
BatchTensor = progressTensor
314+
});
315+
}
316+
317+
318+
private static Action<DiffusionProgress> CreateBatchCallback(Action<DiffusionProgress> progressCallback, int batchCount, Func<int> batchIndex)
283319
{
284-
progressCallback?.Invoke(new DiffusionProgress(progress, progressMax, output)
320+
if (progressCallback == null)
321+
return progressCallback;
322+
323+
return (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress
285324
{
286-
SubProgressMax = subProgressMax,
287-
SubProgressValue = subProgress,
325+
StepMax = progress.StepMax,
326+
StepValue = progress.StepValue,
327+
StepTensor = progress.StepTensor,
328+
BatchMax = batchCount,
329+
BatchValue = batchIndex()
288330
});
289331
}
332+
290333
}
291334
}

0 commit comments

Comments
 (0)