Skip to content

Commit d639418

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Add timeout injection to faulty agent for testing (pytorch#37485)
Summary: Pull Request resolved: pytorch#37485 Adds arbitrary timeout injection to faulty RPC agent. This is to better test scenarios that need information about how long-running RPCs, such as properly testing RPC timeouts and the profiler in all scenarios. This is done by overriding ProcessGroupAgent's `enqueueSend()` function to inject the timeout. Determining which messages to timeout is done similar to the existing `faulty_messages` by having the user specify a mapping of message to timeout. Added unit tests that verify RPC timeouts work with builtin + TorchScript functions, which was not tested before. ghstack-source-id: 103341662 Test Plan: Added unit tests in `FaultyRpcAgentTest`. Differential Revision: D21296537 fbshipit-source-id: 1dbc21aee14e49780272634e9cbb2b5a448f2896
1 parent 707e0e8 commit d639418

File tree

8 files changed

+210
-19
lines changed

8 files changed

+210
-19
lines changed

torch/csrc/distributed/rpc/process_group_agent.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class ProcessGroupAgent : public RpcAgent {
9191
Message&& message,
9292
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
9393

94+
// put SendWork into a queue and notify the worker thread
95+
virtual void enqueueSend(SendWork work);
96+
9497
private:
9598
using steady_clock_time_point =
9699
std::chrono::time_point<std::chrono::steady_clock>;
@@ -145,8 +148,6 @@ class ProcessGroupAgent : public RpcAgent {
145148
};
146149

147150
void collectNames();
148-
// put SendWork into a queue and notify the worker thread
149-
void enqueueSend(SendWork work);
150151
// handle a SendWork request. This serializes the payload inside the work
151152
// object, and sends the message to the receiver using the underlying
152153
// ProcessGroup.

torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ namespace torch {
44
namespace distributed {
55
namespace rpc {
66

7+
namespace {
8+
constexpr auto kSecToMsConversion = 1000;
9+
}
10+
711
std::string fromVec(const std::vector<char>& vec) {
812
return std::string(vec.begin(), vec.end());
913
}
@@ -14,14 +18,16 @@ FaultyProcessGroupAgent::FaultyProcessGroupAgent(
1418
int numSendRecvThreads,
1519
std::chrono::milliseconds rpcTimeout,
1620
const std::vector<std::string>& messagesToFail,
21+
const std::unordered_map<std::string, float>& messageTypesToDelay,
1722
int failNumSends)
1823
: ProcessGroupAgent(
1924
std::move(workerName),
2025
std::move(pg),
2126
numSendRecvThreads,
2227
rpcTimeout),
2328
failNumSends_(failNumSends),
24-
messageTypesToFail_(parseMessagesToFailInput(messagesToFail)) {}
29+
messageTypesToFail_(parseMessagesToFailInput(messagesToFail)),
30+
messageTypesToDelay_(parseMessagesToDelay(messageTypesToDelay)) {}
2531

2632
std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
2733
const std::vector<std::string>& messagesToFail) const {
@@ -30,21 +36,27 @@ std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
3036
// types. We will then check this list of types in the send function to
3137
// determine whether we should fail or not.
3238
std::vector<MessageType> messageTypesToFail;
39+
messageTypesToFail.reserve(messagesToFail.size());
3340
for (const auto& msgString : messagesToFail) {
34-
if (msgString == "RREF_FORK_REQUEST") {
35-
messageTypesToFail.emplace_back(MessageType::RREF_FORK_REQUEST);
36-
} else if (msgString == "RREF_CHILD_ACCEPT") {
37-
messageTypesToFail.emplace_back(MessageType::RREF_CHILD_ACCEPT);
38-
} else if (msgString == "RREF_USER_DELETE") {
39-
messageTypesToFail.emplace_back(MessageType::RREF_USER_DELETE);
40-
} else if (msgString == "CLEANUP_AUTOGRAD_CONTEXT_REQ") {
41-
messageTypesToFail.emplace_back(
42-
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ);
43-
}
41+
messageTypesToFail.push_back(messageStringToType(msgString));
4442
}
4543
return messageTypesToFail;
4644
}
4745

46+
std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
47+
parseMessagesToDelay(const std::unordered_map<std::string, float>&
48+
messageTypesToDelay) const {
49+
std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
50+
for (const auto& messagePair : messageTypesToDelay) {
51+
float delay = messagePair.second;
52+
TORCH_CHECK(
53+
delay >= 0,
54+
"Delays passed to FaultyProcessGroupAgent must be non-negative.")
55+
delayMessages.insert({messageStringToType(messagePair.first), delay});
56+
}
57+
return delayMessages;
58+
}
59+
4860
std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
4961
const WorkerInfo& to,
5062
Message&& message,
@@ -76,13 +88,49 @@ std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
7688
}
7789
}
7890

91+
void FaultyProcessGroupAgent::enqueueSend(SendWork work) {
92+
float msgDelay = getDelayForMessage(work.message_.type());
93+
if (msgDelay != 0) {
94+
// Sleep for the specified delay for the message.
95+
std::this_thread::sleep_for(std::chrono::milliseconds(
96+
static_cast<int>(msgDelay * kSecToMsConversion)));
97+
}
98+
ProcessGroupAgent::enqueueSend(std::move(work));
99+
}
100+
79101
bool FaultyProcessGroupAgent::shouldFailMessage(MessageType type) const {
80102
// Return true if the input message type is in the messageTypesToFail_ list
81103
return (
82104
std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
83105
messageTypesToFail_.end());
84106
}
85107

108+
float FaultyProcessGroupAgent::getDelayForMessage(MessageType type) const {
109+
const auto& it = messageTypesToDelay_.find(type);
110+
return it == messageTypesToDelay_.end() ? 0 : it->second;
111+
}
112+
113+
MessageType FaultyProcessGroupAgent::messageStringToType(
114+
const std::string& messageString) const {
115+
// Lazily constructed map that returns string to message type mapping
116+
static std::unordered_map<std::string, MessageType> msgMap = {
117+
{"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
118+
{"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
119+
{"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
120+
{"CLEANUP_AUTOGRAD_CONTEXT_REQ",
121+
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
122+
{"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
123+
{"PYTHON_CALL", MessageType::PYTHON_CALL},
124+
{"SCRIPT_CALL", MessageType::SCRIPT_CALL},
125+
};
126+
const auto& it = msgMap.find(messageString);
127+
TORCH_CHECK(
128+
it != msgMap.end(),
129+
"No mapping to rpc::MessageType exists for ",
130+
messageString);
131+
return it->second;
132+
}
133+
86134
} // namespace rpc
87135
} // namespace distributed
88136
} // namespace torch

torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,20 @@ struct FaultyProcessGroupRpcBackendOptions
1414
float rpc_timeout,
1515
std::string init_method,
1616
std::vector<std::string> messages_to_fail,
17+
std::unordered_map<std::string, float> messages_to_delay,
1718
int num_fail_sends = 0)
1819
: ProcessGroupRpcBackendOptions(
1920
num_send_recv_threads,
2021
rpc_timeout,
2122
std::move(init_method)),
2223
messagesToFail(std::move(messages_to_fail)),
24+
messagesToDelay(std::move(messages_to_delay)),
2325
numFailSends(num_fail_sends) {
2426
TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
2527
}
2628

2729
std::vector<std::string> messagesToFail;
30+
std::unordered_map<std::string, float> messagesToDelay;
2831
int numFailSends;
2932
};
3033

@@ -36,6 +39,7 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
3639
int numSendRecvThreads,
3740
std::chrono::milliseconds rpcTimeout,
3841
const std::vector<std::string>& messagesToFail,
42+
const std::unordered_map<std::string, float>& messageTypesToDelay,
3943
int failNumSends = 0);
4044

4145
// Faulty send function for this class.
@@ -45,6 +49,9 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
4549
const float rpcTimeoutSeconds =
4650
torch::distributed::rpc::kUnsetRpcTimeout) override;
4751

52+
// Overrides ProcessGroupAgent's enqueueSend to inject delays.
53+
void enqueueSend(SendWork work) override;
54+
4855
protected:
4956
// This function checks the messageTypesToFail_ to determine whether to use
5057
// the faulty send or not.
@@ -56,18 +63,32 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent {
5663
std::vector<MessageType> parseMessagesToFailInput(
5764
const std::vector<std::string>& messagesToFail) const;
5865

66+
// Returns amount of time in seconds to delay sending of the given message
67+
// type.
68+
float getDelayForMessage(MessageType type) const;
69+
70+
// Parse message types that we should inject arbitrary delays for.
71+
std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
72+
const std::unordered_map<std::string, float>& messageTypesToDelay) const;
73+
5974
// Number of sends to intentionally fail before allowing one to succeed.
6075
const int failNumSends_;
6176

6277
// Vector of the MessageTypes that we must use the faulty send for. This is
6378
// parsed based on a list of strings passed in by the python tests.
6479
const std::vector<MessageType> messageTypesToFail_;
6580

81+
// Mapping of message types to amount we should delay send for in the ::send()
82+
// function.
83+
std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
84+
6685
// Map to track the number of sends we've failed for each RPC.
6786
std::unordered_map<std::string, int> failMessageCountMap_;
6887

6988
// Mutex to guard failMessageCountMap_
7089
std::mutex failMapMutex_;
90+
91+
MessageType messageStringToType(const std::string& messageString) const;
7192
};
7293
} // namespace rpc
7394
} // namespace distributed

torch/csrc/distributed/rpc/testing/init.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,28 @@ PyObject* faulty_agent_init(PyObject* /* unused */) {
3636
"FaultyProcessGroupRpcBackendOptions",
3737
rpc_module.attr("ProcessGroupRpcBackendOptions"))
3838
.def(
39-
py::init<int, float, std::string, std::vector<std::string>, int>(),
39+
py::init<
40+
int,
41+
float,
42+
std::string,
43+
std::vector<std::string>,
44+
std::unordered_map<std::string, float>,
45+
int>(),
4046
py::arg("num_send_recv_threads"),
4147
py::arg("rpc_timeout"),
4248
py::arg("init_method"),
4349
py::arg("messages_to_fail"),
50+
py::arg("messages_to_delay"),
4451
py::arg("num_fail_sends"))
4552
.def_readwrite(
4653
"num_send_recv_threads",
4754
&ProcessGroupRpcBackendOptions::numSendRecvThreads)
4855
.def_readwrite(
4956
"messages_to_fail",
5057
&FaultyProcessGroupRpcBackendOptions::messagesToFail)
58+
.def_readwrite(
59+
"messages_to_delay",
60+
&FaultyProcessGroupRpcBackendOptions::messagesToDelay)
5161
.def_readwrite(
5262
"num_fail_sends", &FaultyProcessGroupRpcBackendOptions::numFailSends);
5363

@@ -59,13 +69,15 @@ PyObject* faulty_agent_init(PyObject* /* unused */) {
5969
std::shared_ptr<::c10d::ProcessGroup>,
6070
int,
6171
std::chrono::milliseconds,
62-
std::vector<std::string>,
72+
const std::vector<std::string>&,
73+
const std::unordered_map<std::string, float>&,
6374
int>(),
6475
py::arg("name"),
6576
py::arg("process_group"),
6677
py::arg("num_send_recv_threads"),
6778
py::arg("rpc_timeout"),
6879
py::arg("messages_to_fail"),
80+
py::arg("messages_to_delay"),
6981
py::arg("failNumSends"))
7082
.def(
7183
"join",

torch/distributed/rpc/_testing/faulty_agent_backend_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def _faulty_process_group_construct_rpc_backend_options_handler(
1212
init_method,
1313
num_send_recv_threads,
1414
messages_to_fail,
15+
messages_to_delay,
1516
num_fail_sends,
1617
**kwargs
1718
):
@@ -22,6 +23,7 @@ def _faulty_process_group_construct_rpc_backend_options_handler(
2223
init_method=init_method,
2324
num_send_recv_threads=num_send_recv_threads,
2425
messages_to_fail=messages_to_fail,
26+
messages_to_delay=messages_to_delay,
2527
num_fail_sends=num_fail_sends,
2628
)
2729

@@ -66,6 +68,7 @@ def _faulty_process_group_init_backend_handler(
6668
rpc_backend_options.num_send_recv_threads,
6769
timedelta(seconds=rpc_backend_options.rpc_timeout),
6870
rpc_backend_options.messages_to_fail,
71+
rpc_backend_options.messages_to_delay,
6972
rpc_backend_options.num_fail_sends,
7073
)
7174
except Exception as ex:

torch/testing/_internal/dist_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs):
2828

2929

3030
def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True,
31-
faulty_messages=None):
31+
faulty_messages=None, messages_to_delay=None):
3232
"""
3333
We use this decorator for setting up and tearing down state since
3434
MultiProcessTestCase runs each `test*` method in a separate process and
@@ -54,6 +54,7 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True,
5454
setup_rpc=setup_rpc,
5555
clean_shutdown=clean_shutdown,
5656
faulty_messages=faulty_messages,
57+
messages_to_delay=messages_to_delay,
5758
)
5859

5960
@wraps(old_test_method)
@@ -70,7 +71,7 @@ def new_test_method(self, *arg, **kwargs):
7071
and self.rpc_backend
7172
== rpc.backend_registry.BackendType.FAULTY_PROCESS_GROUP
7273
):
73-
_build_faulty_backend_options(self, faulty_messages)
74+
_build_faulty_backend_options(self, faulty_messages, messages_to_delay)
7475

7576
if setup_rpc:
7677
rpc.init_rpc(
@@ -100,7 +101,7 @@ def new_test_method(self, *arg, **kwargs):
100101
num_send_recv_threads=8,
101102
)
102103

103-
def _build_faulty_backend_options(faulty_agent_fixture, faulty_messages):
104+
def _build_faulty_backend_options(faulty_agent_fixture, faulty_messages, messages_to_delay):
104105
'''
105106
Constructs the backend options object for the faulty process group agent
106107
based on the faulty_messages input to dist_init.
@@ -110,12 +111,18 @@ def _build_faulty_backend_options(faulty_agent_fixture, faulty_messages):
110111
if faulty_messages is not None
111112
else faulty_agent_fixture.retryable_message_types
112113
)
114+
messages_to_delay = (
115+
messages_to_delay
116+
if messages_to_delay is not None
117+
else faulty_agent_fixture.default_messages_to_delay
118+
)
113119
TEST_CONFIG.build_rpc_backend_options = lambda test_object: rpc.backend_registry.construct_rpc_backend_options(
114120
test_object.rpc_backend,
115121
init_method=test_object.init_method,
116122
num_send_recv_threads=8,
117123
num_fail_sends=faulty_agent_fixture.num_fail_sends,
118124
messages_to_fail=messages_to_fail,
125+
messages_to_delay=messages_to_delay,
119126
)
120127

121128

@@ -173,7 +180,7 @@ def get_timeout_error_regex(rpc_backend_name):
173180
should receive when an RPC has timed out. Useful for use with
174181
assertRaisesRegex() to ensure we have the right errors during timeout.
175182
"""
176-
if rpc_backend_name == "PROCESS_GROUP":
183+
if rpc_backend_name in ["PROCESS_GROUP", "FAULTY_PROCESS_GROUP"]:
177184
return "RPC ran for more than"
178185
else:
179186
return "(Timed out)|(Task expired)"

torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
"RREF_USER_DELETE",
1313
"CLEANUP_AUTOGRAD_CONTEXT_REQ"]
1414

15+
# The following messages incur the corresponding delay in seconds while being
16+
# processed in FaultyProcessGroupAgent's enqueueSend() function.
17+
default_messages_to_delay = {
18+
"PYTHON_CALL": 1.5, # Python UDF
19+
"SCRIPT_CALL": 1.5, # Script/Builtin
20+
}
21+
1522
class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
1623
@property
1724
def rpc_backend(self):
@@ -26,3 +33,7 @@ def retryable_message_types(self):
2633
@property
2734
def num_fail_sends(self):
2835
return 3
36+
37+
@property
38+
def default_messages_to_delay(self):
39+
return default_messages_to_delay

0 commit comments

Comments
 (0)