Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7e58b38

Browse files
authored
Merge pull request #23 from saddam213/Automation
Image Batch Processing
2 parents 77a2b6a + adceb2c commit 7e58b38

37 files changed

+1200
-386
lines changed

OnnxStack.Console/Examples/StableDebug.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
using OnnxStack.StableDiffusion.Common;
1+
using OnnxStack.StableDiffusion;
2+
using OnnxStack.StableDiffusion.Common;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Enums;
4-
using OnnxStack.StableDiffusion.Services;
55
using SixLabors.ImageSharp;
66
using System.Diagnostics;
77

@@ -37,11 +37,11 @@ public async Task RunAsync()
3737
{
3838
Prompt = prompt,
3939
NegativePrompt = negativePrompt,
40-
SchedulerType = SchedulerType.LMS
4140
};
4241

4342
var schedulerOptions = new SchedulerOptions
4443
{
44+
SchedulerType = SchedulerType.LMS,
4545
Seed = 624461087,
4646
//Seed = Random.Shared.Next(),
4747
GuidanceScale = 8,
@@ -54,9 +54,9 @@ public async Task RunAsync()
5454
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5555
await _stableDiffusionService.LoadModel(model);
5656

57-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
57+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5858
{
59-
promptOptions.SchedulerType = schedulerType;
59+
schedulerOptions.SchedulerType = schedulerType;
6060
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
6161
await GenerateImage(model, promptOptions, schedulerOptions);
6262
}
@@ -72,12 +72,12 @@ public async Task RunAsync()
7272
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
7373
{
7474
var timestamp = Stopwatch.GetTimestamp();
75-
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
75+
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
7676
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
7777
if (result is not null)
7878
{
7979
await result.SaveAsPngAsync(outputFilename);
80-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
80+
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
8181
OutputHelpers.WriteConsole($"Elapsed: {Stopwatch.GetElapsedTime(timestamp)}ms", ConsoleColor.Yellow);
8282
return true;
8383
}
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
using OnnxStack.Core;
2-
using OnnxStack.StableDiffusion.Common;
1+
using OnnxStack.StableDiffusion.Common;
32
using OnnxStack.StableDiffusion.Config;
43
using OnnxStack.StableDiffusion.Enums;
5-
using OnnxStack.StableDiffusion.Helpers;
4+
using OnnxStack.StableDiffusion;
65
using SixLabors.ImageSharp;
6+
using OnnxStack.StableDiffusion.Helpers;
77

88
namespace OnnxStack.Console.Runner
99
{
@@ -31,68 +31,50 @@ public async Task RunAsync()
3131

3232
while (true)
3333
{
34-
OutputHelpers.WriteConsole("Please type a prompt and press ENTER", ConsoleColor.Yellow);
35-
var prompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
36-
37-
OutputHelpers.WriteConsole("Please type a negative prompt and press ENTER (optional)", ConsoleColor.Yellow);
38-
var negativePrompt = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
39-
40-
OutputHelpers.WriteConsole("Please enter a batch count and press ENTER", ConsoleColor.Yellow);
41-
var batch = OutputHelpers.ReadConsole(ConsoleColor.Cyan);
42-
int.TryParse(batch, out var batchCount);
43-
batchCount = Math.Max(1, batchCount);
4434

4535
var promptOptions = new PromptOptions
4636
{
47-
Prompt = prompt,
48-
NegativePrompt = negativePrompt,
49-
BatchCount = batchCount
37+
Prompt = "Photo of a cat"
5038
};
5139

5240
var schedulerOptions = new SchedulerOptions
5341
{
5442
Seed = Random.Shared.Next(),
5543

5644
GuidanceScale = 8,
57-
InferenceSteps = 22,
45+
InferenceSteps = 20,
5846
Strength = 0.6f
5947
};
6048

49+
var batchOptions = new BatchOptions
50+
{
51+
BatchType = BatchOptionType.Scheduler
52+
};
53+
6154
foreach (var model in _stableDiffusionService.Models)
6255
{
6356
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
6457
await _stableDiffusionService.LoadModel(model);
6558

66-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
59+
var batchIndex = 0;
60+
var callback = (int batch, int batchCount, int step, int steps) =>
61+
{
62+
batchIndex = batch;
63+
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
64+
};
65+
66+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
6767
{
68-
promptOptions.SchedulerType = schedulerType;
69-
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
70-
await GenerateImage(model, promptOptions, schedulerOptions);
68+
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
69+
var image = result.ImageResult.ToImage();
70+
await image.SaveAsPngAsync(outputFilename);
71+
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
7172
}
7273

7374
OutputHelpers.WriteConsole($"Unloading Model `{model.Name}`...", ConsoleColor.Green);
7475
await _stableDiffusionService.UnloadModel(model);
7576
}
7677
}
7778
}
78-
79-
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
80-
{
81-
82-
var result = await _stableDiffusionService.GenerateAsync(model, prompt, options);
83-
if (result == null)
84-
return false;
85-
86-
var imageTensors = result.Split(prompt.BatchCount);
87-
for (int i = 0; i < imageTensors.Length; i++)
88-
{
89-
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{i}.png");
90-
var image = imageTensors[i].ToImage();
91-
await image.SaveAsPngAsync(outputFilename);
92-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
93-
}
94-
95-
return true;
96-
}
9779
}
9880
}

OnnxStack.Console/Examples/StableDiffusionExample.cs

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using OnnxStack.Core;
1+
using OnnxStack.StableDiffusion;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4-
using OnnxStack.StableDiffusion.Enums;
54
using SixLabors.ImageSharp;
65

76
namespace OnnxStack.Console.Runner
@@ -53,9 +52,9 @@ public async Task RunAsync()
5352
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5453
await _stableDiffusionService.LoadModel(model);
5554

56-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
55+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5756
{
58-
promptOptions.SchedulerType = schedulerType;
57+
schedulerOptions.SchedulerType = schedulerType;
5958
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
6059
await GenerateImage(model, promptOptions, schedulerOptions);
6160
}
@@ -68,13 +67,13 @@ public async Task RunAsync()
6867

6968
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
7069
{
71-
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
70+
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
7271
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
7372
if (result == null)
7473
return false;
7574

7675
await result.SaveAsPngAsync(outputFilename);
77-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
76+
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
7877
return true;
7978
}
8079
}

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using OnnxStack.Core;
1+
using OnnxStack.StableDiffusion;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4-
using OnnxStack.StableDiffusion.Enums;
54
using SixLabors.ImageSharp;
65
using System.Collections.ObjectModel;
76

@@ -48,9 +47,9 @@ public async Task RunAsync()
4847
{
4948
Seed = Random.Shared.Next()
5049
};
51-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
50+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5251
{
53-
promptOptions.SchedulerType = schedulerType;
52+
schedulerOptions.SchedulerType = schedulerType;
5453
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);
5554
await GenerateImage(model, promptOptions, schedulerOptions, generationPrompt.Key);
5655
}
@@ -65,13 +64,13 @@ public async Task RunAsync()
6564

6665
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key)
6766
{
68-
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{key}.png");
67+
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
6968
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
7069
if (result == null)
7170
return false;
7271

7372
await result.SaveAsPngAsync(outputFilename);
74-
OutputHelpers.WriteConsole($"{prompt.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
73+
OutputHelpers.WriteConsole($"{options.SchedulerType} Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
7574
return true;
7675
}
7776

OnnxStack.Console/Helpers.cs

-28
This file was deleted.

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

+49
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using OnnxStack.Core.Config;
33
using OnnxStack.Core.Model;
44
using OnnxStack.StableDiffusion.Config;
5+
using OnnxStack.StableDiffusion.Models;
56
using SixLabors.ImageSharp;
67
using SixLabors.ImageSharp.PixelFormats;
78
using System;
@@ -83,5 +84,53 @@ public interface IStableDiffusionService
8384
/// <param name="cancellationToken">The cancellation token.</param>
8485
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
8586
Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
87+
88+
/// <summary>
89+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
90+
/// </summary>
91+
/// <param name="modelOptions">The model options.</param>
92+
/// <param name="promptOptions">The prompt options.</param>
93+
/// <param name="schedulerOptions">The scheduler options.</param>
94+
/// <param name="batchOptions">The batch options.</param>
95+
/// <param name="progressCallback">The progress callback.</param>
96+
/// <param name="cancellationToken">The cancellation token.</param>
97+
/// <returns></returns>
98+
IAsyncEnumerable<BatchResult> GenerateBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
99+
100+
/// <summary>
101+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
102+
/// </summary>
103+
/// <param name="modelOptions">The model options.</param>
104+
/// <param name="promptOptions">The prompt options.</param>
105+
/// <param name="schedulerOptions">The scheduler options.</param>
106+
/// <param name="batchOptions">The batch options.</param>
107+
/// <param name="progressCallback">The progress callback.</param>
108+
/// <param name="cancellationToken">The cancellation token.</param>
109+
/// <returns></returns>
110+
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
111+
112+
/// <summary>
113+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
114+
/// </summary>
115+
/// <param name="modelOptions">The model options.</param>
116+
/// <param name="promptOptions">The prompt options.</param>
117+
/// <param name="schedulerOptions">The scheduler options.</param>
118+
/// <param name="batchOptions">The batch options.</param>
119+
/// <param name="progressCallback">The progress callback.</param>
120+
/// <param name="cancellationToken">The cancellation token.</param>
121+
/// <returns></returns>
122+
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
123+
124+
/// <summary>
125+
/// Generates a batch of StableDiffusion image using the prompt and options provided.
126+
/// </summary>
127+
/// <param name="modelOptions">The model options.</param>
128+
/// <param name="promptOptions">The prompt options.</param>
129+
/// <param name="schedulerOptions">The scheduler options.</param>
130+
/// <param name="batchOptions">The batch options.</param>
131+
/// <param name="progressCallback">The progress callback.</param>
132+
/// <param name="cancellationToken">The cancellation token.</param>
133+
/// <returns></returns>
134+
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
86135
}
87136
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using OnnxStack.StableDiffusion.Enums;
2+
3+
namespace OnnxStack.StableDiffusion.Config
4+
{
5+
public record BatchOptions
6+
{
7+
public BatchOptionType BatchType { get; set; }
8+
public float ValueTo { get; set; }
9+
public float ValueFrom { get; set; }
10+
public float Increment { get; set; } = 1f;
11+
}
12+
}

OnnxStack.StableDiffusion/Config/PromptOptions.cs

-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ public class PromptOptions
1414

1515
[StringLength(512)]
1616
public string NegativePrompt { get; set; }
17-
public SchedulerType SchedulerType { get; set; }
1817

1918
public int BatchCount { get; set; } = 1;
2019

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44

55
namespace OnnxStack.StableDiffusion.Config
66
{
7-
public class SchedulerOptions
7+
public record SchedulerOptions
88
{
9+
/// <summary>
10+
/// Gets or sets the type of scheduler.
11+
/// </summary>
12+
public SchedulerType SchedulerType { get; set; }
13+
914
/// <summary>
1015
/// Gets or sets the height.
1116
/// </summary>

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

+15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.StableDiffusion.Enums;
5+
using OnnxStack.StableDiffusion.Models;
56
using System;
7+
using System.Collections.Generic;
68
using System.Threading;
79
using System.Threading.Tasks;
810

@@ -33,5 +35,18 @@ public interface IDiffuser
3335
/// <param name="cancellationToken">The cancellation token.</param>
3436
/// <returns></returns>
3537
Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
38+
39+
40+
/// <summary>
41+
/// Runs the stable diffusion batch loop
42+
/// </summary>
43+
/// <param name="modelOptions">The model options.</param>
44+
/// <param name="promptOptions">The prompt options.</param>
45+
/// <param name="schedulerOptions">The scheduler options.</param>
46+
/// <param name="batchOptions">The batch options.</param>
47+
/// <param name="progressCallback">The progress callback.</param>
48+
/// <param name="cancellationToken">The cancellation token.</param>
49+
/// <returns></returns>
50+
IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default);
3651
}
3752
}

0 commit comments

Comments
 (0)