Skip to content

Commit 6833cd1

Browse files
committed
adding decorator for cross compile flag
1 parent 1a389df commit 6833cd1

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

py/torch_tensorrt/_compile.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.fx
1111
from torch_tensorrt._enums import dtype
12-
from torch_tensorrt._features import ENABLED_FEATURES
12+
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
1515
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
@@ -301,6 +301,7 @@ def compile(
301301
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
302302

303303

304+
@needs_cross_compile
304305
def cross_compile_for_windows(
305306
module: torch.nn.Module,
306307
file_path: str,
@@ -525,6 +526,7 @@ def convert_method_to_trt_engine(
525526
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
526527

527528

529+
@needs_cross_compile
528530
def load_cross_compiled_exported_program(file_path: str = "") -> Any:
529531
"""
530532
Load an ExportedProgram file in Windows which was previously cross compiled in Linux

py/torch_tensorrt/dynamo/_compiler.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.fx.node import Target
1212
from torch_tensorrt._Device import Device
1313
from torch_tensorrt._enums import EngineCapability, dtype
14+
from torch_tensorrt._features import needs_cross_compile
1415
from torch_tensorrt._Input import Input
1516
from torch_tensorrt.dynamo import _defaults, partitioning
1617
from torch_tensorrt.dynamo._DryRunTracker import (
@@ -49,6 +50,7 @@
4950
logger = logging.getLogger(__name__)
5051

5152

53+
@needs_cross_compile
5254
def cross_compile_for_windows(
5355
exported_program: ExportedProgram,
5456
inputs: Optional[Sequence[Sequence[Any]]] = None,
@@ -1190,6 +1192,7 @@ def convert_exported_program_to_serialized_trt_engine(
11901192
return serialized_engine
11911193

11921194

1195+
@needs_cross_compile
11931196
def save_cross_compiled_exported_program(
11941197
gm: torch.fx.GraphModule,
11951198
file_path: str,
@@ -1211,6 +1214,7 @@ def save_cross_compiled_exported_program(
12111214
logger.debug(f"successfully saved the module for windows at {file_path}")
12121215

12131216

1217+
@needs_cross_compile
12141218
def load_cross_compiled_exported_program(file_path: str = "") -> Any:
12151219
"""
12161220
Load an ExportedProgram file in Windows which was previously cross compiled in Linux

0 commit comments

Comments
 (0)