Skip to content

Commit 9085aab

Browse files
authored
[benchmarks] Add option to use unique jsonschema for each request (#14457)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 8d5aa46 commit 9085aab

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

benchmarks/benchmark_serving_structured_output.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
"""
2525
import argparse
2626
import asyncio
27+
import copy
2728
import dataclasses
2829
import json
2930
import os
3031
import random
3132
import time
33+
import uuid
3234
import warnings
3335
from collections.abc import AsyncGenerator
3436
from dataclasses import dataclass
@@ -109,24 +111,43 @@ class SampleRequest:
109111

110112
def sample_requests(tokenizer: PreTrainedTokenizerBase,
111113
args: argparse.Namespace) -> list[SampleRequest]:
112-
if args.dataset == 'json':
114+
if args.dataset == 'json' or args.dataset == 'json-unique':
113115
if args.json_schema_path is None:
114116
dir_path = os.path.dirname(os.path.realpath(__file__))
115117
args.json_schema_path = os.path.join(dir_path,
116118
"structured_schemas",
117119
"structured_schema_1.json")
120+
json_schemas = []
118121
with open(args.json_schema_path) as f:
119122
schema = json.load(f)
120-
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
121-
input_len = len(tokenizer(prompt).input_ids)
122-
print(f"Input length of the prompt: {input_len} tokens")
123+
124+
if args.dataset == 'json-unique':
125+
json_schemas = [
126+
copy.deepcopy(schema) for _ in range(args.num_prompts)
127+
]
128+
for i in range(len(json_schemas)):
129+
json_schemas[i]["properties"][
130+
f"__optional_field_{uuid.uuid4()}"] = {
131+
"type":
132+
"string",
133+
"description":
134+
"An unique optional field to avoid cached schemas"
135+
}
136+
137+
def gen_prompt(index: int):
138+
schema = json_schemas[index % len(json_schemas)]
139+
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
140+
141+
def get_schema(index: int):
142+
return json_schemas[index % len(json_schemas)]
143+
123144
requests = [
124-
SampleRequest(prompt=prompt,
125-
prompt_len=input_len,
145+
SampleRequest(prompt=gen_prompt(i),
146+
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
126147
expected_output_len=args.output_len,
127-
schema=schema,
148+
schema=get_schema(i),
128149
structure_type=args.structure_type)
129-
for _ in range(args.num_prompts)
150+
for i in range(args.num_prompts)
130151
]
131152

132153
elif args.dataset == "grammar":
@@ -821,10 +842,12 @@ def main(args: argparse.Namespace):
821842
default="/v1/completions",
822843
help="API endpoint.",
823844
)
824-
parser.add_argument(
825-
"--dataset",
826-
default='json',
827-
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
845+
parser.add_argument("--dataset",
846+
default='json',
847+
choices=[
848+
'json', 'json-unique', 'grammar', 'regex',
849+
'choice', 'xgrammar_bench'
850+
])
828851
parser.add_argument("--json_schema_path",
829852
type=str,
830853
default=None,
@@ -966,11 +989,12 @@ def main(args: argparse.Namespace):
966989
type=float,
967990
default=1.0,
968991
help="Ratio of Structured Outputs requests")
969-
parser.add_argument("--structured-output-backend",
970-
type=str,
971-
choices=["outlines", "lm-format-enforcer", "xgrammar"],
972-
default="xgrammar",
973-
help="Backend to use for structured outputs")
992+
parser.add_argument(
993+
"--structured-output-backend",
994+
type=str,
995+
choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"],
996+
default="xgrammar",
997+
help="Backend to use for structured outputs")
974998

975999
args = parser.parse_args()
9761000
main(args)

0 commit comments

Comments
 (0)