Skip to content

Commit 0844f17

Browse files
authored
Support remote inference on Triton Inference Server with ease of use (#536)
* Adding requirements for Triton client impl Signed-off-by: M Q <mingmelvinq@nvidia.com> * Updated/added core classes to support Triton remote inference, and added a new example Signed-off-by: M Q <mingmelvinq@nvidia.com> * GitHub build server complains about conflicts for tritonclient[]>=2.54 for no specific reasons Signed-off-by: M Q <mingmelvinq@nvidia.com> * Fix flake8 complaints Signed-off-by: M Q <mingmelvinq@nvidia.com> * Fix pytype complaints by simplifying code Signed-off-by: M Q <mingmelvinq@nvidia.com> * Remove now unused imports Signed-off-by: M Q <mingmelvinq@nvidia.com> * Addressed all pytype and mypy complaint in new code in the dev env Signed-off-by: M Q <mingmelvinq@nvidia.com> * No complaint in local dev env, but on GitHub Signed-off-by: M Q <mingmelvinq@nvidia.com> * Add model confgi.pbtxt and example env settings Signed-off-by: M Q <mingmelvinq@nvidia.com> * Doc update Signed-off-by: M Q <mingmelvinq@nvidia.com> * update license dates Signed-off-by: M Q <mingmelvinq@nvidia.com> * Updated the copyright year of new files Signed-off-by: M Q <mingmelvinq@nvidia.com> --------- Signed-off-by: M Q <mingmelvinq@nvidia.com>
1 parent 7e0a610 commit 0844f17

File tree

14 files changed

+666
-69
lines changed

14 files changed

+666
-69
lines changed

docs/source/getting_started/examples.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
- dicom_series_to_image_app
1515
- breast_density_classifer_app
1616
- cchmc_ped_abd_ct_seg_app
17+
- ai_remote_infer_app
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from app import AIRemoteInferSpleenSegApp
13+
14+
if __name__ == "__main__":
15+
AIRemoteInferSpleenSegApp().run()
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
import logging
12+
from pathlib import Path
13+
14+
from pydicom.sr.codedict import codes # Required for setting SegmentDescription attributes.
15+
from spleen_seg_operator import SpleenSegOperator
16+
17+
from monai.deploy.conditions import CountCondition
18+
from monai.deploy.core import Application
19+
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
20+
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
21+
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
22+
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
23+
from monai.deploy.operators.stl_conversion_operator import STLConversionOperator
24+
25+
26+
class AIRemoteInferSpleenSegApp(Application):
27+
def __init__(self, *args, **kwargs):
28+
"""Creates an application instance."""
29+
30+
super().__init__(*args, **kwargs)
31+
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
32+
33+
def run(self, *args, **kwargs):
34+
# This method calls the base class to run. Can be omitted if simply calling through.
35+
self._logger.info(f"Begin {self.run.__name__}")
36+
super().run(*args, **kwargs)
37+
self._logger.info(f"End {self.run.__name__}")
38+
39+
def compose(self):
40+
"""Creates the app specific operators and chain them up in the processing DAG."""
41+
42+
# Use Commandline options over environment variables to init context.
43+
app_context = Application.init_app_context(self.argv)
44+
self._logger.debug(f"Begin {self.compose.__name__}")
45+
app_input_path = Path(app_context.input_path)
46+
app_output_path = Path(app_context.output_path)
47+
model_path = Path(app_context.model_path)
48+
49+
self._logger.info(f"App input and output path: {app_input_path}, {app_output_path}")
50+
51+
# instantiates the SDK built-in operator(s).
52+
study_loader_op = DICOMDataLoaderOperator(
53+
self, CountCondition(self, 1), input_folder=app_input_path, name="dcm_loader_op"
54+
)
55+
series_selector_op = DICOMSeriesSelectorOperator(self, rules=Sample_Rules_Text, name="series_selector_op")
56+
series_to_vol_op = DICOMSeriesToVolumeOperator(self, name="series_to_vol_op")
57+
58+
# Model specific inference operator, supporting MONAI transforms.
59+
spleen_seg_op = SpleenSegOperator(
60+
self, app_context=app_context, model_name="spleen_ct", model_path=model_path, name="seg_op"
61+
)
62+
63+
# Create DICOM Seg writer providing the required segment description for each segment with
64+
# the actual algorithm and the pertinent organ/tissue.
65+
# The segment_label, algorithm_name, and algorithm_version are limited to 64 chars.
66+
# https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html
67+
# User can Look up SNOMED CT codes at, e.g.
68+
# https://bioportal.bioontology.org/ontologies/SNOMEDCT
69+
70+
_algorithm_name = "3D segmentation of the Spleen from a CT series"
71+
_algorithm_family = codes.DCM.ArtificialIntelligence
72+
_algorithm_version = "0.1.0"
73+
74+
segment_descriptions = [
75+
SegmentDescription(
76+
segment_label="Spleen",
77+
segmented_property_category=codes.SCT.Organ,
78+
segmented_property_type=codes.SCT.Spleen,
79+
algorithm_name=_algorithm_name,
80+
algorithm_family=_algorithm_family,
81+
algorithm_version=_algorithm_version,
82+
),
83+
]
84+
85+
custom_tags = {"SeriesDescription": "AI generated Seg, not for clinical use."}
86+
87+
dicom_seg_writer = DICOMSegmentationWriterOperator(
88+
self,
89+
segment_descriptions=segment_descriptions,
90+
custom_tags=custom_tags,
91+
output_folder=app_output_path,
92+
name="dcm_seg_writer_op",
93+
)
94+
95+
# Create the processing pipeline, by specifying the source and destination operators, and
96+
# ensuring the output from the former matches the input of the latter, in both name and type.
97+
self.add_flow(study_loader_op, series_selector_op, {("dicom_study_list", "dicom_study_list")})
98+
self.add_flow(
99+
series_selector_op, series_to_vol_op, {("study_selected_series_list", "study_selected_series_list")}
100+
)
101+
self.add_flow(series_to_vol_op, spleen_seg_op, {("image", "image")})
102+
103+
# Note below the dicom_seg_writer requires two inputs, each coming from a source operator.
104+
self.add_flow(
105+
series_selector_op, dicom_seg_writer, {("study_selected_series_list", "study_selected_series_list")}
106+
)
107+
self.add_flow(spleen_seg_op, dicom_seg_writer, {("seg_image", "seg_image")})
108+
109+
# Create the surface mesh STL conversion operator and add it to the app execution flow, if needed, by
110+
# uncommenting the following couple lines.
111+
stl_conversion_op = STLConversionOperator(
112+
self, output_file=app_output_path.joinpath("stl/spleen.stl"), name="stl_conversion_op"
113+
)
114+
self.add_flow(spleen_seg_op, stl_conversion_op, {("pred", "image")})
115+
116+
self._logger.debug(f"End {self.compose.__name__}")
117+
118+
119+
# This is a sample series selection rule in JSON, simply selecting CT series.
120+
# If the study has more than 1 CT series, then all of them will be selected.
121+
# Please see more detail in DICOMSeriesSelectorOperator.
122+
# For list of string values, e.g. "ImageType": ["PRIMARY", "ORIGINAL"], it is a match if all elements
123+
# are all in the multi-value attribute of the DICOM series.
124+
125+
Sample_Rules_Text = """
126+
{
127+
"selections": [
128+
{
129+
"name": "CT Series",
130+
"conditions": {
131+
"StudyDescription": "(.*?)",
132+
"Modality": "(?i)CT",
133+
"SeriesDescription": "(.*?)",
134+
"ImageType": ["PRIMARY", "ORIGINAL"]
135+
}
136+
}
137+
]
138+
}
139+
"""
140+
141+
if __name__ == "__main__":
142+
# Creates the app and test it standalone.
143+
AIRemoteInferSpleenSegApp().run()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2025 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
#!/bin/bash
13+
export HOLOSCAN_INPUT_PATH="inputs/spleen_ct_tcia"
14+
export HOLOSCAN_MODEL_PATH="examples/apps/ai_remote_infer_app/models_client_side"
15+
export HOLOSCAN_OUTPUT_PATH="output_spleen"
16+
export HOLOSCAN_LOG_LEVEL=DEBUG # TRACE can be used for verbose low-level logging
17+
export TRITON_SERVER_NETLOC="localhost:8000" # Triton server network location, host:port
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
platform: "pytorch_libtorch"
2+
3+
max_batch_size: 16 # The maximum batch size. 0 for no batching with full shape in dims
4+
5+
default_model_filename: "model_spleen_ct_segmentation_v1.ts" # The name of the TorchScript model file
6+
7+
input [
8+
{
9+
name: "INPUT_0" # The name of the input tensor (or should match the input tensor name in your model if used)
10+
data_type: TYPE_FP32 # Data type is FP32
11+
dims: [ 1, 96, 96, 96 ] # Input dimensions: [channels, width, height, depth], to be stacked as a batch
12+
}
13+
]
14+
15+
output [
16+
{
17+
name: "OUTPUT_0" # The name of the output tensor (match this with your TorchScript model's output name)
18+
data_type: TYPE_FP32 # Output is FP32
19+
dims: [ 2, 96, 96, 96 ] # Output dimensions: [channels, width, height, depth], stacked to match input batch size
20+
}
21+
]
22+
23+
version_policy: { latest: { num_versions: 1}} # Only serve the latest version, which is the default
24+
25+
instance_group [
26+
{
27+
kind: KIND_GPU # Specify the hardware type (GPU in this case)
28+
count: 1 # Number of instances created for each GPU listed in 'gpus' (adjust based on your resources)
29+
}
30+
]
31+
32+
dynamic_batching {
33+
preferred_batch_size: [ 4, 8, 16 ] # Preferred batch size(s) for dynamic batching. Matching the max_batch_size for sync calls.
34+
max_queue_delay_microseconds: 1000 # Max delay before processing the batch.
35+
}
36+
37+
# The initial calls to a loaded TorchScript model take extremely long.
38+
# Due to this longer model warmup issue, Triton allows execution of models without these optimizations.
39+
parameters: {
40+
key: "DISABLE_OPTIMIZED_EXECUTION"
41+
value: {
42+
string_value: "true"
43+
}
44+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2025 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
from pathlib import Path
14+
15+
from numpy import uint8
16+
17+
from monai.deploy.core import AppContext, ConditionType, Fragment, Operator, OperatorSpec
18+
from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator
19+
from monai.transforms import (
20+
Activationsd,
21+
AsDiscreted,
22+
Compose,
23+
EnsureChannelFirstd,
24+
EnsureTyped,
25+
Invertd,
26+
LoadImaged,
27+
Orientationd,
28+
SaveImaged,
29+
ScaleIntensityRanged,
30+
Spacingd,
31+
)
32+
33+
34+
class SpleenSegOperator(Operator):
35+
"""Performs Spleen segmentation with a 3D image converted from a DICOM CT series."""
36+
37+
DEFAULT_OUTPUT_FOLDER = Path.cwd() / "output/saved_images_folder"
38+
39+
def __init__(
40+
self,
41+
fragment: Fragment,
42+
*args,
43+
app_context: AppContext,
44+
model_path: Path,
45+
model_name: str,
46+
output_folder: Path = DEFAULT_OUTPUT_FOLDER,
47+
**kwargs,
48+
):
49+
50+
self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
51+
self._input_dataset_key = "image"
52+
self._pred_dataset_key = "pred"
53+
54+
self.model_path = model_path
55+
self.model_name = model_name
56+
self.output_folder = output_folder
57+
self.output_folder.mkdir(parents=True, exist_ok=True)
58+
self.app_context = app_context
59+
self.input_name_image = "image"
60+
self.output_name_seg = "seg_image"
61+
self.output_name_saved_images_folder = "saved_images_folder"
62+
63+
# The base class has an attribute called fragment to hold the reference to the fragment object
64+
super().__init__(fragment, *args, **kwargs)
65+
66+
def setup(self, spec: OperatorSpec):
67+
spec.input(self.input_name_image)
68+
spec.output(self.output_name_seg)
69+
spec.output(self.output_name_saved_images_folder).condition(
70+
ConditionType.NONE
71+
) # Output not requiring a receiver
72+
73+
def compute(self, op_input, op_output, context):
74+
input_image = op_input.receive(self.input_name_image)
75+
if not input_image:
76+
raise ValueError("Input image is not found.")
77+
78+
# This operator gets an in-memory Image object, so a specialized ImageReader is needed.
79+
_reader = InMemImageReader(input_image)
80+
81+
pre_transforms = self.pre_process(_reader, str(self.output_folder))
82+
post_transforms = self.post_process(pre_transforms, str(self.output_folder))
83+
84+
# Delegates inference and saving output to the built-in operator.
85+
infer_operator = MonaiSegInferenceOperator(
86+
self.fragment,
87+
roi_size=(
88+
96,
89+
96,
90+
96,
91+
),
92+
pre_transforms=pre_transforms,
93+
post_transforms=post_transforms,
94+
overlap=0.6,
95+
app_context=self.app_context,
96+
model_name=self.model_name,
97+
inferer=InfererType.SLIDING_WINDOW,
98+
sw_batch_size=4,
99+
model_path=self.model_path,
100+
name="monai_seg_remote_inference_op",
101+
)
102+
103+
# Setting the keys used in the dictionary based transforms may change.
104+
infer_operator.input_dataset_key = self._input_dataset_key
105+
infer_operator.pred_dataset_key = self._pred_dataset_key
106+
107+
# Now emit data to the output ports of this operator
108+
op_output.emit(infer_operator.compute_impl(input_image, context), self.output_name_seg)
109+
op_output.emit(self.output_folder, self.output_name_saved_images_folder)
110+
111+
def pre_process(self, img_reader, out_dir: str = "./input_images") -> Compose:
112+
"""Composes transforms for preprocessing input before predicting on a model."""
113+
114+
Path(out_dir).mkdir(parents=True, exist_ok=True)
115+
my_key = self._input_dataset_key
116+
117+
return Compose(
118+
[
119+
LoadImaged(keys=my_key, reader=img_reader),
120+
EnsureChannelFirstd(keys=my_key),
121+
# The SaveImaged transform can be commented out to save 5 seconds.
122+
# Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
123+
SaveImaged(
124+
keys=my_key,
125+
output_dir=out_dir,
126+
output_postfix="",
127+
resample=False,
128+
output_ext=".nii",
129+
),
130+
Orientationd(keys=my_key, axcodes="RAS"),
131+
Spacingd(keys=my_key, pixdim=[1.5, 1.5, 2.9], mode=["bilinear"]),
132+
ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
133+
EnsureTyped(keys=my_key),
134+
]
135+
)
136+
137+
def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
138+
"""Composes transforms for postprocessing the prediction results."""
139+
140+
Path(out_dir).mkdir(parents=True, exist_ok=True)
141+
pred_key = self._pred_dataset_key
142+
143+
return Compose(
144+
[
145+
Activationsd(keys=pred_key, softmax=True),
146+
Invertd(
147+
keys=pred_key,
148+
transform=pre_transforms,
149+
orig_keys=self._input_dataset_key,
150+
nearest_interp=False,
151+
to_tensor=True,
152+
),
153+
AsDiscreted(keys=pred_key, argmax=True),
154+
# The SaveImaged transform can be commented out to save 5 seconds.
155+
# Uncompress NIfTI file, nii, is used favoring speed over size, but can be changed to nii.gz
156+
SaveImaged(
157+
keys=pred_key,
158+
output_dir=out_dir,
159+
output_postfix="seg",
160+
output_dtype=uint8,
161+
resample=False,
162+
output_ext=".nii",
163+
),
164+
]
165+
)

0 commit comments

Comments
 (0)