Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8df1489

Browse files
committed
Rework StableDiffusionService to be more generic
1 parent cb2f211 commit 8df1489

File tree

12 files changed

+318
-146
lines changed

12 files changed

+318
-146
lines changed

OnnxStack.Console/Examples/StableDebug.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using OnnxStack.StableDiffusion.Common;
22
using OnnxStack.StableDiffusion.Config;
33
using OnnxStack.StableDiffusion.Enums;
4+
using SixLabors.ImageSharp;
45
using System.Diagnostics;
56

67
namespace OnnxStack.Console.Runner
@@ -61,9 +62,10 @@ private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions op
6162
{
6263
var timestamp = Stopwatch.GetTimestamp();
6364
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
64-
var result = await _stableDiffusionService.TextToImageFile(prompt, options, outputFilename);
65+
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
6566
if (result is not null)
6667
{
68+
await result.SaveAsPngAsync(outputFilename);
6769
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
6870
OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
6971
return true;

OnnxStack.Console/Examples/StableDiffusionExample.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.StableDiffusion.Enums;
5+
using SixLabors.ImageSharp;
56

67
namespace OnnxStack.Console.Runner
78
{
@@ -59,11 +60,12 @@ public async Task RunAsync()
5960
private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions options)
6061
{
6162
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
62-
var result = await _stableDiffusionService.TextToImageFile(prompt, outputFilename);
63+
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
6364
if (result == null)
6465
return false;
6566

66-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(result.FileName)}", ConsoleColor.Green);
67+
await result.SaveAsPngAsync(outputFilename);
68+
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
6769
return true;
6870
}
6971
}

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.StableDiffusion.Enums;
5+
using SixLabors.ImageSharp;
56
using System.Collections.ObjectModel;
67

78
namespace OnnxStack.Console.Runner
@@ -57,11 +58,12 @@ public async Task RunAsync()
5758
private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions options, string key)
5859
{
5960
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{key}.png");
60-
var result = await _stableDiffusionService.TextToImageFile(prompt, outputFilename);
61+
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
6162
if (result == null)
6263
return false;
6364

64-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(result.FileName)}", ConsoleColor.Green);
65+
await result.SaveAsPngAsync(outputFilename);
66+
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
6567
return true;
6668
}
6769

Original file line numberDiff line numberDiff line change
@@ -1,20 +1,54 @@
1-
using OnnxStack.StableDiffusion.Config;
2-
using OnnxStack.StableDiffusion.Results;
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Config;
3+
using SixLabors.ImageSharp;
4+
using SixLabors.ImageSharp.PixelFormats;
35
using System;
4-
using System.Collections.Generic;
6+
using System.IO;
57
using System.Threading;
68
using System.Threading.Tasks;
79

810
namespace OnnxStack.StableDiffusion.Common
911
{
1012
public interface IStableDiffusionService
1113
{
12-
Task<ImageResult> TextToImage(PromptOptions prompt);
13-
Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions options);
14-
Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions options, Action<int, int> progress = null, CancellationToken cancellationToken = default);
14+
/// <summary>
15+
/// Generates the StableDiffusion image using the prompt and options provided.
16+
/// </summary>
17+
/// <param name="prompt">The prompt.</param>
18+
/// <param name="options">The Scheduler options.</param>
19+
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
20+
/// <param name="cancellationToken">The cancellation token.</param>
21+
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
22+
Task<DenseTensor<float>> GenerateAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
1523

16-
Task<ImageResult> TextToImageFile(PromptOptions prompt, string outputFile);
17-
Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions options, string outputFile);
18-
Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions options, string outputFile, Action<int, int> progress = null, CancellationToken cancellationToken = default);
24+
/// <summary>
25+
/// Generates the StableDiffusion image using the prompt and options provided.
26+
/// </summary>
27+
/// <param name="prompt">The prompt.</param>
28+
/// <param name="options">The Scheduler options.</param>
29+
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
30+
/// <param name="cancellationToken">The cancellation token.</param>
31+
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgb24>"/></returns>
32+
Task<Image<Rgb24>> GenerateAsImageAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
33+
34+
/// <summary>
35+
/// Generates the StableDiffusion image using the prompt and options provided.
36+
/// </summary>
37+
/// <param name="prompt">The prompt.</param>
38+
/// <param name="options">The Scheduler options.</param>
39+
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
40+
/// <param name="cancellationToken">The cancellation token.</param>
41+
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
42+
Task<byte[]> GenerateAsBytesAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
43+
44+
/// <summary>
45+
/// Generates the StableDiffusion image using the prompt and options provided.
46+
/// </summary>
47+
/// <param name="prompt">The prompt.</param>
48+
/// <param name="options">The Scheduler options.</param>
49+
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
50+
/// <param name="cancellationToken">The cancellation token.</param>
51+
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
52+
Task<Stream> GenerateAsStreamAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
1953
}
2054
}

OnnxStack.StableDiffusion/Config/PromptOptions.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using OnnxStack.StableDiffusion.Enums;
2+
using OnnxStack.StableDiffusion.Models;
23
using System.ComponentModel.DataAnnotations;
4+
using System.Text.Json.Serialization;
35

46
namespace OnnxStack.StableDiffusion.Config
57
{
@@ -12,7 +14,12 @@ public class PromptOptions
1214
[StringLength(512)]
1315
public string NegativePrompt { get; set; }
1416
public SchedulerType SchedulerType { get; set; }
15-
public string InputImage { get; set; }
16-
public bool HasInputImage => !string.IsNullOrEmpty(InputImage);
17+
18+
public InputImage InputImage { get; set; }
19+
20+
public InputImage InputImageMask { get; set; }
21+
22+
public bool HasInputImage => InputImage?.HasImage ?? false;
23+
public bool HasInputImageMask => InputImageMask?.HasImage ?? false;
1724
}
1825
}

OnnxStack.StableDiffusion/Helpers/ImageHelpers.cs

+89-41
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,84 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using OnnxStack.StableDiffusion.Config;
3-
using OnnxStack.StableDiffusion.Results;
2+
using OnnxStack.StableDiffusion.Models;
43
using SixLabors.ImageSharp;
54
using SixLabors.ImageSharp.PixelFormats;
65
using SixLabors.ImageSharp.Processing;
76
using System;
7+
using System.IO;
88

99
namespace OnnxStack.StableDiffusion.Helpers
1010
{
1111
internal static class ImageHelpers
1212
{
1313
/// <summary>
14-
/// Convert a Tensor to image.
14+
/// Converts to image.
1515
/// </summary>
16-
/// <param name="options">The options.</param>
1716
/// <param name="imageTensor">The image tensor.</param>
1817
/// <returns></returns>
19-
public static ImageResult TensorToImage(SchedulerOptions options, DenseTensor<float> imageTensor)
18+
public static Image<Rgb24> ToImage(this DenseTensor<float> imageTensor)
2019
{
21-
var result = new Image<Rgba32>(options.Width, options.Height);
22-
for (var y = 0; y < options.Height; y++)
20+
var height = imageTensor.Dimensions[2];
21+
var width = imageTensor.Dimensions[3];
22+
var result = new Image<Rgb24>(width, height);
23+
for (var y = 0; y < height; y++)
2324
{
24-
for (var x = 0; x < options.Width; x++)
25+
for (var x = 0; x < width; x++)
2526
{
26-
result[x, y] = new Rgba32(
27+
result[x, y] = new Rgb24(
2728
CalculateByte(imageTensor, 0, y, x),
2829
CalculateByte(imageTensor, 1, y, x),
2930
CalculateByte(imageTensor, 2, y, x)
3031
);
3132
}
3233
}
33-
return new ImageResult(result);
34+
return result;
3435
}
3536

37+
3638
/// <summary>
37-
/// Converts an DenseTensor image to Image<Rgba32>
39+
/// Converts to image byte array.
3840
/// </summary>
3941
/// <param name="imageTensor">The image tensor.</param>
40-
/// <param name="width">The width.</param>
41-
/// <param name="height">The height.</param>
4242
/// <returns></returns>
43-
public static Image<Rgba32> TensorToImage(DenseTensor<float> imageTensor, int width, int height)
43+
public static byte[] ToImageBytes(this DenseTensor<float> imageTensor)
4444
{
45-
var image = new Image<Rgba32>(width, height);
46-
for (var y = 0; y < height; y++)
45+
using (var image = imageTensor.ToImage())
46+
using (var memoryStream = new MemoryStream())
4747
{
48-
for (var x = 0; x < width; x++)
49-
{
50-
image[x, y] = new Rgba32(
51-
CalculateByte(imageTensor, 0, y, x),
52-
CalculateByte(imageTensor, 1, y, x),
53-
CalculateByte(imageTensor, 2, y, x)
54-
);
55-
}
48+
image.SaveAsPng(memoryStream);
49+
return memoryStream.ToArray();
5650
}
57-
return image;
51+
}
52+
53+
54+
/// <summary>
55+
/// Converts to image memory stream.
56+
/// </summary>
57+
/// <param name="imageTensor">The image tensor.</param>
58+
/// <returns></returns>
59+
public static Stream ToImageStream(this DenseTensor<float> imageTensor)
60+
{
61+
using (var image = imageTensor.ToImage())
62+
{
63+
var memoryStream = new MemoryStream();
64+
image.SaveAsPng(memoryStream);
65+
return memoryStream;
66+
}
67+
}
68+
69+
70+
public static DenseTensor<float> ToDenseTensor(this InputImage imageData, int width, int height)
71+
{
72+
if (!string.IsNullOrEmpty(imageData.ImagePath))
73+
return TensorFromFile(imageData.ImagePath, width, height);
74+
if(imageData.ImageBytes != null)
75+
return TensorFromBytes(imageData.ImageBytes, width, height);
76+
if (imageData.ImageStream != null)
77+
return TensorFromStream(imageData.ImageStream, width, height);
78+
if (imageData.ToDenseTensor != null)
79+
return imageData.ImageTensor.ToDenseTensor(); // Note: Tensor Copy
80+
81+
return null;
5882
}
5983

6084

@@ -111,31 +135,55 @@ public static void TensorToImageDebug(DenseTensor<float> imageTensor, int size,
111135
/// <param name="width">The width.</param>
112136
/// <param name="height">The height.</param>
113137
/// <returns></returns>
114-
public static DenseTensor<float> TensorFromImage(string filename, int width, int height)
138+
public static DenseTensor<float> TensorFromFile(string filename, int width, int height)
115139
{
116140
using (Image<Rgb24> image = Image.Load<Rgb24>(filename))
117141
{
118142
Resize(image, width, height);
119-
var imageArray = new DenseTensor<float>(new[] { 1, 3, width, height });
120-
image.ProcessPixelRows(img =>
121-
{
122-
for (int x = 0; x < width; x++)
123-
{
124-
for (int y = 0; y < height; y++)
125-
{
126-
var pixelSpan = img.GetRowSpan(y);
127-
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f;
128-
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f;
129-
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f;
130-
}
131-
}
132-
});
133-
return imageArray;
143+
return ProcessPixels(width, height, image);
134144
}
135145
}
136146

137147

148+
public static DenseTensor<float> TensorFromBytes(byte[] imageBytes, int width, int height)
149+
{
150+
using (var image = Image.Load<Rgb24>(imageBytes))
151+
{
152+
Resize(image, width, height);
153+
return ProcessPixels(width, height, image);
154+
}
155+
}
138156

157+
public static DenseTensor<float> TensorFromStream(Stream imageStream, int width, int height)
158+
{
159+
using (var image = Image.Load<Rgb24>(imageStream))
160+
{
161+
Resize(image, width, height);
162+
return ProcessPixels(width, height, image);
163+
}
164+
}
165+
166+
167+
168+
169+
private static DenseTensor<float> ProcessPixels(int width, int height, Image<Rgb24> image)
170+
{
171+
var imageArray = new DenseTensor<float>(new[] { 1, 3, width, height });
172+
image.ProcessPixelRows(img =>
173+
{
174+
for (int x = 0; x < width; x++)
175+
{
176+
for (int y = 0; y < height; y++)
177+
{
178+
var pixelSpan = img.GetRowSpan(y);
179+
imageArray[0, 0, y, x] = (pixelSpan[x].R / 255.0f) * 2.0f - 1.0f;
180+
imageArray[0, 1, y, x] = (pixelSpan[x].G / 255.0f) * 2.0f - 1.0f;
181+
imageArray[0, 2, y, x] = (pixelSpan[x].B / 255.0f) * 2.0f - 1.0f;
182+
}
183+
}
184+
});
185+
return imageArray;
186+
}
139187

140188

141189
/// <summary>

0 commit comments

Comments
 (0)