Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 07f4563

Browse files
authored
Merge pull request #65 from saddam213/LCM_XL
Add LatentConsistency XL pipeline
2 parents a68d7a9 + 0cfbea4 commit 07f4563

File tree

11 files changed

+562
-29
lines changed

11 files changed

+562
-29
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Linq;
15+
using System.Threading.Tasks;
16+
17+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
18+
{
19+
public sealed class ImageDiffuser : LatentConsistencyXLDiffuser
20+
{
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
23+
/// </summary>
24+
/// <param name="configuration">The configuration.</param>
25+
/// <param name="onnxModelService">The onnx model service.</param>
26+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyXLDiffuser> logger)
27+
: base(onnxModelService, promptService, logger)
28+
{
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the type of the diffuser.
34+
/// </summary>
35+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
36+
37+
38+
/// <summary>
39+
/// Gets the timesteps.
40+
/// </summary>
41+
/// <param name="prompt">The prompt.</param>
42+
/// <param name="options">The options.</param>
43+
/// <param name="scheduler">The scheduler.</param>
44+
/// <returns></returns>
45+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
46+
{
47+
// Image2Image we narrow step the range by the Strength
48+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
49+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
50+
return scheduler.Timesteps.Skip(start).ToList();
51+
}
52+
53+
54+
/// <summary>
55+
/// Prepares the latents for inference.
56+
/// </summary>
57+
/// <param name="prompt">The prompt.</param>
58+
/// <param name="options">The options.</param>
59+
/// <param name="scheduler">The scheduler.</param>
60+
/// <returns></returns>
61+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
62+
{
63+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
64+
65+
//TODO: Model Config, Channels
66+
var outputDimension = options.GetScaledDimension();
67+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
68+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
69+
{
70+
inferenceParameters.AddInputTensor(imageTensor);
71+
inferenceParameters.AddOutputBuffer(outputDimension);
72+
73+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
74+
using (var result = results.First())
75+
{
76+
var outputResult = result.ToDenseTensor();
77+
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
78+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
79+
}
80+
}
81+
}
82+
83+
}
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using SixLabors.ImageSharp.Processing;
13+
using System;
14+
using System.Collections.Generic;
15+
using System.Diagnostics;
16+
using System.Linq;
17+
using System.Threading;
18+
using System.Threading.Tasks;
19+
20+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
21+
{
22+
public sealed class InpaintLegacyDiffuser : LatentConsistencyXLDiffuser
23+
{
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="InpaintLegacyDiffuser"/> class.
26+
/// </summary>
27+
/// <param name="configuration">The configuration.</param>
28+
/// <param name="onnxModelService">The onnx model service.</param>
29+
public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<LatentConsistencyXLDiffuser> logger)
30+
: base(onnxModelService, promptService, logger)
31+
{
32+
}
33+
34+
35+
/// <summary>
36+
/// Gets the type of the diffuser.
37+
/// </summary>
38+
public override DiffuserType DiffuserType => DiffuserType.ImageInpaintLegacy;
39+
40+
41+
/// <summary>
42+
/// Runs the scheduler steps.
43+
/// </summary>
44+
/// <param name="modelOptions">The model options.</param>
45+
/// <param name="promptOptions">The prompt options.</param>
46+
/// <param name="schedulerOptions">The scheduler options.</param>
47+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
48+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
49+
/// <param name="progressCallback">The progress callback.</param>
50+
/// <param name="cancellationToken">The cancellation token.</param>
51+
/// <returns></returns>
52+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
53+
{
54+
using (var scheduler = GetScheduler(schedulerOptions))
55+
{
56+
// Get timesteps
57+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
58+
59+
// Create latent sample
60+
var latentsOriginal = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
61+
62+
// Create masks sample
63+
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);
64+
65+
// Generate some noise
66+
var noise = scheduler.CreateRandomSample(latentsOriginal.Dimensions);
67+
68+
// Add noise to original latent
69+
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
70+
71+
// Get Model metadata
72+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
73+
74+
// Get Time ids
75+
var addTimeIds = GetAddTimeIds(modelOptions, schedulerOptions, performGuidance);
76+
77+
// Loop though the timesteps
78+
var step = 0;
79+
foreach (var timestep in timesteps)
80+
{
81+
step++;
82+
var stepTime = Stopwatch.GetTimestamp();
83+
cancellationToken.ThrowIfCancellationRequested();
84+
85+
// Create input tensor.
86+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
87+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
88+
var timestepTensor = CreateTimestepTensor(timestep);
89+
90+
var outputChannels = performGuidance ? 2 : 1;
91+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
92+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
93+
{
94+
inferenceParameters.AddInputTensor(inputTensor);
95+
inferenceParameters.AddInputTensor(timestepTensor);
96+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
97+
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
98+
inferenceParameters.AddInputTensor(addTimeIds);
99+
inferenceParameters.AddOutputBuffer(outputDimension);
100+
101+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
102+
using (var result = results.First())
103+
{
104+
var noisePred = result.ToDenseTensor();
105+
106+
// Perform guidance
107+
if (performGuidance)
108+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
109+
110+
// Scheduler Step
111+
var steplatents = scheduler.Step(noisePred, timestep, latents).Result;
112+
113+
// Add noise to original latent
114+
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });
115+
116+
// Apply mask and combine
117+
latents = ApplyMaskedLatents(steplatents, initLatentsProper, maskImage);
118+
}
119+
}
120+
121+
progressCallback?.Invoke(step, timesteps.Count);
122+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
123+
}
124+
125+
// Decode Latents
126+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
127+
}
128+
}
129+
130+
131+
/// <summary>
132+
/// Gets the timesteps.
133+
/// </summary>
134+
/// <param name="prompt">The prompt.</param>
135+
/// <param name="options">The options.</param>
136+
/// <param name="scheduler">The scheduler.</param>
137+
/// <returns></returns>
138+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
139+
{
140+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
141+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
142+
return scheduler.Timesteps.Skip(start).ToList();
143+
}
144+
145+
146+
/// <summary>
147+
/// Prepares the latents for inference.
148+
/// </summary>
149+
/// <param name="prompt">The prompt.</param>
150+
/// <param name="options">The options.</param>
151+
/// <param name="scheduler">The scheduler.</param>
152+
/// <returns></returns>
153+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
154+
{
155+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
156+
157+
//TODO: Model Config, Channels
158+
var outputDimensions = options.GetScaledDimension();
159+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
160+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
161+
{
162+
inferenceParameters.AddInputTensor(imageTensor);
163+
inferenceParameters.AddOutputBuffer(outputDimensions);
164+
165+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
166+
using (var result = results.First())
167+
{
168+
var outputResult = result.ToDenseTensor();
169+
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
170+
return scaledSample;
171+
}
172+
}
173+
}
174+
175+
176+
/// <summary>
177+
/// Prepares the mask.
178+
/// </summary>
179+
/// <param name="promptOptions">The prompt options.</param>
180+
/// <param name="schedulerOptions">The scheduler options.</param>
181+
/// <returns></returns>
182+
private DenseTensor<float> PrepareMask(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
183+
{
184+
using (var mask = promptOptions.InputImageMask.ToImage())
185+
{
186+
// Prepare the mask
187+
int width = schedulerOptions.GetScaledWidth();
188+
int height = schedulerOptions.GetScaledHeight();
189+
mask.Mutate(x => x.Grayscale());
190+
mask.Mutate(x => x.Resize(new Size(width, height), KnownResamplers.NearestNeighbor, true));
191+
var maskTensor = new DenseTensor<float>(new[] { 1, 4, width, height });
192+
mask.ProcessPixelRows(img =>
193+
{
194+
for (int x = 0; x < width; x++)
195+
{
196+
for (int y = 0; y < height; y++)
197+
{
198+
var pixelSpan = img.GetRowSpan(y);
199+
var value = 1f - (pixelSpan[x].A / 255.0f);
200+
maskTensor[0, 0, y, x] = value;
201+
maskTensor[0, 1, y, x] = value; // Needed for shape only
202+
maskTensor[0, 2, y, x] = value; // Needed for shape only
203+
maskTensor[0, 3, y, x] = value; // Needed for shape only
204+
}
205+
}
206+
});
207+
return maskTensor;
208+
}
209+
}
210+
211+
212+
/// <summary>
213+
/// Applies the masked latents.
214+
/// </summary>
215+
/// <param name="latents">The latents.</param>
216+
/// <param name="initLatentsProper">The initialize latents proper.</param>
217+
/// <param name="mask">The mask.</param>
218+
/// <returns></returns>
219+
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
220+
{
221+
var result = new DenseTensor<float>(latents.Dimensions);
222+
for (int i = 0; i < result.Length; i++)
223+
{
224+
float maskValue = mask.GetValue(i);
225+
result.SetValue(i, initLatentsProper.GetValue(i) * maskValue + latents.GetValue(i) * (1f - maskValue));
226+
}
227+
return result;
228+
}
229+
}
230+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Services;
4+
using OnnxStack.StableDiffusion.Common;
5+
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL;
7+
using OnnxStack.StableDiffusion.Enums;
8+
using OnnxStack.StableDiffusion.Models;
9+
using System.Collections.Generic;
10+
using System.Threading.Tasks;
11+
using System.Threading;
12+
using System;
13+
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
14+
15+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
16+
{
17+
public abstract class LatentConsistencyXLDiffuser : StableDiffusionXLDiffuser
18+
{
19+
protected LatentConsistencyXLDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
20+
: base(onnxModelService, promptService, logger) { }
21+
22+
23+
/// <summary>
24+
/// Gets the type of the pipeline.
25+
/// </summary>
26+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL;
27+
28+
29+
/// <summary>
30+
/// Runs the stable diffusion loop
31+
/// </summary>
32+
/// <param name="modelOptions"></param>
33+
/// <param name="promptOptions">The prompt options.</param>
34+
/// <param name="schedulerOptions">The scheduler options.</param>
35+
/// <param name="progressCallback"></param>
36+
/// <param name="cancellationToken">The cancellation token.</param>
37+
/// <returns></returns>
38+
public override Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
39+
{
40+
// LCM does not support negative prompting
41+
promptOptions.NegativePrompt = string.Empty;
42+
return base.DiffuseAsync(modelOptions, promptOptions, schedulerOptions, progressCallback, cancellationToken);
43+
}
44+
45+
46+
/// <summary>
47+
/// Runs the stable diffusion batch loop
48+
/// </summary>
49+
/// <param name="modelOptions">The model options.</param>
50+
/// <param name="promptOptions">The prompt options.</param>
51+
/// <param name="schedulerOptions">The scheduler options.</param>
52+
/// <param name="batchOptions">The batch options.</param>
53+
/// <param name="progressCallback">The progress callback.</param>
54+
/// <param name="cancellationToken">The cancellation token.</param>
55+
/// <returns></returns>
56+
public override IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, CancellationToken cancellationToken = default)
57+
{
58+
// LCM does not support negative prompting
59+
promptOptions.NegativePrompt = string.Empty;
60+
return base.DiffuseBatchAsync(modelOptions, promptOptions, schedulerOptions, batchOptions, progressCallback, cancellationToken);
61+
}
62+
63+
64+
/// <summary>
65+
/// Gets the scheduler.
66+
/// </summary>
67+
/// <param name="prompt"></param>
68+
/// <param name="options">The options.</param>
69+
/// <returns></returns>
70+
protected override IScheduler GetScheduler(SchedulerOptions options)
71+
{
72+
return options.SchedulerType switch
73+
{
74+
SchedulerType.LCM => new LCMScheduler(options),
75+
_ => default
76+
};
77+
}
78+
}
79+
}

0 commit comments

Comments
 (0)