Skip to content

Commit 43043bf

Browse files
committed
Addressed all pytype and mypy complaint in new code in the dev env
Signed-off-by: M Q <mingmelvinq@nvidia.com>
1 parent 24be6f4 commit 43043bf

File tree

5 files changed

+115
-104
lines changed

5 files changed

+115
-104
lines changed

monai/deploy/core/app_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def update(self, args: Dict[str, str]):
6363

6464
# TritonModel instances are just clients and must be connected to the Triton Inference Server
6565
# at the provided network location. In-process hosting of Triton Inference Server is not supported.
66-
if self.triton_server_netloc:
66+
if self.triton_server_netloc and self.models:
6767
for _, model in self.models.items():
6868
if isinstance(model, TritonModel):
6969
model.connect(self.triton_server_netloc, verbose=args.get("log_level", "INFO") == "DEBUG")

monai/deploy/core/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@
2323
from .model import Model
2424
from .named_model import NamedModel
2525
from .torch_model import TorchScriptModel
26-
from .triton_model import TritonModel
26+
from .triton_model import TritonModel, TritonRemoteModel
2727

2828
Model.register([TritonModel, NamedModel, TorchScriptModel, Model])

monai/deploy/core/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, path: str, name: str = ""):
8181
else:
8282
self._name = Path(path).stem
8383

84-
self._predictor = None
84+
self._predictor: Any = None
8585

8686
# Add self to the list of models
8787
self._items: Dict[str, Model] = {self.name: self}

monai/deploy/core/models/named_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
import logging
1313
from pathlib import Path
1414

15-
from monai.deploy.core.models import ModelFactory
16-
17-
from .model import Model
15+
from monai.deploy.core.models import Model, ModelFactory
1816

1917
logger = logging.getLogger(__name__)
2018

monai/deploy/core/models/triton_model.py

Lines changed: 111 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,88 @@ def parse_triton_config_pbtxt(pbtxt_path) -> ModelConfig:
5252
raise ValueError(f"Failed to parse config file {pbtxt_path}") from e
5353

5454

55+
class TritonRemoteModel:
56+
"""A remote model that is hosted on a Triton Inference Server.
57+
58+
Args:
59+
model_name (str): The name of the model.
60+
netloc (str): The network location of the Triton Inference Server.
61+
model_config (ModelConfig): The model config.
62+
headers (dict): The headers to send to the Triton Inference Server.
63+
"""
64+
65+
def __init__(self, model_name, netloc, model_config, headers=None, **kwargs):
66+
self._headers = headers
67+
self._request_compression_algorithm = None
68+
self._response_compression_algorithm = None
69+
self._model_name = model_name
70+
self._model_version = None
71+
self._model_config = model_config
72+
self._request_compression_algorithm = None
73+
self._response_compression_algorithm = None
74+
self._count = 0
75+
76+
try:
77+
self._triton_client = httpclient.InferenceServerClient(url=netloc, verbose=kwargs.get("verbose", False))
78+
logging.info(f"Created triton client: {self._triton_client}")
79+
except Exception as e:
80+
logging.error("channel creation failed: " + str(e))
81+
raise
82+
83+
def __call__(self, data, **kwds):
84+
85+
self._count += 1
86+
logging.info(f"{self.__class__.__name__}.__call__: {self._model_name} count: {self._count}")
87+
88+
inputs = []
89+
outputs = []
90+
91+
# For now support only one input and one output
92+
input_name = self._model_config.input[0].name
93+
input_type = str.split(DataType.Name(self._model_config.input[0].data_type), "_")[1] # remove the prefix
94+
input_shape = list(self._model_config.input[0].dims)
95+
data_shape = list(data.shape)
96+
logging.info(f"Model config input data shape: {input_shape}")
97+
logging.info(f"Actual input data shape: {data_shape}")
98+
99+
# The server side will handle the batching, and with dynamic batching
100+
# the model config does not have the batch size in the input dims.
101+
logging.info(f"Effective input_name: {input_name}, input_type: {input_type}, input_shape: {data_shape}")
102+
103+
inputs.append(httpclient.InferInput(input_name, data_shape, input_type))
104+
105+
# Move to tensor to CPU
106+
input0_data_np = data.detach().cpu().numpy()
107+
logging.debug(f"Input data shape: {input0_data_np.shape}")
108+
109+
# Initialize the data
110+
inputs[0].set_data_from_numpy(input0_data_np, binary_data=False)
111+
112+
output_name = self._model_config.output[0].name
113+
outputs.append(httpclient.InferRequestedOutput(output_name, binary_data=True))
114+
115+
query_params = {f"{self._model_name}_count": self._count}
116+
results = self._triton_client.infer(
117+
self._model_name,
118+
inputs,
119+
outputs=outputs,
120+
query_params=query_params,
121+
headers=self._headers,
122+
request_compression_algorithm=self._request_compression_algorithm,
123+
response_compression_algorithm=self._response_compression_algorithm,
124+
)
125+
126+
logging.info(f"Got results{results.get_response()}")
127+
output0_data = results.as_numpy(output_name)
128+
logging.debug(f"as_numpy output0_data.shape: {output0_data.shape}")
129+
logging.debug(f"as_numpy output0_data.dtype: {output0_data.dtype}")
130+
131+
# Convert numpy array to torch tensor as expected by the anticipated clients,
132+
# e.g. monai cliding window inference
133+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134+
return torch.as_tensor(output0_data).to(device) # from_numpy is fine too.
135+
136+
55137
class TritonModel(Model):
56138
"""Represents Triton models in the model repository.
57139
@@ -124,9 +206,7 @@ def __init__(self, path: str, name: str = ""):
124206
f"Model name in config.pbtxt ({self._model_config.name}) does not match the folder name ({self._name})."
125207
)
126208

127-
self._netloc = None # network location of the Triton Inference Server
128-
self._predictor = None # triton remote client
129-
209+
self._netloc: str = ""
130210
logging.info(f"Created Triton model: {self._name}")
131211

132212
def connect(self, netloc: str, **kwargs):
@@ -137,36 +217,51 @@ def connect(self, netloc: str, **kwargs):
137217
"""
138218

139219
if not netloc:
140-
if not self._predictor:
141-
raise ValueError("Network location is required to connect to the Triton Inference Server.")
142-
else:
143-
logging.warning("No network location provided, using the last connected network location.")
220+
raise ValueError("Network location is required to connect to the Triton Inference Server.")
144221

145-
if self._predictor and not self._netloc.casefold() == netloc.casefold():
222+
if self._netloc and not self._netloc.casefold() == netloc.casefold():
146223
logging.warning(f"Reconnecting to a different Triton Inference Server at {netloc} from {self._netloc}.")
147224

225+
self._predictor = TritonRemoteModel(self._name, netloc, self._model_config, **kwargs)
148226
self._netloc = netloc
149-
self._predictor = TritonRemoteModel(self._name, self._netloc, self._model_config, **kwargs)
227+
150228
return self._predictor
151229

152230
@property
153231
def model_config(self):
154-
if not self._model_config: # not expected to happen with the current implementation.
155-
self._model_config = parse_triton_config_pbtxt(self._model_path / "config.pbtxt")
156232
return self._model_config
157233

158234
@property
159-
def predictor(self):
160-
"""Get the model's predictor (triton remote client)
235+
def net_loc(self):
236+
"""Get the network location of the Triton Inference Server, i.e. "<host>:<port>".
161237
162238
Returns:
163-
the model's predictor
239+
str: The network location of the Triton Inference Server.
164240
"""
165-
if self._predictor is None:
166-
raise ValueError("Model is not connected to the Triton Inference Server.")
167241

242+
return self._netloc
243+
244+
@net_loc.setter
245+
def net_loc(self, value: str):
246+
"""Set the network location of the Triton Inference Server, and causes re-connect."""
247+
if not value:
248+
raise ValueError("Network location cannot be empty.")
249+
self._netloc = value
250+
# Reconnect to the Triton Inference Server at the new network location.
251+
self.connect(value)
252+
253+
@property
254+
def predictor(self):
255+
if not self._predictor:
256+
raise ValueError("Model is not connected to the Triton Inference Server.")
168257
return self._predictor
169258

259+
@predictor.setter
260+
def predictor(self, predictor: TritonRemoteModel):
261+
if not isinstance(predictor, TritonRemoteModel):
262+
raise ValueError("Predictor must be an instance of TritonRemoteModel.")
263+
self._predictor = predictor
264+
170265
@classmethod
171266
def accept(cls, path: str) -> tuple[bool, str]:
172267
model_folder: Path = Path(path)
@@ -195,85 +290,3 @@ def accept(cls, path: str) -> tuple[bool, str]:
195290
logging.info(f"Model {model_folder.name} only has config.pbtxt in client workspace.")
196291

197292
return True, cls.model_type
198-
199-
200-
class TritonRemoteModel:
201-
"""A remote model that is hosted on a Triton Inference Server.
202-
203-
Args:
204-
model_name (str): The name of the model.
205-
netloc (str): The network location of the Triton Inference Server.
206-
model_config (ModelConfig): The model config.
207-
headers (dict): The headers to send to the Triton Inference Server.
208-
"""
209-
210-
def __init__(self, model_name, netloc, model_config, headers=None, **kwargs):
211-
self._headers = headers
212-
self._request_compression_algorithm = None
213-
self._response_compression_algorithm = None
214-
self._model_name = model_name
215-
self._model_version = None
216-
self._model_config = model_config
217-
self._request_compression_algorithm = None
218-
self._response_compression_algorithm = None
219-
self._count = 0
220-
221-
try:
222-
self._triton_client = httpclient.InferenceServerClient(url=netloc, verbose=kwargs.get("verbose", False))
223-
print(f"Created triton client: {self._triton_client}")
224-
except Exception as e:
225-
logging.error("channel creation failed: " + str(e))
226-
raise
227-
228-
def __call__(self, data, **kwds):
229-
230-
self._count += 1
231-
logging.info(f"{self.__class__.__name__}.__call__: {self._model_name} count: {self._count}")
232-
233-
inputs = []
234-
outputs = []
235-
236-
# For now support only one input and one output
237-
input_name = self._model_config.input[0].name
238-
input_type = str.split(DataType.Name(self._model_config.input[0].data_type), "_")[1] # remove the prefix
239-
input_shape = list(self._model_config.input[0].dims)
240-
data_shape = list(data.shape)
241-
logging.info(f"Model config input data shape: {input_shape}")
242-
logging.info(f"Actual input data shape: {data_shape}")
243-
244-
# The server side will handle the batching, and with dynamic batching
245-
# the model config does not have the batch size in the input dims.
246-
logging.info(f"Effective input_name: {input_name}, input_type: {input_type}, input_shape: {data_shape}")
247-
248-
inputs.append(httpclient.InferInput(input_name, data_shape, input_type))
249-
250-
# Move to tensor to CPU
251-
input0_data_np = data.detach().cpu().numpy()
252-
logging.debug(f"Input data shape: {input0_data_np.shape}")
253-
254-
# Initialize the data
255-
inputs[0].set_data_from_numpy(input0_data_np, binary_data=False)
256-
257-
output_name = self._model_config.output[0].name
258-
outputs.append(httpclient.InferRequestedOutput(output_name, binary_data=True))
259-
260-
query_params = {f"{self._model_name}_count": self._count}
261-
results = self._triton_client.infer(
262-
self._model_name,
263-
inputs,
264-
outputs=outputs,
265-
query_params=query_params,
266-
headers=self._headers,
267-
request_compression_algorithm=self._request_compression_algorithm,
268-
response_compression_algorithm=self._response_compression_algorithm,
269-
)
270-
271-
logging.info(f"Got results{results.get_response()}")
272-
output0_data = results.as_numpy(output_name)
273-
logging.debug(f"as_numpy output0_data.shape: {output0_data.shape}")
274-
logging.debug(f"as_numpy output0_data.dtype: {output0_data.dtype}")
275-
276-
# Convert numpy array to torch tensor as expected by the anticipated clients,
277-
# e.g. monai cliding window inference
278-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
279-
return torch.as_tensor(output0_data).to(device) # from_numpy is fine too.

0 commit comments

Comments
 (0)