@@ -52,6 +52,88 @@ def parse_triton_config_pbtxt(pbtxt_path) -> ModelConfig:
52
52
raise ValueError (f"Failed to parse config file { pbtxt_path } " ) from e
53
53
54
54
55
+ class TritonRemoteModel :
56
+ """A remote model that is hosted on a Triton Inference Server.
57
+
58
+ Args:
59
+ model_name (str): The name of the model.
60
+ netloc (str): The network location of the Triton Inference Server.
61
+ model_config (ModelConfig): The model config.
62
+ headers (dict): The headers to send to the Triton Inference Server.
63
+ """
64
+
65
+ def __init__ (self , model_name , netloc , model_config , headers = None , ** kwargs ):
66
+ self ._headers = headers
67
+ self ._request_compression_algorithm = None
68
+ self ._response_compression_algorithm = None
69
+ self ._model_name = model_name
70
+ self ._model_version = None
71
+ self ._model_config = model_config
72
+ self ._request_compression_algorithm = None
73
+ self ._response_compression_algorithm = None
74
+ self ._count = 0
75
+
76
+ try :
77
+ self ._triton_client = httpclient .InferenceServerClient (url = netloc , verbose = kwargs .get ("verbose" , False ))
78
+ logging .info (f"Created triton client: { self ._triton_client } " )
79
+ except Exception as e :
80
+ logging .error ("channel creation failed: " + str (e ))
81
+ raise
82
+
83
+ def __call__ (self , data , ** kwds ):
84
+
85
+ self ._count += 1
86
+ logging .info (f"{ self .__class__ .__name__ } .__call__: { self ._model_name } count: { self ._count } " )
87
+
88
+ inputs = []
89
+ outputs = []
90
+
91
+ # For now support only one input and one output
92
+ input_name = self ._model_config .input [0 ].name
93
+ input_type = str .split (DataType .Name (self ._model_config .input [0 ].data_type ), "_" )[1 ] # remove the prefix
94
+ input_shape = list (self ._model_config .input [0 ].dims )
95
+ data_shape = list (data .shape )
96
+ logging .info (f"Model config input data shape: { input_shape } " )
97
+ logging .info (f"Actual input data shape: { data_shape } " )
98
+
99
+ # The server side will handle the batching, and with dynamic batching
100
+ # the model config does not have the batch size in the input dims.
101
+ logging .info (f"Effective input_name: { input_name } , input_type: { input_type } , input_shape: { data_shape } " )
102
+
103
+ inputs .append (httpclient .InferInput (input_name , data_shape , input_type ))
104
+
105
+ # Move to tensor to CPU
106
+ input0_data_np = data .detach ().cpu ().numpy ()
107
+ logging .debug (f"Input data shape: { input0_data_np .shape } " )
108
+
109
+ # Initialize the data
110
+ inputs [0 ].set_data_from_numpy (input0_data_np , binary_data = False )
111
+
112
+ output_name = self ._model_config .output [0 ].name
113
+ outputs .append (httpclient .InferRequestedOutput (output_name , binary_data = True ))
114
+
115
+ query_params = {f"{ self ._model_name } _count" : self ._count }
116
+ results = self ._triton_client .infer (
117
+ self ._model_name ,
118
+ inputs ,
119
+ outputs = outputs ,
120
+ query_params = query_params ,
121
+ headers = self ._headers ,
122
+ request_compression_algorithm = self ._request_compression_algorithm ,
123
+ response_compression_algorithm = self ._response_compression_algorithm ,
124
+ )
125
+
126
+ logging .info (f"Got results{ results .get_response ()} " )
127
+ output0_data = results .as_numpy (output_name )
128
+ logging .debug (f"as_numpy output0_data.shape: { output0_data .shape } " )
129
+ logging .debug (f"as_numpy output0_data.dtype: { output0_data .dtype } " )
130
+
131
+ # Convert numpy array to torch tensor as expected by the anticipated clients,
132
+ # e.g. monai cliding window inference
133
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
134
+ return torch .as_tensor (output0_data ).to (device ) # from_numpy is fine too.
135
+
136
+
55
137
class TritonModel (Model ):
56
138
"""Represents Triton models in the model repository.
57
139
@@ -124,9 +206,7 @@ def __init__(self, path: str, name: str = ""):
124
206
f"Model name in config.pbtxt ({ self ._model_config .name } ) does not match the folder name ({ self ._name } )."
125
207
)
126
208
127
- self ._netloc = None # network location of the Triton Inference Server
128
- self ._predictor = None # triton remote client
129
-
209
+ self ._netloc : str = ""
130
210
logging .info (f"Created Triton model: { self ._name } " )
131
211
132
212
def connect (self , netloc : str , ** kwargs ):
@@ -137,36 +217,51 @@ def connect(self, netloc: str, **kwargs):
137
217
"""
138
218
139
219
if not netloc :
140
- if not self ._predictor :
141
- raise ValueError ("Network location is required to connect to the Triton Inference Server." )
142
- else :
143
- logging .warning ("No network location provided, using the last connected network location." )
220
+ raise ValueError ("Network location is required to connect to the Triton Inference Server." )
144
221
145
- if self ._predictor and not self ._netloc .casefold () == netloc .casefold ():
222
+ if self ._netloc and not self ._netloc .casefold () == netloc .casefold ():
146
223
logging .warning (f"Reconnecting to a different Triton Inference Server at { netloc } from { self ._netloc } ." )
147
224
225
+ self ._predictor = TritonRemoteModel (self ._name , netloc , self ._model_config , ** kwargs )
148
226
self ._netloc = netloc
149
- self . _predictor = TritonRemoteModel ( self . _name , self . _netloc , self . _model_config , ** kwargs )
227
+
150
228
return self ._predictor
151
229
152
230
@property
153
231
def model_config (self ):
154
- if not self ._model_config : # not expected to happen with the current implementation.
155
- self ._model_config = parse_triton_config_pbtxt (self ._model_path / "config.pbtxt" )
156
232
return self ._model_config
157
233
158
234
@property
159
- def predictor (self ):
160
- """Get the model's predictor (triton remote client)
235
+ def net_loc (self ):
236
+ """Get the network location of the Triton Inference Server, i.e. "<host>:<port>".
161
237
162
238
Returns:
163
- the model's predictor
239
+ str: The network location of the Triton Inference Server.
164
240
"""
165
- if self ._predictor is None :
166
- raise ValueError ("Model is not connected to the Triton Inference Server." )
167
241
242
+ return self ._netloc
243
+
244
+ @net_loc .setter
245
+ def net_loc (self , value : str ):
246
+ """Set the network location of the Triton Inference Server, and causes re-connect."""
247
+ if not value :
248
+ raise ValueError ("Network location cannot be empty." )
249
+ self ._netloc = value
250
+ # Reconnect to the Triton Inference Server at the new network location.
251
+ self .connect (value )
252
+
253
+ @property
254
+ def predictor (self ):
255
+ if not self ._predictor :
256
+ raise ValueError ("Model is not connected to the Triton Inference Server." )
168
257
return self ._predictor
169
258
259
+ @predictor .setter
260
+ def predictor (self , predictor : TritonRemoteModel ):
261
+ if not isinstance (predictor , TritonRemoteModel ):
262
+ raise ValueError ("Predictor must be an instance of TritonRemoteModel." )
263
+ self ._predictor = predictor
264
+
170
265
@classmethod
171
266
def accept (cls , path : str ) -> tuple [bool , str ]:
172
267
model_folder : Path = Path (path )
@@ -195,85 +290,3 @@ def accept(cls, path: str) -> tuple[bool, str]:
195
290
logging .info (f"Model { model_folder .name } only has config.pbtxt in client workspace." )
196
291
197
292
return True , cls .model_type
198
-
199
-
200
- class TritonRemoteModel :
201
- """A remote model that is hosted on a Triton Inference Server.
202
-
203
- Args:
204
- model_name (str): The name of the model.
205
- netloc (str): The network location of the Triton Inference Server.
206
- model_config (ModelConfig): The model config.
207
- headers (dict): The headers to send to the Triton Inference Server.
208
- """
209
-
210
- def __init__ (self , model_name , netloc , model_config , headers = None , ** kwargs ):
211
- self ._headers = headers
212
- self ._request_compression_algorithm = None
213
- self ._response_compression_algorithm = None
214
- self ._model_name = model_name
215
- self ._model_version = None
216
- self ._model_config = model_config
217
- self ._request_compression_algorithm = None
218
- self ._response_compression_algorithm = None
219
- self ._count = 0
220
-
221
- try :
222
- self ._triton_client = httpclient .InferenceServerClient (url = netloc , verbose = kwargs .get ("verbose" , False ))
223
- print (f"Created triton client: { self ._triton_client } " )
224
- except Exception as e :
225
- logging .error ("channel creation failed: " + str (e ))
226
- raise
227
-
228
- def __call__ (self , data , ** kwds ):
229
-
230
- self ._count += 1
231
- logging .info (f"{ self .__class__ .__name__ } .__call__: { self ._model_name } count: { self ._count } " )
232
-
233
- inputs = []
234
- outputs = []
235
-
236
- # For now support only one input and one output
237
- input_name = self ._model_config .input [0 ].name
238
- input_type = str .split (DataType .Name (self ._model_config .input [0 ].data_type ), "_" )[1 ] # remove the prefix
239
- input_shape = list (self ._model_config .input [0 ].dims )
240
- data_shape = list (data .shape )
241
- logging .info (f"Model config input data shape: { input_shape } " )
242
- logging .info (f"Actual input data shape: { data_shape } " )
243
-
244
- # The server side will handle the batching, and with dynamic batching
245
- # the model config does not have the batch size in the input dims.
246
- logging .info (f"Effective input_name: { input_name } , input_type: { input_type } , input_shape: { data_shape } " )
247
-
248
- inputs .append (httpclient .InferInput (input_name , data_shape , input_type ))
249
-
250
- # Move to tensor to CPU
251
- input0_data_np = data .detach ().cpu ().numpy ()
252
- logging .debug (f"Input data shape: { input0_data_np .shape } " )
253
-
254
- # Initialize the data
255
- inputs [0 ].set_data_from_numpy (input0_data_np , binary_data = False )
256
-
257
- output_name = self ._model_config .output [0 ].name
258
- outputs .append (httpclient .InferRequestedOutput (output_name , binary_data = True ))
259
-
260
- query_params = {f"{ self ._model_name } _count" : self ._count }
261
- results = self ._triton_client .infer (
262
- self ._model_name ,
263
- inputs ,
264
- outputs = outputs ,
265
- query_params = query_params ,
266
- headers = self ._headers ,
267
- request_compression_algorithm = self ._request_compression_algorithm ,
268
- response_compression_algorithm = self ._response_compression_algorithm ,
269
- )
270
-
271
- logging .info (f"Got results{ results .get_response ()} " )
272
- output0_data = results .as_numpy (output_name )
273
- logging .debug (f"as_numpy output0_data.shape: { output0_data .shape } " )
274
- logging .debug (f"as_numpy output0_data.dtype: { output0_data .dtype } " )
275
-
276
- # Convert numpy array to torch tensor as expected by the anticipated clients,
277
- # e.g. monai cliding window inference
278
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
279
- return torch .as_tensor (output0_data ).to (device ) # from_numpy is fine too.
0 commit comments