8
8
from typing import Optional
9
9
10
10
import requests
11
- from azure .ai .evaluation import AzureAIProject
12
11
from azure .ai .evaluation .red_team import AttackStrategy , RedTeam , RiskCategory
13
12
from azure .identity import AzureDeveloperCliCredential
14
13
from dotenv_azd import load_azd_env
15
- from rich .logging import RichHandler
16
-
17
- logger = logging .getLogger ("ragapp" )
18
-
19
- # Configure logging to capture and display warnings with tracebacks
20
- logging .captureWarnings (True ) # Capture warnings as log messages
21
14
22
15
root_dir = pathlib .Path (__file__ ).parent
23
16
24
17
25
18
def get_azure_credential ():
26
19
AZURE_TENANT_ID = os .getenv ("AZURE_TENANT_ID" )
27
20
if AZURE_TENANT_ID :
28
- logger . info ("Setting up Azure credential using AzureDeveloperCliCredential with tenant_id %s" , AZURE_TENANT_ID )
21
+ print ("Setting up Azure credential using AzureDeveloperCliCredential with tenant_id %s" , AZURE_TENANT_ID )
29
22
azure_credential = AzureDeveloperCliCredential (tenant_id = AZURE_TENANT_ID , process_timeout = 60 )
30
23
else :
31
- logger . info ("Setting up Azure credential using AzureDeveloperCliCredential for home tenant" )
24
+ print ("Setting up Azure credential using AzureDeveloperCliCredential for home tenant" )
32
25
azure_credential = AzureDeveloperCliCredential (process_timeout = 60 )
33
26
return azure_credential
34
27
35
28
36
- async def callback (
37
- messages : list ,
29
+ def callback (
30
+ question : str ,
38
31
target_url : str = "http://127.0.0.1:8000/chat" ,
39
32
):
40
- query = messages [- 1 ].content
41
33
headers = {"Content-Type" : "application/json" }
42
34
body = {
43
- "messages" : [{"content" : query , "role" : "user" }],
35
+ "messages" : [{"content" : question , "role" : "user" }],
44
36
"stream" : False ,
45
- "context" : {"overrides" : {"use_advanced_flow" : True , "top" : 3 , "retrieval_mode" : "hybrid" , "temperature" : 0.3 }},
37
+ "context" : {
38
+ "overrides" : {"use_advanced_flow" : False , "top" : 3 , "retrieval_mode" : "hybrid" , "temperature" : 0.3 }
39
+ },
46
40
}
47
41
url = target_url
48
42
r = requests .post (url , headers = headers , json = body )
49
43
response = r .json ()
50
44
if "error" in response :
51
- message = { "content" : response [" error" ], "role" : "assistant" }
45
+ return f"Error received: { response [' error' ] } "
52
46
else :
53
- message = response ["message" ]
54
- return {"messages" : messages + [message ]}
47
+ return response ["message" ]["content" ]
55
48
56
49
57
- async def run_simulator (target_url : str , max_simulations : int , scan_name : Optional [str ] = None ):
58
- credential = get_azure_credential ()
59
- azure_ai_project : AzureAIProject = {
60
- "subscription_id" : os .getenv ("AZURE_SUBSCRIPTION_ID" ),
61
- "resource_group_name" : os .getenv ("AZURE_RESOURCE_GROUP" ),
62
- "project_name" : "pf-testprojforaisaety" ,
63
- }
50
+ async def run_redteaming (target_url : str , questions_per_category : int = 1 , scan_name : Optional [str ] = None ):
51
+ AZURE_AI_FOUNDRY = os .getenv ("AZURE_AI_FOUNDRY" )
52
+ AZURE_AI_PROJECT = os .getenv ("AZURE_AI_PROJECT" )
64
53
model_red_team = RedTeam (
65
- azure_ai_project = azure_ai_project ,
66
- credential = credential ,
54
+ azure_ai_project = f"https:// { AZURE_AI_FOUNDRY } .services.ai.azure.com/api/projects/ { AZURE_AI_PROJECT } " ,
55
+ credential = get_azure_credential () ,
67
56
risk_categories = [
68
57
RiskCategory .Violence ,
69
58
RiskCategory .HateUnfairness ,
70
59
RiskCategory .Sexual ,
71
60
RiskCategory .SelfHarm ,
72
61
],
73
- num_objectives = 1 ,
62
+ num_objectives = questions_per_category ,
74
63
)
64
+
75
65
if scan_name is None :
76
66
timestamp = datetime .datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
77
67
scan_name = f"Safety evaluation { timestamp } "
68
+
78
69
await model_red_team .scan (
79
- target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
80
70
scan_name = scan_name ,
71
+ output_path = f"{ root_dir } /redteams/{ scan_name } .json" ,
81
72
attack_strategies = [
82
- AttackStrategy .DIFFICULT ,
83
73
AttackStrategy .Baseline ,
84
- AttackStrategy .UnicodeConfusable , # Use confusable Unicode characters
85
- AttackStrategy .Morse , # Encode prompts in Morse code
86
- AttackStrategy .Leetspeak , # Use Leetspeak
87
- AttackStrategy .Url , # Use URLs in prompts
74
+ # Easy Complexity:
75
+ AttackStrategy .Morse ,
76
+ AttackStrategy .UnicodeConfusable ,
77
+ AttackStrategy .Url ,
78
+ # Moderate Complexity:
79
+ AttackStrategy .Tense ,
80
+ # Difficult Complexity:
81
+ AttackStrategy .Compose ([AttackStrategy .Tense , AttackStrategy .Url ]),
88
82
],
89
- output_path = "Advanced-Callback-Scan.json" ,
83
+ target = lambda query : callback ( query , target_url ) ,
90
84
)
91
85
92
86
@@ -96,31 +90,17 @@ async def run_simulator(target_url: str, max_simulations: int, scan_name: Option
96
90
"--target_url" , type = str , default = "http://127.0.0.1:8000/chat" , help = "Target URL for the callback."
97
91
)
98
92
parser .add_argument (
99
- "--max_simulations" , type = int , default = 200 , help = "Maximum number of simulations (question/response pairs)."
93
+ "--questions_per_category" ,
94
+ type = int ,
95
+ default = 1 ,
96
+ help = "Number of questions per risk category to ask during the scan." ,
100
97
)
101
- # argument for the name
102
98
parser .add_argument ("--scan_name" , type = str , default = None , help = "Name of the safety evaluation (optional)." )
103
99
args = parser .parse_args ()
104
100
105
- # Configure logging to show tracebacks for warnings and above
106
- logging .basicConfig (
107
- level = logging .WARNING ,
108
- format = "%(message)s" ,
109
- datefmt = "[%X]" ,
110
- handlers = [RichHandler (rich_tracebacks = False , show_path = True )],
111
- )
112
-
113
- # Set urllib3 and azure libraries to WARNING level to see connection issues
114
- logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
115
- logging .getLogger ("azure" ).setLevel (logging .WARNING )
116
-
117
- # Set our application logger to INFO level
118
- logger .setLevel (logging .INFO )
119
-
120
101
load_azd_env ()
121
-
122
102
try :
123
- asyncio .run (run_simulator (args .target_url , args .max_simulations , args .scan_name ))
103
+ asyncio .run (run_redteaming (args .target_url , args .questions_per_category , args .scan_name ))
124
104
except Exception :
125
105
logging .exception ("Unhandled exception in safety evaluation" )
126
106
sys .exit (1 )
0 commit comments