diff --git a/examples/apps/NGRVNG.safetensors b/examples/apps/NGRVNG.safetensors new file mode 100644 index 0000000000..0fe8121f51 Binary files /dev/null and b/examples/apps/NGRVNG.safetensors differ diff --git a/examples/apps/flux-demo.py b/examples/apps/flux-demo.py new file mode 100644 index 0000000000..e7b4686cd7 --- /dev/null +++ b/examples/apps/flux-demo.py @@ -0,0 +1,155 @@ +import time + +import gradio as gr +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +pipe.to(torch.float16) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + "enabled_precisions": {torch.float32}, + "truncate_double": True, + "min_block_size": 1, + "use_fp32_acc": True, + "use_explicit_typing": True, + "debug": False, + "use_python_runtime": True, + "immutable_weights": False, + "enable_cuda_graph": True, +} +backbone.to(DEVICE) +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm +pipe.to(DEVICE) + + +def generate_image(prompt, inference_step, batch_size=2): + start_time = time.time() + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end_time = time.time() + return image, end_time - start_time + + +generate_image(["Test"], 2) +torch.cuda.empty_cache() + + +def model_change(model): + if model == "Torch Model": + pipe.transformer = backbone + backbone.to(DEVICE) + else: + backbone.to("cpu") + pipe.transformer = trt_gm + torch.cuda.empty_cache() + + +def load_lora(path): + + pipe.load_lora_weights( + path, + adapter_name="lora1", + ) + pipe.set_adapters(["lora1"], adapter_weights=[1]) + pipe.fuse_lora() + pipe.unload_lora_weights() + print("LoRA loaded! Begin refitting") + generate_image(["Test"], 2) + print("Refitting Finished!") + + +# Create Gradio interface +with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: + gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") + + with gr.Row(): + with gr.Column(): + # Input components + prompt_input = gr.Textbox( + label="Prompt", placeholder="Enter your prompt here...", lines=3 + ) + model_dropdown = gr.Dropdown( + choices=["Torch Model", "Torch-TensorRT Accelerated Model"], + value="Torch-TensorRT Accelerated Model", + label="Model Variant", + ) + + lora_upload_path = gr.Textbox( + label="LoRA Path", + placeholder="Enter the LoRA checkpoint path here", + value="/home/TensorRT/examples/apps/NGRVNG.safetensors", + lines=2, + ) + num_steps = gr.Slider( + minimum=20, maximum=100, value=20, step=1, label="Inference Steps" + ) + batch_size = gr.Slider( + minimum=1, maximum=8, value=1, step=1, label="Batch Size" + ) + + generate_btn = gr.Button("Generate Image") + load_lora_btn = gr.Button("Load LoRA") + + with gr.Column(): + # Output component + output_image = gr.Gallery(label="Generated Image") + time_taken = gr.Textbox( + label="Generation Time (seconds)", interactive=False + ) + + # Connect the button to the generation function + model_dropdown.change(model_change, inputs=[model_dropdown]) + load_lora_btn.click( + fn=load_lora, + inputs=[ + lora_upload_path, + ], + ) + + # Update generate button click to include time output + generate_btn.click( + fn=generate_image, + inputs=[ + prompt_input, + num_steps, + batch_size, + ], + outputs=[output_image, time_taken], + ) + +# Launch the interface +if __name__ == "__main__": + demo.launch() diff --git a/examples/apps/flux-quantization-fp32.py b/examples/apps/flux-quantization-fp32.py new file mode 100644 index 0000000000..6cd226294d --- /dev/null +++ b/examples/apps/flux-quantization-fp32.py @@ -0,0 +1,168 @@ +# %% +# Import the following libraries +# ----------------------------- +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float32, +) +pipe.transformer = FluxTransformer2DModel( + num_layers=1, num_single_layers=1, guidance_embeds=True +) + +pipe.to(DEVICE).to(torch.float32) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +ptq_config = mtq.FP8_DEFAULT_CFG +backbone = mtq.quantize(backbone, ptq_config, forward_loop) +mtq.disable_quantizer(backbone, filter_func) + + +# %% +# Export the backbone using torch.export +# -------------------------------------------------- +# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` +# due to `0/1 specialization `_ + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=2) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float32).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float32 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float32).to( + DEVICE + ), + "timestep": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float32).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float32).to(DEVICE), + "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions={torch.float8_e4m3fn}, + truncate_double=True, + min_block_size=1, + debug=False, + use_python_runtime=True, + immutable_weights=True, + offload_module_to_cpu=True, + ) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline + +for _ in range(2): + generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/examples/apps/flux-quantization.py b/examples/apps/flux-quantization.py new file mode 100644 index 0000000000..f7f35bc378 --- /dev/null +++ b/examples/apps/flux-quantization.py @@ -0,0 +1,203 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8"], + default="int8", + help="Quantization data type to use (fp8 or int8)", +) + +args = parser.parse_args() + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + enabled_precisions = {torch.float8_e4m3fn, torch.float16} + ptq_config = mtq.FP8_DEFAULT_CFG +else: # int8 + enabled_precisions = {torch.int8, torch.float16} + ptq_config = mtq.INT8_DEFAULT_CFG + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +# pipe.transformer = FluxTransformer2DModel( +# num_layers=1, num_single_layers=1, guidance_embeds=True +# ) + +pipe.to(DEVICE).to(torch.float16) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +backbone = mtq.quantize(backbone, ptq_config, forward_loop) +mtq.disable_quantizer(backbone, filter_func) + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=2) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( + DEVICE + ), + "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + truncate_double=True, + min_block_size=1, + debug=False, + use_python_runtime=True, + immutable_weights=True, + offload_module_to_cpu=True, + ) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") + + +def benchmark(prompt, inference_step, batch_size=2, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print("Time Elapse for", iterations, "iterations:", end - start) + print("Average Latency Per Step:", (end - start) / inference_step / iterations) + return image + + +print(f"Benchmark Original PyTorch Module Latency ({args.dtype})") +benchmark(["Test"], 50, iterations=3) + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index f264b8a8d3..665bda1b51 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -22,6 +22,7 @@ import torch import torch_tensorrt as torch_trt import torchvision.models as models +from diffusers import DiffusionPipeline np.random.seed(5) torch.manual_seed(5) @@ -31,7 +32,7 @@ # Initialize the Mutable Torch TensorRT Module with settings. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ settings = { - "use_python": False, + "use_python_runtime": False, "enabled_precisions": {torch.float32}, "immutable_weights": False, } @@ -40,7 +41,6 @@ mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. mutable_module(*inputs) - # %% # Make modifications to the mutable module. # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -73,13 +73,12 @@ # Stable Diffusion with Huggingface # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -from diffusers import DiffusionPipeline with torch.no_grad(): settings = { "use_python_runtime": True, "enabled_precisions": {torch.float16}, - "debug": True, + "debug": False, "immutable_weights": False, } @@ -106,6 +105,7 @@ "text_embeds": {0: BATCH}, "time_ids": {0: BATCH}, }, + "return_dict": None, } pipe.unet.set_expected_dynamic_shape_range( args_dynamic_shapes, kwargs_dynamic_shapes diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 66a1a70964..51202528c5 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -101,6 +101,7 @@ ) # Check the output +model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assert torch.allclose( diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 3891fcbb9a..9dcd073f73 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -112,6 +112,8 @@ min_block_size=1, use_fp32_acc=True, use_explicit_typing=True, + use_python_runtime=True, + immutable_weights=False, ) # %% @@ -120,13 +122,13 @@ # Release the GPU memory occupied by the exported program and the pipe.transformer # Set the transformer in the Flux pipeline to the Torch-TRT compiled model -del ep -backbone.to("cpu") pipe.to(DEVICE) -torch.cuda.empty_cache() +backbone.to("cpu") pipe.transformer = trt_gm +del ep +torch.cuda.empty_cache() pipe.transformer.config = config - +trt_gm.device = torch.device("cuda") # %% # Image generation using prompt # --------------------------- diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6928347baa..50b3a32a87 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -37,6 +37,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -421,6 +422,7 @@ def compile( enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, + offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -550,15 +552,6 @@ def compile( "`immutable_weights` must be False when `refit_identical_engine_weights` is True." ) - if ( - not immutable_weights - and not refit_identical_engine_weights - and enable_weight_streaming - ): - raise ValueError( - "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" - ) - if ( "enable_cross_compile_for_windows" in kwargs.keys() and kwargs["enable_cross_compile_for_windows"] @@ -674,6 +667,7 @@ def compile( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "offload_module_to_cpu": offload_module_to_cpu, } settings = CompilationSettings(**compilation_options) @@ -684,12 +678,17 @@ def compile( ) gm = exported_program.module() + # Move the weights in the state_dict to CPU logger.debug("Input graph: " + str(gm.graph)) # Apply lowering on the graph module gm = post_lowering(gm, settings) logger.debug("Lowered Input graph: " + str(gm.graph)) - + if offload_module_to_cpu: + exported_program.module().to(CPU_DEVICE) + logger.info( + "The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False." + ) trt_gm = compile_module( gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache ) @@ -820,6 +819,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -833,6 +833,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: str(name), str(submodule.graph), ) + submodule.to(torch.cuda.current_device()) continue if name not in submodule_node_dict: @@ -891,7 +892,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: parse_graph_io(submodule, subgraph_data) dryrun_tracker.tensorrt_graph_count += 1 dryrun_tracker.per_subgraph_data.append(subgraph_data) - + torch.cuda.empty_cache() # Create TRT engines from submodule if not settings.dryrun: trt_module = convert_module( diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 379a196e2e..aafd1072f4 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -49,6 +49,7 @@ TILING_OPTIMIZATION_LEVEL = "none" L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False +OFFLOAD_MODULE_TO_CPU = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c128e9cc82..6498f8dc57 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -2,6 +2,7 @@ import collections.abc import copy +import gc import logging from typing import Any, List, Optional, Sequence, Tuple @@ -35,7 +36,9 @@ TorchTensorRTModule, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, check_module_output, + delete_module, get_model_device, get_torch_inputs, set_log_level, @@ -109,7 +112,9 @@ def construct_refit_mapping( def construct_refit_mapping_from_weight_name_map( - weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] + weight_name_map: dict[Any, Any], + state_dict: dict[Any, Any], + settings: CompilationSettings, ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): @@ -120,7 +125,9 @@ def construct_refit_mapping_from_weight_name_map( # If weights is not in sd, we can leave it unchanged continue else: - engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( + to_torch_device(settings.device) + ) engine_weight_map[engine_weight_name] = ( engine_weight_map[engine_weight_name] @@ -163,7 +170,7 @@ def _refit_single_trt_engine_with_gm( "constant_mapping", {} ) # type: ignore mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() + weight_name_map, new_gm.state_dict(), settings ) constant_mapping_with_type = {} @@ -309,42 +316,68 @@ def refit_module_weights( get_decompositions(settings.enable_experimental_decompositions) ) new_gm = new_weight_module.module() + logger.debug("Input graph: " + str(new_gm.graph)) # Apply lowering on the graph module new_gm = post_lowering(new_gm, settings) - logger.info("Compilation Settings: %s\n", settings) + logger.debug("Lowered Input graph: " + str(new_gm.graph)) # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + CONVERTERS.set_compilation_settings(settings) + + # Check the number of supported operations in the graph + num_supported_ops, total_ops = partitioning.get_graph_converter_support( + new_gm, settings.debug, settings.torch_executed_ops + ) + + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): + logger.warning( + f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " + f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" + ) + return new_gm + else: + logger.debug( + f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph." + ) # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: + logger.info("Partitioning the graph via the fast partitioner") new_partitioned_module, supported_ops = partitioning.fast_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=(num_supported_ops == total_ops), ) + except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( "Partitioning failed on the subgraph with fast partition. See trace above. " - + "Retrying with global partition.", + "Retrying with global partition.", exc_info=True, ) settings.use_fast_partitioner = False if not settings.use_fast_partitioner: + logger.info("Partitioning the graph via the global partitioner") new_partitioned_module, supported_ops = partitioning.global_partition( new_gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) + # Done Partition if inline_module: # Preprocess the partitioned module to be in the same format as the inline module inline_torch_modules(new_partitioned_module) @@ -361,7 +394,7 @@ def refit_module_weights( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - + new_weight_module.module().to(CPU_DEVICE) for name, new_submodule in new_partitioned_module.named_children(): # Refit each submodule # Extract engine from the submodule @@ -464,26 +497,33 @@ def refit_module_weights( settings=settings, weight_name_map=None, ) + delete_module(new_submodule) # clear EXCLUDE_WEIGHTS flag serialization_config = engine.create_serialization_config() serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) serialized_engine = engine.serialize_with_config(serialization_config) - if isinstance( - compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) - ): + if isinstance(compiled_submodule, PythonTorchTensorRTModule): + compiled_submodule.serialized_engine = bytes(serialized_engine) + elif isinstance(compiled_submodule, TorchTensorRTModule): compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated compiled_submodule.serialized_engine = bytes(serialized_engine) compiled_submodule.setup_engine() - elif inline_module: new_engine_info = list(engine_info) new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) + del engine + gc.collect() + torch.cuda.empty_cache() + + delete_module(new_partitioned_module) + if verify_output and arg_inputs is not None: + new_gm.to(torch.cuda.current_device()) if check_module_output( new_module=new_gm, refitted_module=compiled_module, @@ -491,6 +531,7 @@ def refit_module_weights( kwarg_inputs=torch_kwarg_inputs, ): logger.info("Refitting Succeed!") + new_gm.to(CPU_DEVICE) else: if weight_name_map: logger.warning( @@ -506,6 +547,7 @@ def refit_module_weights( in_place=in_place, ) logger.error("Refitting Failed! The outputs do not match.") + new_gm.to(CPU_DEVICE) else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d9b0e05e4d..97c02f34fb 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -25,6 +25,7 @@ MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, + OFFLOAD_MODULE_TO_CPU, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, REFIT_IDENTICAL_ENGINE_WEIGHTS, @@ -140,6 +141,7 @@ class CompilationSettings: tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE + offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 17f2fccbff..73cb685808 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -44,7 +44,7 @@ get_trt_tensor, to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -491,15 +491,11 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - gm_is_on_cuda = get_model_device(self.module).type == "cuda" - if not gm_is_on_cuda: - # If the model original position is on CPU, move it GPU - sd = { - k: v.reshape(-1).to(torch_device) - for k, v in self.module.state_dict().items() - } - else: - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } + weight_name_map: dict[str, Any] = {} np_map = {} constant_mapping = {} @@ -737,7 +733,8 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - + if self.compilation_settings.offload_module_to_cpu: + delete_module(self.module) serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1fed1f9a1f..1d86c65d89 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -597,7 +597,9 @@ def aten_ops_neg( ) else: - @dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default) + @dynamo_tensorrt_converter( + torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True + ) def aten_ops_quantize_op( ctx: ConversionContext, target: Target, @@ -617,6 +619,38 @@ def aten_ops_quantize_op( ) +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: + # Currently, `attn_mask` is not supported + return args_bounds_check(node.args, 3) is None + + +@dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + capability_validator=attention_validator, + supports_dynamic_shapes=True, +) +def tensorrt_scaled_dot_product_attention( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.attention.scaled_dot_product_attention( + ctx, + target, + SourceIR.TORCHTRT_LOWERED, + name, + args[0], + args[1], + args[2], + args_bounds_check(args, 5, False), + kwargs.get("scale", None), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..75f7492591 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,6 +2,7 @@ activation, addmm, arange, + attention, cast, cat, condition, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py new file mode 100644 index 0000000000..9cc4a30ccf --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -0,0 +1,165 @@ +import math +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor + + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # the lower triangle of the tensor means the rows greater than and equal to the cols + row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0) + col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1) + rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col) + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1 + ) + # get the rows + row_tensor = impl.elementwise.trunc_div( + ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col + ) + # get the cols + col_tensor = impl.elementwise.fmod( + ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col + ) + cond = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_tensor, col_tensor + ) + return impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", cond, [row, col] + ) + + +def scaled_dot_product_attention( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + query: TRTTensor, + key: TRTTensor, + value: TRTTensor, + is_causal: bool, + scale: Optional[float], +) -> TRTTensor: + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, -2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2) + + LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S) + + # this is to generate a tensor which has shape (L, S), type is int32 + arange_tensor = impl.arange.arange( + ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1 + ) + shape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S] + ) + + # since we want our attn_bias to be in float32, so cast it to float32 + shape_tensor = cast_trt_tensor( + ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir + ) + + # initialize the attn_bias as the zeros tensor + attn_bias = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0 + ) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor) + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + inf_tensor = impl.elementwise.mul( + ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf") + ) + cond = impl.elementwise.eq( + ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True) + ) + # mask out the certain part of the attn_bias + attn_bias = impl.condition.select( + ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond + ) + + scaled = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index e472ed3092..a3af535028 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -6,12 +6,31 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor +def get_ir(target: Target) -> SourceIR: + target_module = getattr(target, "__module__", "None") + if any( + target_module.startswith(prefix) + for prefix in ("torch.ops.aten", "torch._ops.aten") + ): + return SourceIR.ATEN + elif any( + target_module.startswith(prefix) + for prefix in ("torch.ops.prims", "torch._ops.prims") + ): + return SourceIR.PRIM + elif target_module.startswith("torch.nn"): + return SourceIR.NN + + return SourceIR.UNKNOWN + + def quantize( ctx: ConversionContext, target: Target, @@ -44,26 +63,41 @@ def quantize( elif num_bits == 8 and exponent_bits == 4: max_bound = 448 - amax = to_torch(amax, None) - scale = torch.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") + if not isinstance(amax, trt.ITensor): + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale") + else: + scale = impl.elementwise.div( + ctx, + target, + get_ir(target), + name, + amax, + max_bound, + ) + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node - quantize_layer = ctx.net.add_quantize(input_tensor, scale) if num_bits == 8 and exponent_bits == 0: - quantize_layer.set_output_type(0, trt.DataType.INT8) + dtype = trt.DataType.INT8 elif num_bits == 8 and exponent_bits == 4: - quantize_layer.set_output_type(0, trt.DataType.FP8) + dtype = trt.DataType.FP8 + + if not isinstance(input_tensor, TRTTensor): + input_tensor = get_trt_tensor(ctx, input_tensor, name + "_quantize_input") + + quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype) set_layer_name(quantize_layer, target, name + "_quantize", source_ir) q_output = quantize_layer.get_output(0) # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) + dequantize_layer = ctx.net.add_dequantize( + q_output, scale, output_type=input_tensor.dtype + ) set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - if num_bits == 8 and exponent_bits == 0: - dequantize_layer.precision = trt.DataType.INT8 - elif num_bits == 8 and exponent_bits == 4: - # Set DQ layer precision to FP8 - dequantize_layer.precision = trt.DataType.FP8 + dequantize_layer.precision = dtype + dq_output = dequantize_layer.get_output(0) return dq_output diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b4165477ed..1cd47e4bd8 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,6 +1,6 @@ import logging from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import torch from torch._decomp import register_decomposition @@ -440,130 +440,130 @@ def view_decomposition(x: torch.Tensor, size: List[torch.SymInt]) -> torch.Tenso return aten._reshape_copy.default(x, size) -@register_torch_trt_decomposition( - aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS -) -def scaled_dot_product_attention_decomposition( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> torch.Tensor: - L, S = query.size(-2), key.size(-2) - device = query.device - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) - - if is_causal: - assert attn_mask is None, "attn_mask must be None when is_causal=True" - temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) - attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) - else: - attn_bias = attn_mask + attn_bias - - if enable_gqa: - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) - - if scale is None: - scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) - attn_weight = attn_weight / scale - else: - attn_weight = attn_weight * scale - - attn_weight = attn_weight + attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - return attn_weight @ value - - -@register_torch_trt_decomposition( - aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS -) -def scaled_dot_product_flash_attention_decomposition( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.SymInt, - torch.SymInt, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - attn = scaled_dot_product_attention_decomposition( - query, key, value, None, dropout_p, is_causal, scale=scale - ) - return attn, None, None, None, 0, 0, None, None, None - - -@register_torch_trt_decomposition( - aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS -) -def scaled_dot_product_efficient_attention_decomposition( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor], - compute_log_sumexp: bool, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: Optional[float] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - attn = scaled_dot_product_attention_decomposition( - query, key, value, attn_bias, dropout_p, is_causal, scale=scale - ) - return attn, None, None, None - - -@register_torch_trt_decomposition( - aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS -) -def scaled_dot_product_cudnn_attention_decomposition( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor], - compute_log_sumexp: bool, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.SymInt, - torch.SymInt, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - attn = scaled_dot_product_attention_decomposition( - query, key, value, attn_bias, dropout_p, is_causal, scale=scale - ) - return attn, None, None, None, 0, 0, None, None, None +# @register_torch_trt_decomposition( +# aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) +# def scaled_dot_product_attention_decomposition( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# attn_mask: Optional[torch.Tensor] = None, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# *, +# scale: Optional[float] = None, +# enable_gqa: bool = False, +# ) -> torch.Tensor: +# L, S = query.size(-2), key.size(-2) +# device = query.device +# attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) + +# if is_causal: +# assert attn_mask is None, "attn_mask must be None when is_causal=True" +# temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) +# attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) + +# if attn_mask is not None: +# if attn_mask.dtype == torch.bool: +# attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) +# else: +# attn_bias = attn_mask + attn_bias + +# if enable_gqa: +# key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) +# value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + +# attn_weight = query @ key.transpose(-2, -1) + +# if scale is None: +# scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) +# attn_weight = attn_weight / scale +# else: +# attn_weight = attn_weight * scale + +# attn_weight = attn_weight + attn_bias +# attn_weight = torch.softmax(attn_weight, dim=-1) +# return attn_weight @ value + + +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) +# def scaled_dot_product_flash_attention_decomposition( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# return_debug_mask: bool = False, +# *, +# scale: Optional[float] = None, +# ) -> Tuple[ +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# torch.SymInt, +# torch.SymInt, +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# ]: +# attn = scaled_dot_product_attention_decomposition( +# query, key, value, None, dropout_p, is_causal, scale=scale +# ) +# return attn, None, None, None, 0, 0, None, None, None + + +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) +# def scaled_dot_product_efficient_attention_decomposition( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# attn_bias: Optional[torch.Tensor], +# compute_log_sumexp: bool, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# *, +# scale: Optional[float] = None, +# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +# attn = scaled_dot_product_attention_decomposition( +# query, key, value, attn_bias, dropout_p, is_causal, scale=scale +# ) +# return attn, None, None, None + + +# @register_torch_trt_decomposition( +# aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS +# ) +# def scaled_dot_product_cudnn_attention_decomposition( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# attn_bias: Optional[torch.Tensor], +# compute_log_sumexp: bool, +# dropout_p: float = 0.0, +# is_causal: bool = False, +# return_debug_mask: bool = False, +# *, +# scale: Optional[float] = None, +# ) -> Tuple[ +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# torch.SymInt, +# torch.SymInt, +# torch.Tensor, +# torch.Tensor, +# torch.Tensor, +# ]: +# attn = scaled_dot_product_attention_decomposition( +# query, key, value, attn_bias, dropout_p, is_causal, scale=scale +# ) +# return attn, None, None, None, 0, 0, None, None, None @register_torch_trt_decomposition( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 2ecc45ecf3..553151da7a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -9,6 +9,7 @@ from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast +from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach @@ -23,6 +24,7 @@ repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, + lower_scaled_dot_product_attention, remove_assert_nodes, accumulate_fp32_matmul, remove_num_users_is_0_nodes, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 6ebefc5509..d3c6d6b4f1 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,4 +101,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: + if node.target == torch.ops.tensorrt.quantize_op.default: + return True return False diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py new file mode 100644 index 0000000000..40fd587615 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -0,0 +1,169 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + + +def lower_scaled_dot_product_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace specific versions of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + original_fns, replacement = scaled_dot_product_attention_replacement() + replaced_nodes = [] + # For each original function, search for it in the graph and replace + for original in original_fns: + replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( + gm, + original, + replacement, + ignore_literals=True, + ) + + if replaced_nodes: + # Repair instances which use the kwargs field (specifically the "scale" kwarg) + # Also repair instances which specified the is_causal or attn_bias fields + for match in replaced_nodes: + attention_node_replaced = None + # Seek the attention operator being replaced + for node in match.nodes_map: + if node.target in REPLACEABLE_ATEN_OPS: + attention_node_replaced = match.nodes_map[node] + break + + assert attention_node_replaced is not None + assert len(match.replacements) == 1 + + new_attention_node = match.replacements[0] + + assert ( + new_attention_node.target + == torch.nn.functional.scaled_dot_product_attention + ) + + # Copy the metadata of the replaced attention node to the new node + # TODO: Investigate why there are multiple FakeTensors in the metadata. + # We only use the first one as it contains the output shape information for this node. + if "val" in attention_node_replaced.meta: + new_attention_node.meta["val"] = copy.copy( + attention_node_replaced.meta["val"][0] + ) + + # If the attention operator had keyword-args, copy them to the new node + if attention_node_replaced.kwargs: + new_attention_node.kwargs = {**attention_node_replaced.kwargs} + + # Set default args in new node: + # Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False + new_attention_node.args = new_attention_node.args + (None, 0.0, False) + + # The `is_causal` argument was specified + if ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 4, False) + ) or ( + ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) + and args_bounds_check(attention_node_replaced.args, 6, False) + ): + new_attention_node.args = ( + new_attention_node.args[:5] + (True,) + new_attention_node.args[6:] + ) + + # The `attn_bias` argument was specified + if ( + attention_node_replaced.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ) and args_bounds_check(attention_node_replaced.args, 3) is not None: + new_attention_node.args = ( + new_attention_node.args[:3] + + attention_node_replaced.args[3] + + new_attention_node.args[4:] + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}") + + return gm + + +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for efficient attention""" + + # Efficient Attention original graph + def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention original graph + def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + out = operator.getitem(outputs, 0) + return out + + # Efficient Attention w/Scale original graph + def efficient_scale( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, + k, + v, + None, + False, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Flash Attention w/Scale original graph + def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + scale=1.0, + ) + out = operator.getitem(outputs, 0) + return out + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return (efficient, flash, efficient_scale, flash_scale), replacement diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index eaeb6a8c28..88eff5757b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -2,17 +2,16 @@ import logging from copy import deepcopy from enum import Enum, auto -from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, Optional, Union import numpy as np import torch -from torch.fx.node import Target +import torch_tensorrt +from torch.export._trace import _export from torch_tensorrt._Device import Device -from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._compiler import compile as dynamo_compile from torch_tensorrt.dynamo._refit import refit_module_weights -from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import ( check_output_equal, to_torch_device, @@ -63,35 +62,11 @@ def __init__( pytorch_model: torch.nn.Module, *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, - disable_tf32: bool = _defaults.DISABLE_TF32, - assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: Set[ - Union[torch.dtype, dtype] - ] = _defaults.ENABLED_PRECISIONS, - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - immutable_weights: bool = False, - debug: bool = _defaults.DEBUG, - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, - workspace_size: int = _defaults.WORKSPACE_SIZE, - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, - min_block_size: int = _defaults.MIN_BLOCK_SIZE, - torch_executed_ops: Optional[Collection[Target]] = None, - torch_executed_modules: Optional[List[str]] = None, - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, - version_compatible: bool = _defaults.VERSION_COMPATIBLE, - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME, - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - dryrun: bool = _defaults.DRYRUN, - hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, - timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + immutable_weights: bool = False, + strict: bool = True, + allow_complex_guards_as_runtime_asserts: bool = False, + weight_streaming_budget: Optional[int] = None, **kwargs: Any, ) -> None: """ @@ -154,53 +129,35 @@ def __init__( self.exp_program: Any = None self.arg_inputs: tuple[Any, ...] = tuple() self.kwarg_inputs: dict[str, Any] = {} - device = to_torch_tensorrt_device(device) - enabled_precisions = {dtype._from(p) for p in enabled_precisions} + self.additional_settings = kwargs + self.strict = strict + self.allow_complex_guards_as_runtime_asserts = ( + allow_complex_guards_as_runtime_asserts + ) + self.use_python_runtime = use_python_runtime + self.trt_device = to_torch_tensorrt_device(device) assert ( not immutable_weights - ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." - compilation_options = { - "enabled_precisions": ( - enabled_precisions - if enabled_precisions - else _defaults.ENABLED_PRECISIONS - ), - "debug": debug, - "device": device, - "assume_dynamic_shape_support": assume_dynamic_shape_support, - "workspace_size": workspace_size, - "min_block_size": min_block_size, - "torch_executed_ops": ( - torch_executed_ops if torch_executed_ops is not None else set() - ), - "pass_through_build_failures": pass_through_build_failures, - "max_aux_streams": max_aux_streams, - "version_compatible": version_compatible, - "optimization_level": optimization_level, - "use_python_runtime": use_python_runtime, - "truncate_double": truncate_double, - "use_fast_partitioner": use_fast_partitioner, - "num_avg_timing_iters": num_avg_timing_iters, - "enable_experimental_decompositions": enable_experimental_decompositions, - "require_full_compilation": require_full_compilation, - "disable_tf32": disable_tf32, - "sparse_weights": sparse_weights, - "immutable_weights": immutable_weights, - "engine_capability": engine_capability, - "dla_sram_size": dla_sram_size, - "dla_local_dram_size": dla_local_dram_size, - "dla_global_dram_size": dla_global_dram_size, - "dryrun": dryrun, - "hardware_compatible": hardware_compatible, - "timing_cache_path": timing_cache_path, - } + ), "`immutable_weights has to be False for a MutableTorchTensorRTModule" + self.arg_dynamic_shapes: Optional[tuple[Any]] = None self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None - - self.settings = CompilationSettings(**compilation_options) + self.serializable_dynamic_shapes_dims: dict[str, tuple[str, int, int]] = {} self.run_info: Optional[tuple[Any, ...]] = None self.state_dict_metadata: dict[str, torch.Size] = {} self._store_state_dict_metadata() + self.enable_weight_streaming = ( + kwargs["enable_weight_streaming"] + if "enable_weight_streaming" in kwargs + else False + ) + self.weight_streaming_ctx = None + self.weight_streaming_budget = weight_streaming_budget + if self.enable_weight_streaming: + if weight_streaming_budget is None: + logger.warning( + "Weight stremaing budget is not set. Using auto weight streaming budget" + ) cls = self.__class__ self.__class__ = type( @@ -293,7 +250,7 @@ def update_refit_condition(self) -> None: # to determine whether refit/recompilation is needed. If the output is the same, no further process needed. if self.run_info: args, kwargs, result = self.run_info - self.original_model.to(to_torch_device(self.settings.device)) + self.original_model.to(to_torch_device(self.trt_device)) new_result = self.original_model(*args, **kwargs) self.original_model.cpu() torch.cuda.empty_cache() @@ -325,17 +282,17 @@ def refit_gm(self) -> None: MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module. If it fails to catch the changes, please call this function manually to update the TRT graph module. """ - self.original_model.to(to_torch_device(self.settings.device)) + if self.exp_program is None: - self.exp_program = torch.export.export( - self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() else: self.exp_program._state_dict = ( MutableTorchTensorRTModule._transform_state_dict( self.original_model.state_dict() ) ) + self.exp_program.module().to(to_torch_device(self.trt_device)) self.gm = refit_module_weights( self.gm, self.exp_program, @@ -345,9 +302,28 @@ def refit_gm(self) -> None: in_place=True, ) - self.original_model.cpu() + self.original_model.to("cpu") torch.cuda.empty_cache() + def get_exported_program(self) -> torch.export.ExportedProgram: + if self.allow_complex_guards_as_runtime_asserts: + return _export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts, + ) + else: + return torch.export.export( + self.original_model, + self.arg_inputs, + kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), + strict=self.strict, + ) + def compile(self) -> None: """ (Re)compile the TRT graph module using the PyTorch module. @@ -356,25 +332,37 @@ def compile(self) -> None: If it fails to catch the changes, please call this function manually to recompile the TRT graph module. """ # Export the module - self.original_model.to(to_torch_device(self.settings.device)) - self.exp_program = torch.export.export( - self.original_model, - self.arg_inputs, - kwargs=self.kwarg_inputs, - dynamic_shapes=self._get_total_dynamic_shapes(), - ) + self.original_model.to(to_torch_device(self.trt_device)) + self.exp_program = self.get_exported_program() self.gm = dynamo_compile( self.exp_program, arg_inputs=self.arg_inputs, kwarg_inputs=self.kwarg_inputs, - **self.settings.__dict__, + immutable_weights=False, + use_python_runtime=self.use_python_runtime, + **self.additional_settings, ) - self.original_model.cpu() + self.original_model.to("cpu") torch.cuda.empty_cache() + if self.enable_weight_streaming: + self.set_weight_streaming_ctx(self.weight_streaming_budget) + + def set_weight_streaming_ctx(self, requested_budget: Optional[int] = None) -> None: + """ + Set the weight streaming budget. If budget is not set, then automatic weight streaming budget + is used. + """ + self.weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(self.gm) + requested_budget = ( + requested_budget + if requested_budget is not None + else self.weight_streaming_ctx.get_automatic_weight_streaming_budget() + ) + self.weight_streaming_ctx.device_budget = requested_budget def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: - if not self.arg_inputs: + if not self.arg_inputs and not self.kwarg_inputs: logger.info("First time compilation initiated. This may take some time.") self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) self._store_inputs(args, kwargs) @@ -491,14 +479,24 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) + weight_streaming_ctx = ( + self.weight_streaming_ctx if self.enable_weight_streaming else None + ) result = self.gm(*args, **kwargs) # Storing inputs and outputs for verification when the state is unknown self.run_info = (args, kwargs, result) return result - def to(self, device: str) -> None: - logger.warning("Original PyTorch model is moved. CPU offload may failed.") - self.original_model.to(device) + def to(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage." + + "If this is absolute necessary, please call module.pytorch_model.to(...) \n" + + "The model is still on the original device." + ) + + @property + def device(self) -> torch.device: + return to_torch_device(self.trt_device) def __deepcopy__(self, memo: Any) -> Any: cls = self.__class__ @@ -624,18 +622,58 @@ def _check_tensor_shapes_with_dynamic_shapes( return True + def serialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_serialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, torch.export.dynamic_shapes._Dim): + name = str(v).split("'")[1].split(".")[-1] + # We use string of the hash to be the unique identifier of Dim object + dims.setdefault(str(hash(v)), (name, v.min, v.max)) + obj[axis] = str(hash(v)) + else: + resursivly_serialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_serialize_dynamic_shape(v) + + resursivly_serialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_serialize_dynamic_shape(self.kwarg_dynamic_shapes) + + def deserialize_dynamic_shapes(self) -> None: + dims = self.serializable_dynamic_shapes_dims + + def resursivly_deserialize_dynamic_shape(obj: Any) -> None: + if isinstance(obj, dict): + for axis, v in obj.items(): + if isinstance(v, str): + obj[axis] = torch.export.Dim( + dims[v][0], min=dims[v][1], max=dims[v][2] + ) + else: + resursivly_deserialize_dynamic_shape(v) + if isinstance(obj, (tuple, list)): + for v in obj: + resursivly_deserialize_dynamic_shape(v) + + resursivly_deserialize_dynamic_shape(self.arg_dynamic_shapes) + resursivly_deserialize_dynamic_shape(self.kwarg_dynamic_shapes) + @staticmethod def save(module: Any, path: str) -> None: # Cast the object back to MutableTorchTensorRTModule to save assert ( - not module.settings.use_python_runtime + not module.use_python_runtime ), "Python runtime does not support serialization. Save failed." module.init_finished = False module.__class__ = MutableTorchTensorRTModule exp_program = module.exp_program module.pytorch_model = None module.exp_program = None - torch.save(module, path) + module.serialize_dynamic_shapes() + torch.save(module, path, pickle_protocol=4) # Restore deleted attributes module.exp_program = exp_program module.pytorch_model = _make_refit_change_trigger( @@ -658,7 +696,7 @@ def load(path: str) -> Any: module.pytorch_model = _make_refit_change_trigger( module.original_model, module.refit_state ) - module.original_model.to(to_torch_device(module.settings.device)) + module.original_model.to(to_torch_device(module.device)) module.exp_program = torch.export.export( module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs ) @@ -669,6 +707,7 @@ def load(path: str) -> Any: (cls, module.original_model.__class__), {}, ) + module.deserialize_dynamic_shapes() module.init_finished = True return module diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e4018ae95c..e75f5149ba 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -419,7 +419,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) - """ Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64 """ - if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)): + if isinstance(tensor, (torch.Tensor, FakeTensor)): + return tensor.dtype + elif isinstance(tensor, (int, float, bool)): return torch.tensor(tensor).dtype elif isinstance(tensor, torch.SymInt): return torch.int64 @@ -791,6 +793,8 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype] output_dtypes.append(dtype.float32) else: output_dtypes.append(dtype._from(output_meta.dtype)) + elif isinstance(output_meta, torch.SymInt): + output_dtypes.append(dtype.int64) elif "tensor_meta" in output.meta: output_meta = output.meta["tensor_meta"] output_dtypes.append(dtype._from(output_meta.dtype)) diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 346132145e..de0a7b9fdf 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -69,48 +69,16 @@ def __init__(self, compiled_module: torch.nn.Module) -> None: self.old_mode = _PY_RT_CUDAGRAPHS self.compiled_module = compiled_module self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None + self.old_module = None - def __enter__(self) -> torch.nn.Module: - global _PY_RT_CUDAGRAPHS - - num_torch_module = 0 - num_trt_module = 0 - for name, module in self.compiled_module.named_children(): - # need to disable cudagraphs if any model requires output allocator - if ( - hasattr(module, "requires_output_allocator") - and module.requires_output_allocator - ): - raise RuntimeError( - "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." - ) - if "_run_on_acc" in name: - num_trt_module += 1 - elif "_run_on_gpu" in name: - num_torch_module += 1 - - if num_torch_module > 0: - # Set whole cudagraphs mode and returns wrapped module - _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS - # Set new mode for C++ - if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule: - logger.debug( - "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" - ) - self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module) - return self.cudagraphs_module - else: - if num_trt_module > 0: - logger.debug("No graph breaks detected, using runtime cudagraphs mode") - else: - logger.debug( - "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" - ) - # Enable cudagraphs for TRT submodule - set_cudagraphs_mode(True) + if isinstance(self.compiled_module, torch_tensorrt.MutableTorchTensorRTModule): + self.old_module = self.compiled_module.gm + self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm) return self.compiled_module + else: + return get_cuda_graph_module(self.compiled_module) def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode @@ -118,6 +86,52 @@ def __exit__(self, *args: Any) -> None: # __del__ is not entirely predictable, so we reset cudagraph here if self.cudagraphs_module: self.cudagraphs_module._reset_captured_graph() + if self.old_module: # MutableTorchTRTModule + self.compiled_module.gm = self.old_module + + +def get_cuda_graph_module( + compiled_module: torch.fx.GraphModule, +) -> torch.nn.Module | torch.fx.GraphModule: + global _PY_RT_CUDAGRAPHS + + num_torch_module = 0 + num_trt_module = 0 + for name, module in compiled_module.named_children(): + # need to disable cudagraphs if any model requires output allocator + if ( + hasattr(module, "requires_output_allocator") + and module.requires_output_allocator + ): + raise RuntimeError( + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." + ) + if "_run_on_acc" in name: + num_trt_module += 1 + elif "_run_on_gpu" in name: + num_torch_module += 1 + + if num_torch_module > 0: + # Set whole cudagraphs mode and returns wrapped module + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + + logger.debug( + "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" + ) + return CudaGraphsTorchTensorRTModule(compiled_module) + else: + if num_trt_module > 0: + logger.debug("No graph breaks detected, using runtime cudagraphs mode") + else: + logger.debug( + "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" + ) + # Enable cudagraphs for TRT submodule + set_cudagraphs_mode(True) + return compiled_module def enable_cudagraphs( diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a0b3292c29..a18fee7c44 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -93,7 +93,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -117,6 +117,7 @@ def test_refit_one_engine_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -167,6 +168,7 @@ def test_refit_one_engine_no_map_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -192,7 +194,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -221,6 +223,7 @@ def test_refit_one_engine_with_wrong_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -249,7 +252,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions = {torch.float} debug = False min_block_size = 1 - use_python_runtime = False + use_python_runtime = True exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -272,6 +275,7 @@ def test_refit_one_engine_bert_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -294,7 +298,7 @@ def test_refit_one_engine_bert_with_weightmap(): "TorchScript Frontend is not available", ) @pytest.mark.unit -def test_refit_one_engine_inline_runtime__with_weightmap(): +def test_refit_one_engine_inline_runtime_with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") model = models.resnet18(pretrained=False).eval().to("cuda") model2 = models.resnet18(pretrained=True).eval().to("cuda") @@ -326,6 +330,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -370,6 +375,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -441,6 +447,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -489,6 +496,7 @@ def test_refit_one_engine_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -540,6 +548,7 @@ def test_refit_one_engine_bert_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -594,6 +603,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -638,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -709,6 +720,7 @@ def forward(self, x): ) # Check the output + model2.to("cuda") expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) @@ -763,6 +775,7 @@ def forward(self, x): ) # Check the output + model.to("cuda") pyt_outputs, trt_outputs = exp_program.module()(*inputs), trt_gm(*inputs) for pyt_output, trt_output in zip(pyt_outputs, trt_outputs): assertions.assertTrue( diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index c07e04b6a4..d9105f5a75 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -75,7 +75,7 @@ def test_check_input_shape_dynamic(): @pytest.mark.unit -def test_model_complex_dynamic_shape(): +def test_model_complex_dynamic_shape_with_saving(): device = "cuda:0" class Model(torch.nn.Module): @@ -111,6 +111,13 @@ def forward(self, a, b, c=None): # Run inference trt_gm(*inputs, **kwargs) + try: + save_path = os.path.join(tempfile.gettempdir(), "mutable_module.pkl") + torch_trt.MutableTorchTensorRTModule.save(mutable_module, save_path) + model = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") + except Exception as e: + assert "Module saving and reloading with dynamic shape failed." + inputs_2 = [torch.rand(10, 9).to(device)] kwargs_2 = { "b": torch.rand(9, 30).to(device), diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh new file mode 100644 index 0000000000..68486abea8 --- /dev/null +++ b/tools/perf/Flux/benchmark.sh @@ -0,0 +1,6 @@ +#TODO: Enter the HF Token +huggingface-cli login --token HF_TOKEN + +python flux_quantization.py --dtype fp8 > fp8_benchmark.txt +python flux_quantization.py --dtype int8 > int8_benchmark.txt +python flux_perf.py > fp16_benchmark.txt \ No newline at end of file diff --git a/tools/perf/Flux/create_env.sh b/tools/perf/Flux/create_env.sh new file mode 100644 index 0000000000..9390020214 --- /dev/null +++ b/tools/perf/Flux/create_env.sh @@ -0,0 +1,27 @@ +%bash + +git config --global --add safe.directory /home/TensorRT + +#Install bazel +apt install apt-transport-https curl gnupg -y +curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor >bazel-archive-keyring.gpg +mv bazel-archive-keyring.gpg /usr/share/keyrings +echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list + + +apt update && apt install bazel-7.2.1 +apt install bazel +bazel +cd /home/TensorRT + +python -m pip install --pre -e . --extra-index-url https://download.pytorch.org/whl/nightly/cu128 +pip install tensorrt==10.9.0.34 --force-reinstall + +pip3 install --pre torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + + +pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" + +pip install notebook +pip install gradio safetensors peft pyinstrument +pip install nvidia-modelopt onnx torchprofile pulp onnxruntime diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py new file mode 100644 index 0000000000..e5e7dceecd --- /dev/null +++ b/tools/perf/Flux/flux_perf.py @@ -0,0 +1,93 @@ +from time import time + +import torch +import torch_tensorrt +from diffusers import FluxPipeline + +for i in range(torch.cuda.device_count()): + print(torch.cuda.get_device_properties(i).name) + +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float32, +) +pipe.to(DEVICE).to(torch.float32) +backbone = pipe.transformer + + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=8) + +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {}, + "img_ids": {}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} + +settings = { + "strict": False, + "allow_complex_guards_as_runtime_asserts": True, + "enabled_precisions": {torch.float32}, + "truncate_double": True, + "min_block_size": 1, + "use_fp32_acc": True, + "use_explicit_typing": True, + "debug": False, + "use_python_runtime": True, + "immutable_weights": False, +} + + +def generate_image(prompt, inference_step, batch_size=2, benchmark=False, iterations=1): + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + if benchmark: + print("Time Elapse for", iterations, "iterations:", end - start) + print("Average Latency Per Step:", (end - start) / inference_step / iterations) + return image + + +generate_image(["Test"], 2) +print("Benchmark Original PyTorch Module Latency (float32)") +generate_image(["Test"], 50, benchmark=True, iterations=3) + +pipe.to(torch.float16) +print("Benchmark Original PyTorch Module Latency (float16)") +generate_image(["Test"], 50, benchmark=True, iterations=3) + + +trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) +trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) +pipe.transformer = trt_gm + +start = time() +generate_image(["Test"], 2) +end = time() +print("Time Elapse compilation:", end - start) +print() +print("Benchmark TRT Accelerated Latency") +generate_image(["Test"], 50, benchmark=True, iterations=3) +torch.cuda.empty_cache() + + +with torch_tensorrt.runtime.enable_cudagraphs(trt_gm): + generate_image(["Test"], 2) + print("Benchmark TRT Accelerated Latency with Cuda Graph") + generate_image(["Test"], 50, benchmark=True, iterations=3) diff --git a/tools/perf/Flux/flux_quantization.py b/tools/perf/Flux/flux_quantization.py new file mode 100644 index 0000000000..e90e29e2fe --- /dev/null +++ b/tools/perf/Flux/flux_quantization.py @@ -0,0 +1,202 @@ +# %% +# Import the following libraries +# ----------------------------- +# Load the ModelOpt-modified model architecture and weights using Huggingface APIs +# Add argument parsing for dtype selection +import argparse +import re + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from modelopt.torch.quantization.utils import export_torch_mode +from torch.export._trace import _export +from transformers import AutoModelForCausalLM + +parser = argparse.ArgumentParser( + description="Run Flux quantization with different dtypes" +) +parser.add_argument( + "--dtype", + choices=["fp8", "int8"], + default="fp8", + help="Quantization data type to use (fp8 or int8)", +) + +args = parser.parse_args() + +# Update enabled precisions based on dtype argument +if args.dtype == "fp8": + enabled_precisions = {torch.float8_e4m3fn, torch.float16} + ptq_config = mtq.FP8_DEFAULT_CFG +else: # int8 + enabled_precisions = {torch.int8, torch.float16} + ptq_config = mtq.INT8_DEFAULT_CFG + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None +print(f"\nUsing {args.dtype} quantization") +# %% +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +# pipe.transformer = FluxTransformer2DModel( +# num_layers=1, num_single_layers=1, guidance_embeds=True +# ) + +pipe.to(DEVICE).to(torch.float16) +# Store the config and transformer backbone +config = pipe.transformer.config +# global backbone +backbone = pipe.transformer +backbone.eval() + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +def benchmark(prompt, inference_step, batch_size=2, iterations=1): + from time import time + + start = time() + for i in range(iterations): + image = pipe( + prompt, + output_type="pil", + num_inference_steps=inference_step, + num_images_per_prompt=batch_size, + ).images + end = time() + print("Time Elapse for", iterations, "iterations:", end - start) + print("Average Latency Per Step:", (end - start) / inference_step / iterations) + return image + + +# %% +# Quantization + + +def do_calibrate( + pipe, + prompt: str, +) -> None: + """ + Run calibration steps on the pipeline using the given prompts. + """ + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(0), + ).images[0] + + +def forward_loop(mod): + # Switch the pipeline's backbone, run calibration + pipe.transformer = mod + do_calibrate( + pipe=pipe, + prompt="test", + ) + + +backbone = mtq.quantize(backbone, ptq_config, forward_loop) +mtq.disable_quantizer(backbone, filter_func) + +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=2) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( + DEVICE + ), + "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, +} + +# This will create an exported program which is going to be compiled with Torch-TensorRT +with export_torch_mode(): + ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + +with torch_tensorrt.logging.debug(): + trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions=enabled_precisions, + truncate_double=True, + min_block_size=1, + debug=False, + use_python_runtime=True, + immutable_weights=True, + offload_module_to_cpu=True, + ) + + +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + + +# %% +trt_gm.device = torch.device(DEVICE) +# Function which generates images from the flux pipeline +generate_image(pipe, ["A golden retriever"], "dog_code2") + + +print(f"Benchmark TRT Module Latency at ({args.dtype})") +benchmark(["Test"], 50, batch_size=2, iterations=3) +print() + +# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB