|
24 | 24 | """
|
25 | 25 | import argparse
|
26 | 26 | import asyncio
|
| 27 | +import copy |
27 | 28 | import dataclasses
|
28 | 29 | import json
|
29 | 30 | import os
|
30 | 31 | import random
|
31 | 32 | import time
|
| 33 | +import uuid |
32 | 34 | import warnings
|
33 | 35 | from collections.abc import AsyncGenerator
|
34 | 36 | from dataclasses import dataclass
|
@@ -109,24 +111,43 @@ class SampleRequest:
|
109 | 111 |
|
110 | 112 | def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
111 | 113 | args: argparse.Namespace) -> list[SampleRequest]:
|
112 |
| - if args.dataset == 'json': |
| 114 | + if args.dataset == 'json' or args.dataset == 'json-unique': |
113 | 115 | if args.json_schema_path is None:
|
114 | 116 | dir_path = os.path.dirname(os.path.realpath(__file__))
|
115 | 117 | args.json_schema_path = os.path.join(dir_path,
|
116 | 118 | "structured_schemas",
|
117 | 119 | "structured_schema_1.json")
|
| 120 | + json_schemas = [] |
118 | 121 | with open(args.json_schema_path) as f:
|
119 | 122 | schema = json.load(f)
|
120 |
| - prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 |
121 |
| - input_len = len(tokenizer(prompt).input_ids) |
122 |
| - print(f"Input length of the prompt: {input_len} tokens") |
| 123 | + |
| 124 | + if args.dataset == 'json-unique': |
| 125 | + json_schemas = [ |
| 126 | + copy.deepcopy(schema) for _ in range(args.num_prompts) |
| 127 | + ] |
| 128 | + for i in range(len(json_schemas)): |
| 129 | + json_schemas[i]["properties"][ |
| 130 | + f"__optional_field_{uuid.uuid4()}"] = { |
| 131 | + "type": |
| 132 | + "string", |
| 133 | + "description": |
| 134 | + "An unique optional field to avoid cached schemas" |
| 135 | + } |
| 136 | + |
| 137 | + def gen_prompt(index: int): |
| 138 | + schema = json_schemas[index % len(json_schemas)] |
| 139 | + return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 |
| 140 | + |
| 141 | + def get_schema(index: int): |
| 142 | + return json_schemas[index % len(json_schemas)] |
| 143 | + |
123 | 144 | requests = [
|
124 |
| - SampleRequest(prompt=prompt, |
125 |
| - prompt_len=input_len, |
| 145 | + SampleRequest(prompt=gen_prompt(i), |
| 146 | + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), |
126 | 147 | expected_output_len=args.output_len,
|
127 |
| - schema=schema, |
| 148 | + schema=get_schema(i), |
128 | 149 | structure_type=args.structure_type)
|
129 |
| - for _ in range(args.num_prompts) |
| 150 | + for i in range(args.num_prompts) |
130 | 151 | ]
|
131 | 152 |
|
132 | 153 | elif args.dataset == "grammar":
|
@@ -821,10 +842,12 @@ def main(args: argparse.Namespace):
|
821 | 842 | default="/v1/completions",
|
822 | 843 | help="API endpoint.",
|
823 | 844 | )
|
824 |
| - parser.add_argument( |
825 |
| - "--dataset", |
826 |
| - default='json', |
827 |
| - choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) |
| 845 | + parser.add_argument("--dataset", |
| 846 | + default='json', |
| 847 | + choices=[ |
| 848 | + 'json', 'json-unique', 'grammar', 'regex', |
| 849 | + 'choice', 'xgrammar_bench' |
| 850 | + ]) |
828 | 851 | parser.add_argument("--json_schema_path",
|
829 | 852 | type=str,
|
830 | 853 | default=None,
|
@@ -966,11 +989,12 @@ def main(args: argparse.Namespace):
|
966 | 989 | type=float,
|
967 | 990 | default=1.0,
|
968 | 991 | help="Ratio of Structured Outputs requests")
|
969 |
| - parser.add_argument("--structured-output-backend", |
970 |
| - type=str, |
971 |
| - choices=["outlines", "lm-format-enforcer", "xgrammar"], |
972 |
| - default="xgrammar", |
973 |
| - help="Backend to use for structured outputs") |
| 992 | + parser.add_argument( |
| 993 | + "--structured-output-backend", |
| 994 | + type=str, |
| 995 | + choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"], |
| 996 | + default="xgrammar", |
| 997 | + help="Backend to use for structured outputs") |
974 | 998 |
|
975 | 999 | args = parser.parse_args()
|
976 | 1000 | main(args)
|
0 commit comments