From 26cb93933ccf754e2d0902be538991ae237f1762 Mon Sep 17 00:00:00 2001 From: compujuckel Date: Thu, 27 Apr 2023 21:27:09 +0200 Subject: [PATCH 1/6] Allow specifying list of user ids in SocketGuild.DownloadUsersAsync --- .../API/Gateway/GuildMembersChunkEvent.cs | 6 +++ .../API/Gateway/RequestMembersParams.cs | 7 ++- .../DiscordSocketApiClient.cs | 10 +++- .../DiscordSocketClient.cs | 50 ++++++++++++------- .../Entities/Guilds/SocketGuild.cs | 8 ++- 5 files changed, 58 insertions(+), 23 deletions(-) diff --git a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs index 26114bf541..80d1ed745b 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs @@ -8,5 +8,11 @@ internal class GuildMembersChunkEvent public ulong GuildId { get; set; } [JsonProperty("members")] public GuildMember[] Members { get; set; } + [JsonProperty("chunk_index")] + public int ChunkIndex { get; set; } + [JsonProperty("chunk_count")] + public int ChunkCount { get; set; } + [JsonProperty("nonce")] + public string Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs index f7a63e330c..856a67e9bc 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs @@ -10,8 +10,11 @@ internal class RequestMembersParams public string Query { get; set; } [JsonProperty("limit")] public int Limit { get; set; } - [JsonProperty("guild_id")] - public IEnumerable GuildIds { get; set; } + public ulong GuildId { get; set; } + [JsonProperty("user_ids")] + public IEnumerable UserIds { get; set; } + [JsonProperty("nonce")] + public string Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 75960b173e..a3898ed2ff 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -379,11 +379,17 @@ public async Task SendPresenceUpdateAsync(UserStatus status, bool isAFK, long? s options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id; await SendGatewayAsync(GatewayOpCode.PresenceUpdate, args, options: options).ConfigureAwait(false); } - public async Task SendRequestMembersAsync(IEnumerable guildIds, RequestOptions options = null) + public async Task SendRequestMembersAsync(ulong guildId, RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildIds = guildIds, Query = "", Limit = 0 }, options: options).ConfigureAwait(false); + await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildId = guildId, Query = "", Limit = 0 }, options: options).ConfigureAwait(false); } + public async Task SendRequestMembersAsync(ulong guildId, IEnumerable userIds, string nonce, RequestOptions options = null) + { + options = RequestOptions.CreateOrClone(options); + await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildId = guildId, Limit = 0, UserIds = userIds, Nonce = nonce }, options: options).ConfigureAwait(false); + } + public async Task SendVoiceStateUpdateAsync(ulong guildId, ulong? channelId, bool selfDeaf, bool selfMute, RequestOptions options = null) { var payload = new VoiceStateUpdateParams diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 924f5f645c..ce31e03257 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -37,6 +37,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient private readonly ConnectionManager _connection; private readonly Logger _gatewayLogger; private readonly SemaphoreSlim _stateLock; + private readonly ConcurrentDictionary> _guildMembersRequestTasks; private string _sessionId; private int _lastSeq; @@ -51,6 +52,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient private GatewayIntents _gatewayIntents; private ImmutableArray> _defaultStickers; private SocketSelfUser _previousSessionUser; + private long _guildMembersRequestCounter; /// /// Provides access to a REST-only client with a shared state from this client. @@ -183,6 +185,8 @@ private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClie e.ErrorContext.Handled = true; }; + _guildMembersRequestTasks = new ConcurrentDictionary>(); + ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; @@ -627,30 +631,33 @@ private async Task ProcessUserDownloadsAsync(IEnumerable guilds) { var cachedGuilds = guilds.ToImmutableArray(); - const short batchSize = 1; - ulong[] batchIds = new ulong[Math.Min(batchSize, cachedGuilds.Length)]; - Task[] batchTasks = new Task[batchIds.Length]; - int batchCount = (cachedGuilds.Length + (batchSize - 1)) / batchSize; + foreach (var guild in cachedGuilds) + { + await ApiClient.SendRequestMembersAsync(guild.Id).ConfigureAwait(false); + await guild.DownloaderPromise.ConfigureAwait(false); + } + } - for (int i = 0, k = 0; i < batchCount; i++) + public async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) + { + if (ConnectionState == ConnectionState.Connected) { - bool isLast = i == batchCount - 1; - int count = isLast ? (cachedGuilds.Length - (batchCount - 1) * batchSize) : batchSize; + EnsureGatewayIntent(GatewayIntents.GuildMembers); - for (int j = 0; j < count; j++, k++) + var socketGuild = GetGuild(guild.Id); + if (socketGuild != null) { - var guild = cachedGuilds[k]; - batchIds[j] = guild.Id; - batchTasks[j] = guild.DownloaderPromise; + await ProcessUserDownloadsAsync(socketGuild, userIds).ConfigureAwait(false); } + } + } - await ApiClient.SendRequestMembersAsync(batchIds).ConfigureAwait(false); - if (isLast && batchCount > 1) - await Task.WhenAll(batchTasks.Take(count)).ConfigureAwait(false); - else - await Task.WhenAll(batchTasks).ConfigureAwait(false); - } + private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds) + { + var nonce = Interlocked.Increment(ref _guildMembersRequestCounter).ToString(); + _guildMembersRequestTasks.TryAdd(nonce, new TaskCompletionSource()); + await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false); } /// @@ -1410,6 +1417,13 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty guild.CompleteDownloadUsers(); await TimedInvokeAsync(_guildMembersDownloadedEvent, nameof(GuildMembersDownloaded), guild).ConfigureAwait(false); } + + if (data.Nonce != null + && data.ChunkIndex + 1 >= data.ChunkCount + && _guildMembersRequestTasks.TryRemove(data.Nonce, out var tcs)) + { + tcs.TrySetResult(true); + } } else { @@ -2904,7 +2918,7 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty } break; #endregion - + #region Auto Moderation case "AUTO_MODERATION_RULE_CREATE": diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 9180ad92f3..20ee31916b 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -1260,6 +1260,12 @@ public async Task DownloadUsersAsync() { await Discord.DownloadUsersAsync(new[] { this }).ConfigureAwait(false); } + + public async Task DownloadUsersAsync(IEnumerable userIds) + { + await Discord.DownloadUsersAsync(this, userIds).ConfigureAwait(false); + } + internal void CompleteDownloadUsers() { _downloaderPromise.TrySetResultAsync(true); @@ -1406,7 +1412,7 @@ public Task CreateEventAsync( /// /// A task that represents the asynchronous get operation. The task result contains a read-only collection /// of the requested audit log entries. - /// + /// public IAsyncEnumerable> GetAuditLogsAsync(int limit, RequestOptions options = null, ulong? beforeId = null, ulong? userId = null, ActionType? actionType = null, ulong? afterId = null) => GuildHelper.GetAuditLogsAsync(this, Discord, beforeId, limit, options, userId: userId, actionType: actionType, afterId: afterId); From f8ace7992452f95ae17c54d9a216bfd80634a487 Mon Sep 17 00:00:00 2001 From: compujuckel Date: Fri, 28 Apr 2023 15:14:41 +0200 Subject: [PATCH 2/6] use Optional for optional parameters and (hopefully) fix CI errors --- src/Discord.Net.Core/Entities/Guilds/IGuild.cs | 11 +++++++++++ src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs | 6 +++++- .../API/Gateway/GuildMembersChunkEvent.cs | 7 ++++++- .../API/Gateway/RequestMembersParams.cs | 12 +++++++----- src/Discord.Net.WebSocket/DiscordSocketApiClient.cs | 9 ++++++++- src/Discord.Net.WebSocket/DiscordSocketClient.cs | 4 ++-- .../Entities/Guilds/SocketGuild.cs | 3 ++- 7 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs index aec5bff1e6..f7c0d04538 100644 --- a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs +++ b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs @@ -959,6 +959,17 @@ public interface IGuild : IDeletable, ISnowflakeEntity /// Task DownloadUsersAsync(); /// + /// Downloads specific users for this guild. + /// + /// + /// This method downloads all users specified in through the Gateway and caches them. + /// + /// The list of Discord user IDs to download + /// + /// A task that represents the asynchronous download operation. + /// + Task DownloadUsersAsync(IEnumerable userIds); + /// /// Prunes inactive users. /// /// diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index a6c2d2d998..e64b36da11 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -1512,6 +1512,10 @@ async Task> IGuild.GetUsersAsync(CacheMode mode, Task IGuild.DownloadUsersAsync() => throw new NotSupportedException(); /// + /// Downloading users is not supported for a REST-based guild. + Task IGuild.DownloadUsersAsync(IEnumerable userIds) => + throw new NotSupportedException(); + /// async Task> IGuild.SearchUsersAsync(string query, int limit, CacheMode mode, RequestOptions options) { if (mode == CacheMode.AllowDownload) @@ -1604,7 +1608,7 @@ async Task IGuild.GetAutoModRulesAsync(RequestOptions options) /// async Task IGuild.CreateAutoModRuleAsync(Action props, RequestOptions options) => await CreateAutoModRuleAsync(props, options).ConfigureAwait(false); - + /// async Task IGuild.GetOnboardingAsync(RequestOptions options) => await GetOnboardingAsync(options); diff --git a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs index 80d1ed745b..e62dae08bf 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs @@ -1,4 +1,5 @@ using Newtonsoft.Json; +using System.Collections.Generic; namespace Discord.API.Gateway { @@ -12,7 +13,11 @@ internal class GuildMembersChunkEvent public int ChunkIndex { get; set; } [JsonProperty("chunk_count")] public int ChunkCount { get; set; } + [JsonProperty("not_found")] + public Optional> NotFound { get; set; } + [JsonProperty("presences")] + public Optional> Presences { get; set; } [JsonProperty("nonce")] - public string Nonce { get; set; } + public Optional Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs index 856a67e9bc..d88fce51c5 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs @@ -6,15 +6,17 @@ namespace Discord.API.Gateway [JsonObject(MemberSerialization = MemberSerialization.OptIn)] internal class RequestMembersParams { + [JsonProperty("guild_id")] + public ulong GuildId { get; set; } [JsonProperty("query")] - public string Query { get; set; } + public Optional Query { get; set; } [JsonProperty("limit")] public int Limit { get; set; } - [JsonProperty("guild_id")] - public ulong GuildId { get; set; } + [JsonProperty("presences")] + public Optional Presences { get; set; } [JsonProperty("user_ids")] - public IEnumerable UserIds { get; set; } + public Optional> UserIds { get; set; } [JsonProperty("nonce")] - public string Nonce { get; set; } + public Optional Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index a3898ed2ff..64a0496502 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -386,8 +386,15 @@ public async Task SendRequestMembersAsync(ulong guildId, RequestOptions options } public async Task SendRequestMembersAsync(ulong guildId, IEnumerable userIds, string nonce, RequestOptions options = null) { + var payload = new RequestMembersParams + { + GuildId = guildId, + Limit = 0, + UserIds = new Optional>(userIds), + Nonce = nonce + }; options = RequestOptions.CreateOrClone(options); - await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildId = guildId, Limit = 0, UserIds = userIds, Nonce = nonce }, options: options).ConfigureAwait(false); + await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, payload, options: options).ConfigureAwait(false); } public async Task SendVoiceStateUpdateAsync(ulong guildId, ulong? channelId, bool selfDeaf, bool selfMute, RequestOptions options = null) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index ce31e03257..3b0f9d51e2 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -1418,9 +1418,9 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty await TimedInvokeAsync(_guildMembersDownloadedEvent, nameof(GuildMembersDownloaded), guild).ConfigureAwait(false); } - if (data.Nonce != null + if (data.Nonce.IsSpecified && data.ChunkIndex + 1 >= data.ChunkCount - && _guildMembersRequestTasks.TryRemove(data.Nonce, out var tcs)) + && _guildMembersRequestTasks.TryRemove(data.Nonce.Value, out var tcs)) { tcs.TrySetResult(true); } diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 20ee31916b..659a7a1b7c 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -383,7 +383,7 @@ public IReadOnlyCollection Stickers /// /// /// Otherwise, you may need to enable to fetch - /// the full user list upon startup, or use to manually download + /// the full user list upon startup, or use to manually download /// the users. /// /// @@ -1261,6 +1261,7 @@ public async Task DownloadUsersAsync() await Discord.DownloadUsersAsync(new[] { this }).ConfigureAwait(false); } + /// public async Task DownloadUsersAsync(IEnumerable userIds) { await Discord.DownloadUsersAsync(this, userIds).ConfigureAwait(false); From df483728196ebfc387c17b91fafbb12728e0b91e Mon Sep 17 00:00:00 2001 From: compujuckel Date: Fri, 28 Apr 2023 16:51:59 +0200 Subject: [PATCH 3/6] await download task in ProcessUserDownloadsAsync --- src/Discord.Net.WebSocket/DiscordSocketClient.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 3b0f9d51e2..1893245048 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -656,8 +656,10 @@ public async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds) { var nonce = Interlocked.Increment(ref _guildMembersRequestCounter).ToString(); - _guildMembersRequestTasks.TryAdd(nonce, new TaskCompletionSource()); + var tcs = new TaskCompletionSource(); + _guildMembersRequestTasks.TryAdd(nonce, tcs); await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false); + await tcs.Task.ConfigureAwait(false); } /// From 3f599c83b6ec8221b8d4f70bcd4e75661eb4b08d Mon Sep 17 00:00:00 2001 From: compujuckel Date: Fri, 28 Apr 2023 17:02:11 +0200 Subject: [PATCH 4/6] add new DownloadUsersAsync to BaseSocketClient and DiscordShardedClient --- src/Discord.Net.WebSocket/BaseSocketClient.cs | 9 +++++++++ .../DiscordShardedClient.cs | 17 +++++++++++++++++ .../DiscordSocketClient.cs | 3 ++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 482a08a0f9..f3bb7b7ea1 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -235,6 +235,15 @@ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config /// A task that represents the asynchronous download operation. /// public abstract Task DownloadUsersAsync(IEnumerable guilds); + /// + /// Attempts to download specific users into the user cache for the selected guilds. + /// + /// The guild to download the members from. + /// The list of Discord user IDs to download. + /// + /// A task that represents the asynchronous download operation. + /// + public abstract Task DownloadUsersAsync(IGuild guild, IEnumerable userIds); /// /// Creates a guild for the logged-in user who is in less than 10 active guilds. diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index c3809ba672..145e846a95 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -381,6 +381,23 @@ public override async Task DownloadUsersAsync(IEnumerable guilds) } } + /// + /// is + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) + { + if (guild == null) + throw new ArgumentNullException(nameof(guild)); + + for (int i = 0; i < _shards.Length; i++) + { + int id = _shardIds[i]; + if (GetShardIdFor(guild) == id) + { + await _shards[i].DownloadUsersAsync(guild, userIds).ConfigureAwait(false); + } + } + } + private int GetLatency() { int total = 0; diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 1893245048..0b915f11d7 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -638,7 +638,8 @@ private async Task ProcessUserDownloadsAsync(IEnumerable guilds) } } - public async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) + /// + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) { if (ConnectionState == ConnectionState.Connected) { From 0d06290c6dd0e6307773cc6cc4edd28ea6c50acf Mon Sep 17 00:00:00 2001 From: compujuckel Date: Sat, 29 Apr 2023 13:35:35 +0200 Subject: [PATCH 5/6] small fix --- src/Discord.Net.WebSocket/BaseSocketClient.cs | 2 +- src/Discord.Net.WebSocket/DiscordShardedClient.cs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index f3bb7b7ea1..979638e5ef 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -236,7 +236,7 @@ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config /// public abstract Task DownloadUsersAsync(IEnumerable guilds); /// - /// Attempts to download specific users into the user cache for the selected guilds. + /// Attempts to download specific users into the user cache for the selected guild. /// /// The guild to download the members from. /// The list of Discord user IDs to download. diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 145e846a95..7f51862439 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -394,6 +394,7 @@ public override async Task DownloadUsersAsync(IGuild guild, IEnumerable u if (GetShardIdFor(guild) == id) { await _shards[i].DownloadUsersAsync(guild, userIds).ConfigureAwait(false); + break; } } } From 1aad00b86092ba1f6b3f109e240e4d1e878245de Mon Sep 17 00:00:00 2001 From: compujuckel Date: Fri, 30 Jun 2023 19:54:37 +0200 Subject: [PATCH 6/6] - allow cancellation of DownloadUsersAsync - throw exception when client not connected - chunk user downloads --- src/Discord.Net.Core/DiscordConfig.cs | 5 + .../Entities/Guilds/IGuild.cs | 21 +++- .../Extensions/EnumerableExtensions.cs | 103 ++++++++++++++++++ .../Entities/Guilds/RestGuild.cs | 5 + src/Discord.Net.WebSocket/BaseSocketClient.cs | 5 +- .../DiscordShardedClient.cs | 7 +- .../DiscordSocketClient.cs | 26 ++++- .../Entities/Guilds/SocketGuild.cs | 9 +- 8 files changed, 168 insertions(+), 13 deletions(-) create mode 100644 src/Discord.Net.Core/Extensions/EnumerableExtensions.cs diff --git a/src/Discord.Net.Core/DiscordConfig.cs b/src/Discord.Net.Core/DiscordConfig.cs index 3aacc30b6a..396c3069b8 100644 --- a/src/Discord.Net.Core/DiscordConfig.cs +++ b/src/Discord.Net.Core/DiscordConfig.cs @@ -232,5 +232,10 @@ public class DiscordConfig /// Returns the max length of an application description. /// public const int MaxApplicationDescriptionLength = 400; + + /// + /// Returns the max number of user IDs that can be requested in a Request Guild Members chunk. + /// + public const int MaxRequestedUserIdsPerRequestGuildMembersChunk = 100; } } diff --git a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs index f7c0d04538..cdcd196c0c 100644 --- a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs +++ b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Globalization; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Discord @@ -958,17 +959,33 @@ public interface IGuild : IDeletable, ISnowflakeEntity /// A task that represents the asynchronous download operation. /// Task DownloadUsersAsync(); + /// - /// Downloads specific users for this guild. + /// Downloads specific users for this guild with the default request timeout. /// /// /// This method downloads all users specified in through the Gateway and caches them. + /// Consider using when downloading a large number of users. /// - /// The list of Discord user IDs to download + /// The list of Discord user IDs to download. /// /// A task that represents the asynchronous download operation. /// + /// The timeout has elapsed. Task DownloadUsersAsync(IEnumerable userIds); + + /// + /// Downloads specific users for this guild. + /// + /// + /// This method downloads all users specified in through the Gateway and caches them. + /// + /// The list of Discord user IDs to download. + /// The cancellation token used to cancel the task. + /// + /// A task that represents the asynchronous download operation. + /// + Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken); /// /// Prunes inactive users. /// diff --git a/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs new file mode 100644 index 0000000000..0ce262c90f --- /dev/null +++ b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs @@ -0,0 +1,103 @@ +// Based on https://github.com/dotnet/runtime/blob/main/src/libraries/System.Linq/src/System/Linq/Chunk.cs (only available on .NET 6+) +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Discord +{ + internal static class EnumerableExtensions + { + /// + /// Split the elements of a sequence into chunks of size at most . + /// + /// + /// Every chunk except the last will be of size . + /// The last chunk will contain the remaining elements and may be of a smaller size. + /// + /// + /// An whose elements to chunk. + /// + /// + /// Maximum size of each chunk. + /// + /// + /// The type of the elements of source. + /// + /// + /// An that contains the elements the input sequence split into chunks of size . + /// + /// + /// is null. + /// + /// + /// is below 1. + /// + public static IEnumerable Chunk(this IEnumerable source, int size) + { + Preconditions.NotNull(source, nameof(source)); + Preconditions.GreaterThan(size, 0, nameof(size)); + + return ChunkIterator(source, size); + } + + private static IEnumerable ChunkIterator(IEnumerable source, int size) + { + using IEnumerator e = source.GetEnumerator(); + + // Before allocating anything, make sure there's at least one element. + if (e.MoveNext()) + { + // Now that we know we have at least one item, allocate an initial storage array. This is not + // the array we'll yield. It starts out small in order to avoid significantly overallocating + // when the source has many fewer elements than the chunk size. + int arraySize = Math.Min(size, 4); + int i; + do + { + var array = new TSource[arraySize]; + + // Store the first item. + array[0] = e.Current; + i = 1; + + if (size != array.Length) + { + // This is the first chunk. As we fill the array, grow it as needed. + for (; i < size && e.MoveNext(); i++) + { + if (i >= array.Length) + { + arraySize = (int)Math.Min((uint)size, 2 * (uint)array.Length); + Array.Resize(ref array, arraySize); + } + + array[i] = e.Current; + } + } + else + { + // For all but the first chunk, the array will already be correctly sized. + // We can just store into it until either it's full or MoveNext returns false. + TSource[] local = array; // avoid bounds checks by using cached local (`array` is lifted to iterator object as a field) + Debug.Assert(local.Length == size); + for (; (uint)i < (uint)local.Length && e.MoveNext(); i++) + { + local[i] = e.Current; + } + } + + if (i != array.Length) + { + Array.Resize(ref array, i); + } + + yield return array; + } + while (i >= size && e.MoveNext()); + } + } + } +} diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index e64b36da11..fdf35b09ca 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Model = Discord.API.Guild; using WidgetModel = Discord.API.GuildWidget; @@ -1516,6 +1517,10 @@ Task IGuild.DownloadUsersAsync() => Task IGuild.DownloadUsersAsync(IEnumerable userIds) => throw new NotSupportedException(); /// + /// Downloading users is not supported for a REST-based guild. + Task IGuild.DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken) => + throw new NotSupportedException(); + /// async Task> IGuild.SearchUsersAsync(string query, int limit, CacheMode mode, RequestOptions options) { if (mode == CacheMode.AllowDownload) diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 979638e5ef..897aa67460 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Discord.WebSocket @@ -235,15 +236,17 @@ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config /// A task that represents the asynchronous download operation. /// public abstract Task DownloadUsersAsync(IEnumerable guilds); + /// /// Attempts to download specific users into the user cache for the selected guild. /// /// The guild to download the members from. /// The list of Discord user IDs to download. + /// The cancellation token used to cancel the task. /// /// A task that represents the asynchronous download operation. /// - public abstract Task DownloadUsersAsync(IGuild guild, IEnumerable userIds); + public abstract Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default); /// /// Creates a guild for the logged-in user who is in less than 10 active guilds. diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index 7f51862439..b4c59c4236 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -383,17 +383,16 @@ public override async Task DownloadUsersAsync(IEnumerable guilds) /// /// is - public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) { - if (guild == null) - throw new ArgumentNullException(nameof(guild)); + Preconditions.NotNull(guild, nameof(guild)); for (int i = 0; i < _shards.Length; i++) { int id = _shardIds[i]; if (GetShardIdFor(guild) == id) { - await _shards[i].DownloadUsersAsync(guild, userIds).ConfigureAwait(false); + await _shards[i].DownloadUsersAsync(guild, userIds, cancelToken).ConfigureAwait(false); break; } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 0b915f11d7..b7ddc73139 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -639,7 +639,7 @@ private async Task ProcessUserDownloadsAsync(IEnumerable guilds) } /// - public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds) + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) { if (ConnectionState == ConnectionState.Connected) { @@ -648,19 +648,35 @@ public override async Task DownloadUsersAsync(IGuild guild, IEnumerable u var socketGuild = GetGuild(guild.Id); if (socketGuild != null) { - await ProcessUserDownloadsAsync(socketGuild, userIds).ConfigureAwait(false); + foreach (var chunk in userIds.Chunk(DiscordConfig.MaxRequestedUserIdsPerRequestGuildMembersChunk)) + { + await ProcessUserDownloadsAsync(socketGuild, chunk, cancelToken).ConfigureAwait(false); + } } } + else + { + throw new InvalidOperationException("Client not connected"); + } } - private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds) + private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) { var nonce = Interlocked.Increment(ref _guildMembersRequestCounter).ToString(); var tcs = new TaskCompletionSource(); + using var registration = cancelToken.Register(() => tcs.TrySetCanceled()); _guildMembersRequestTasks.TryAdd(nonce, tcs); - await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false); - await tcs.Task.ConfigureAwait(false); + try + { + await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false); + await tcs.Task.ConfigureAwait(false); + cancelToken.ThrowIfCancellationRequested(); + } + finally + { + _guildMembersRequestTasks.TryRemove(nonce, out _); + } } /// diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 659a7a1b7c..d8f5ea7d25 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -1264,7 +1264,14 @@ public async Task DownloadUsersAsync() /// public async Task DownloadUsersAsync(IEnumerable userIds) { - await Discord.DownloadUsersAsync(this, userIds).ConfigureAwait(false); + using var cts = new CancellationTokenSource(DiscordConfig.DefaultRequestTimeout); + await DownloadUsersAsync(userIds, cts.Token).ConfigureAwait(false); + } + + /// + public async Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken) + { + await Discord.DownloadUsersAsync(this, userIds, cancelToken).ConfigureAwait(false); } internal void CompleteDownloadUsers()