Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit a68d7a9

Browse files
authored
Merge pull request #63 from saddam213/Init
Improve ModelSet runtime management
2 parents a4cd36f + ef91017 commit a68d7a9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+367
-218
lines changed

OnnxStack.Console/Examples/StableDebug.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public async Task RunAsync()
4848
Strength = 0.6f
4949
};
5050

51-
foreach (var model in _stableDiffusionService.Models)
51+
foreach (var model in _stableDiffusionService.ModelSets)
5252
{
5353
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5454
await _stableDiffusionService.LoadModelAsync(model);
@@ -71,7 +71,7 @@ public async Task RunAsync()
7171
}
7272

7373

74-
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
74+
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
7575
{
7676
var timestamp = Stopwatch.GetTimestamp();
7777
var outputFilename = Path.Combine(_outputDirectory, $"{model.Name}_{options.Seed}_{options.SchedulerType}.png");

OnnxStack.Console/Examples/StableDiffusionBatch.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public async Task RunAsync()
5151
BatchType = BatchOptionType.Scheduler
5252
};
5353

54-
foreach (var model in _stableDiffusionService.Models)
54+
foreach (var model in _stableDiffusionService.ModelSets)
5555
{
5656
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5757
await _stableDiffusionService.LoadModelAsync(model);

OnnxStack.Console/Examples/StableDiffusionExample.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public async Task RunAsync()
4747
Seed = Random.Shared.Next()
4848
};
4949

50-
foreach (var model in _stableDiffusionService.Models)
50+
foreach (var model in _stableDiffusionService.ModelSets)
5151
{
5252
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5353
await _stableDiffusionService.LoadModelAsync(model);
@@ -65,7 +65,7 @@ public async Task RunAsync()
6565
}
6666
}
6767

68-
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options)
68+
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
6969
{
7070
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
7171
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public async Task RunAsync()
3131
Directory.CreateDirectory(_outputDirectory);
3232

3333
var seed = Random.Shared.Next();
34-
foreach (var model in _stableDiffusionService.Models)
34+
foreach (var model in _stableDiffusionService.ModelSets)
3535
{
3636
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
3737
await _stableDiffusionService.LoadModelAsync(model);
@@ -62,7 +62,7 @@ public async Task RunAsync()
6262
OutputHelpers.ReadConsole(ConsoleColor.Gray);
6363
}
6464

65-
private async Task<bool> GenerateImage(ModelOptions model, PromptOptions prompt, SchedulerOptions options, string key)
65+
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, string key)
6666
{
6767
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
6868
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);

OnnxStack.Console/Examples/StableDiffusionGif.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public async Task RunAsync()
5454
};
5555

5656
// Choose Model
57-
var model = _stableDiffusionService.Models.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7");
57+
var model = _stableDiffusionService.ModelSets.FirstOrDefault(x => x.Name == "LCM-Dreamshaper-V7");
5858
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5959
await _stableDiffusionService.LoadModelAsync(model);
6060

OnnxStack.Console/appsettings.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
}
77
},
88
"AllowedHosts": "*",
9-
"OnnxStackConfig": {
10-
"OnnxModelSets": [
9+
"StableDiffusionConfig": {
10+
"ModelSets": [
1111
{
1212
"Name": "StableDiffusion 1.5",
1313
"IsEnabled": true,

OnnxStack.Core/Config/IOnnxModelSetConfig.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ public interface IOnnxModelSetConfig : IOnnxModel
1111
int IntraOpNumThreads { get; set; }
1212
ExecutionMode ExecutionMode { get; set; }
1313
ExecutionProvider ExecutionProvider { get; set; }
14-
List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
14+
List<OnnxModelConfig> ModelConfigurations { get; set; }
1515
}
1616
}

OnnxStack.Core/Config/OnnxModelSessionConfig.cs renamed to OnnxStack.Core/Config/OnnxModelConfig.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace OnnxStack.Core.Config
55
{
6-
public class OnnxModelSessionConfig
6+
public class OnnxModelConfig
77
{
88
public OnnxModelType Type { get; set; }
99
public string OnnxModelPath { get; set; }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System.Collections.Generic;
2+
3+
namespace OnnxStack.Core.Config
4+
{
5+
public class OnnxModelEqualityComparer : IEqualityComparer<IOnnxModel>
6+
{
7+
public bool Equals(IOnnxModel x, IOnnxModel y)
8+
{
9+
return x != null && y != null && x.Name == y.Name;
10+
}
11+
12+
public int GetHashCode(IOnnxModel obj)
13+
{
14+
return obj?.Name?.GetHashCode() ?? 0;
15+
}
16+
}
17+
}

OnnxStack.Core/Config/OnnxModelSetConfig.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ public class OnnxModelSetConfig : IOnnxModelSetConfig
1313
public int IntraOpNumThreads { get; set; }
1414
public ExecutionMode ExecutionMode { get; set; }
1515
public ExecutionProvider ExecutionProvider { get; set; }
16-
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
16+
public List<OnnxModelConfig> ModelConfigurations { get; set; }
1717
}
1818
}
-11
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
11
using OnnxStack.Common.Config;
2-
using System.Collections.Generic;
3-
using System.Linq;
42

53
namespace OnnxStack.Core.Config
64
{
75
public class OnnxStackConfig : IConfigSection
86
{
9-
public List<OnnxModelSetConfig> OnnxModelSets { get; set; } = new List<OnnxModelSetConfig>();
10-
117
public void Initialize()
128
{
13-
if (OnnxModelSets.IsNullOrEmpty())
14-
return;
15-
16-
foreach (var modelSet in OnnxModelSets)
17-
{
18-
modelSet.ApplyConfigurationOverrides();
19-
}
209
}
2110
}
2211
}

OnnxStack.Core/Extensions/Extensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace OnnxStack.Core
1010
{
1111
public static class Extensions
1212
{
13-
public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig configuration)
13+
public static SessionOptions GetSessionOptions(this OnnxModelConfig configuration)
1414
{
1515
var sessionOptions = new SessionOptions
1616
{

OnnxStack.Core/Model/OnnxModelSession.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ public class OnnxModelSession : IDisposable
99
{
1010
private readonly SessionOptions _options;
1111
private readonly InferenceSession _session;
12-
private readonly OnnxModelSessionConfig _configuration;
12+
private readonly OnnxModelConfig _configuration;
1313

1414
/// <summary>
1515
/// Initializes a new instance of the <see cref="OnnxModelSession"/> class.
1616
/// </summary>
1717
/// <param name="configuration">The configuration.</param>
1818
/// <param name="container">The container.</param>
1919
/// <exception cref="System.IO.FileNotFoundException">Onnx model file not found</exception>
20-
public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsContainer container)
20+
public OnnxModelSession(OnnxModelConfig configuration, PrePackedWeightsContainer container)
2121
{
2222
if (!File.Exists(configuration.OnnxModelPath))
2323
throw new FileNotFoundException("Onnx model file not found", configuration.OnnxModelPath);
@@ -44,7 +44,7 @@ public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsCo
4444
/// <summary>
4545
/// Gets the configuration.
4646
/// </summary>
47-
public OnnxModelSessionConfig Configuration => _configuration;
47+
public OnnxModelConfig Configuration => _configuration;
4848

4949

5050
/// <summary>

OnnxStack.Core/Model/OnnxModelSet.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public InferenceSession GetSession(OnnxModelType modelType)
7878
/// </summary>
7979
/// <param name="modelType">Type of the model.</param>
8080
/// <returns></returns>
81-
public OnnxModelSessionConfig GetConfiguration(OnnxModelType modelType)
81+
public OnnxModelConfig GetConfiguration(OnnxModelType modelType)
8282
{
8383
return _configuration.ModelConfigurations.FirstOrDefault(x => x.Type == modelType);
8484
}

OnnxStack.Core/Registration.cs

+18-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public static class Registration
1616
/// <param name="serviceCollection">The service collection.</param>
1717
public static void AddOnnxStack(this IServiceCollection serviceCollection)
1818
{
19-
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration());
19+
serviceCollection.AddSingleton(TryLoadAppSettings());
2020
serviceCollection.AddSingleton<IOnnxModelService, OnnxModelService>();
2121
}
2222

@@ -43,5 +43,22 @@ public static void AddOnnxStackConfig<T>(this IServiceCollection serviceCollecti
4343
{
4444
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration<T>());
4545
}
46+
47+
48+
/// <summary>
49+
/// Try load OnnxStackConfig from application settings if it exists.
50+
/// </summary>
51+
/// <returns></returns>
52+
private static OnnxStackConfig TryLoadAppSettings()
53+
{
54+
try
55+
{
56+
return ConfigManager.LoadConfiguration<OnnxStackConfig>();
57+
}
58+
catch
59+
{
60+
return new OnnxStackConfig();
61+
}
62+
}
4663
}
4764
}

OnnxStack.Core/Services/IOnnxModelService.cs

+7-9
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,26 @@ public interface IOnnxModelService : IDisposable
2626
/// <returns></returns>
2727
Task<bool> AddModelSet(IOnnxModelSetConfig modelSet);
2828

29-
3029
/// <summary>
3130
/// Adds a collection of ModelSet
3231
/// </summary>
3332
/// <param name="modelSets">The model sets.</param>
3433
Task AddModelSet(IEnumerable<IOnnxModelSetConfig> modelSets);
3534

36-
3735
/// <summary>
3836
/// Removes a model set.
3937
/// </summary>
4038
/// <param name="modelSet">The model set.</param>
4139
/// <returns></returns>
4240
Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet);
4341

42+
/// <summary>
43+
/// Updates the model set.
44+
/// </summary>
45+
/// <param name="modelSet">The model set.</param>
46+
/// <returns></returns>
47+
Task<bool> UpdateModelSet(IOnnxModelSetConfig modelSet);
48+
4449
/// <summary>
4550
/// Loads the model.
4651
/// </summary>
@@ -65,13 +70,6 @@ public interface IOnnxModelService : IDisposable
6570
bool IsModelLoaded(IOnnxModel model);
6671

6772

68-
/// <summary>
69-
/// Updates the model set.
70-
/// </summary>
71-
/// <param name="modelSet">The model set.</param>
72-
/// <returns></returns>
73-
bool UpdateModelSet(IOnnxModelSetConfig modelSet);
74-
7573
/// <summary>
7674
/// Determines whether the specified model type is enabled.
7775
/// </summary>

OnnxStack.Core/Services/OnnxModelService.cs

+18-22
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ namespace OnnxStack.Core.Services
1616
public sealed class OnnxModelService : IOnnxModelService
1717
{
1818
private readonly OnnxStackConfig _configuration;
19-
private readonly ConcurrentDictionary<string, OnnxModelSet> _onnxModelSets;
20-
private readonly ConcurrentDictionary<string, IOnnxModelSetConfig> _onnxModelSetConfigs;
19+
private readonly ConcurrentDictionary<IOnnxModel, OnnxModelSet> _onnxModelSets;
20+
private readonly ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig> _onnxModelSetConfigs;
2121

2222
/// <summary>
2323
/// Initializes a new instance of the <see cref="OnnxModelService"/> class.
@@ -26,8 +26,8 @@ public sealed class OnnxModelService : IOnnxModelService
2626
public OnnxModelService(OnnxStackConfig configuration)
2727
{
2828
_configuration = configuration;
29-
_onnxModelSets = new ConcurrentDictionary<string, OnnxModelSet>();
30-
_onnxModelSetConfigs = _configuration.OnnxModelSets.ToConcurrentDictionary(x => x.Name, x => x as IOnnxModelSetConfig);
29+
_onnxModelSets = new ConcurrentDictionary<IOnnxModel, OnnxModelSet>(new OnnxModelEqualityComparer());
30+
_onnxModelSetConfigs = new ConcurrentDictionary<IOnnxModel, IOnnxModelSetConfig>(new OnnxModelEqualityComparer());
3131
}
3232

3333

@@ -50,7 +50,7 @@ public OnnxModelService(OnnxStackConfig configuration)
5050
/// <returns></returns>
5151
public Task<bool> AddModelSet(IOnnxModelSetConfig modelSet)
5252
{
53-
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet));
53+
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
5454
}
5555

5656
/// <summary>
@@ -74,7 +74,7 @@ public Task AddModelSet(IEnumerable<IOnnxModelSetConfig> modelSets)
7474
/// <returns></returns>
7575
public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
7676
{
77-
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet.Name, out _));
77+
return Task.FromResult(_onnxModelSetConfigs.TryRemove(modelSet, out _));
7878
}
7979

8080

@@ -83,10 +83,10 @@ public Task<bool> RemoveModelSet(IOnnxModelSetConfig modelSet)
8383
/// </summary>
8484
/// <param name="modelSet">The model set.</param>
8585
/// <returns></returns>
86-
public bool UpdateModelSet(IOnnxModelSetConfig modelSet)
86+
public Task<bool> UpdateModelSet(IOnnxModelSetConfig modelSet)
8787
{
88-
_onnxModelSetConfigs.TryRemove(modelSet.Name, out _);
89-
return _onnxModelSetConfigs.TryAdd(modelSet.Name, modelSet);
88+
_onnxModelSetConfigs.TryRemove(modelSet, out _);
89+
return Task.FromResult(_onnxModelSetConfigs.TryAdd(modelSet, modelSet));
9090
}
9191

9292

@@ -120,7 +120,7 @@ public async Task<bool> UnloadModelAsync(IOnnxModel model)
120120
/// </returns>
121121
public bool IsModelLoaded(IOnnxModel model)
122122
{
123-
return _onnxModelSets.ContainsKey(model.Name);
123+
return _onnxModelSets.ContainsKey(model);
124124
}
125125

126126

@@ -251,7 +251,7 @@ private OnnxMetadata GetNodeMetadataInternal(IOnnxModel model, OnnxModelType mod
251251
/// <exception cref="System.Exception">Model {model.Name} has not been loaded</exception>
252252
private OnnxModelSet GetModelSet(IOnnxModel model)
253253
{
254-
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
254+
if (!_onnxModelSets.TryGetValue(model, out var modelSet))
255255
throw new Exception($"Model {model.Name} has not been loaded");
256256

257257
return modelSet;
@@ -266,17 +266,17 @@ private OnnxModelSet GetModelSet(IOnnxModel model)
266266
/// <exception cref="System.Exception">Model {model.Name} not found in configuration</exception>
267267
private OnnxModelSet LoadModelSet(IOnnxModel model)
268268
{
269-
if (_onnxModelSets.ContainsKey(model.Name))
270-
return _onnxModelSets[model.Name];
269+
if (_onnxModelSets.ContainsKey(model))
270+
return _onnxModelSets[model];
271271

272-
if (!_onnxModelSetConfigs.TryGetValue(model.Name, out var modelSetConfig))
273-
throw new Exception($"Model {model.Name} not found in configuration");
272+
if (!_onnxModelSetConfigs.TryGetValue(model, out var modelSetConfig))
273+
throw new Exception($"Model {model.Name} not found");
274274

275275
if (!modelSetConfig.IsEnabled)
276276
throw new Exception($"Model {model.Name} is not enabled");
277277

278278
var modelSet = new OnnxModelSet(modelSetConfig);
279-
_onnxModelSets.TryAdd(model.Name, modelSet);
279+
_onnxModelSets.TryAdd(model, modelSet);
280280
return modelSet;
281281
}
282282

@@ -288,10 +288,10 @@ private OnnxModelSet LoadModelSet(IOnnxModel model)
288288
/// <returns></returns>
289289
private bool UnloadModelSet(IOnnxModel model)
290290
{
291-
if (!_onnxModelSets.TryGetValue(model.Name, out var modelSet))
291+
if (!_onnxModelSets.TryGetValue(model, out _))
292292
return true;
293293

294-
if (_onnxModelSets.TryRemove(model.Name, out modelSet))
294+
if (_onnxModelSets.TryRemove(model, out var modelSet))
295295
{
296296
modelSet?.Dispose();
297297
return true;
@@ -310,9 +310,5 @@ public void Dispose()
310310
onnxModelSet?.Dispose();
311311
}
312312
}
313-
314-
315313
}
316-
317-
318314
}

OnnxStack.ImageUpscaler/Config/UpscaleModelSet.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ public class UpscaleModelSet : IOnnxModelSetConfig
1616
public int IntraOpNumThreads { get; set; }
1717
public ExecutionMode ExecutionMode { get; set; }
1818
public ExecutionProvider ExecutionProvider { get; set; }
19-
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
19+
public List<OnnxModelConfig> ModelConfigurations { get; set; }
2020
}
2121
}

0 commit comments

Comments
 (0)