Skip to content

Commit 606ef1d

Browse files
committed
Merge branch 'main' into redteam
2 parents 5830c0f + 425fb2a commit 606ef1d

33 files changed

+782
-457
lines changed

.devcontainer/devcontainer.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@
2929
"extensions": [
3030
"ms-python.python",
3131
"ms-python.vscode-pylance",
32+
"ms-python.vscode-python-envs",
3233
"charliermarsh.ruff",
3334
"mtxr.sqltools",
3435
"mtxr.sqltools-driver-pg",
36+
"esbenp.prettier-vscode",
37+
"mechatroner.rainbow-csv",
3538
"ms-vscode.vscode-node-azure-pack",
3639
"esbenp.prettier-vscode",
3740
"twixes.pypi-assistant",
38-
"ms-python.vscode-python-envs"
41+
"ms-python.vscode-python-envs",
42+
"teamsdevapp.vscode-ai-foundry",
43+
"ms-windows-ai-studio.windows-ai-studio"
3944
],
4045
// Set *default* container specific settings.json values on container create.
4146
"settings": {

.env.sample

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ POSTGRES_PASSWORD=postgres
55
POSTGRES_DATABASE=postgres
66
POSTGRES_SSL=disable
77

8-
# OPENAI_CHAT_HOST can be either azure, openai, or ollama:
8+
# OPENAI_CHAT_HOST can be either azure, openai, ollama, or github:
99
OPENAI_CHAT_HOST=azure
10-
# OPENAI_EMBED_HOST can be either azure or openai:
10+
# OPENAI_EMBED_HOST can be either azure, openai, ollama, or github:
1111
OPENAI_EMBED_HOST=azure
1212
# Needed for Azure:
1313
# You also need to `azd auth login` if running this locally
@@ -28,10 +28,17 @@ AZURE_OPENAI_KEY=
2828
OPENAICOM_KEY=YOUR-OPENAI-API-KEY
2929
OPENAICOM_CHAT_MODEL=gpt-3.5-turbo
3030
OPENAICOM_EMBED_MODEL=text-embedding-3-large
31-
OPENAICOM_EMBED_MODEL_DIMENSIONS=1024
31+
OPENAICOM_EMBED_DIMENSIONS=1024
3232
OPENAICOM_EMBEDDING_COLUMN=embedding_3l
3333
# Needed for Ollama:
3434
OLLAMA_ENDPOINT=http://host.docker.internal:11434/v1
3535
OLLAMA_CHAT_MODEL=llama3.1
3636
OLLAMA_EMBED_MODEL=nomic-embed-text
3737
OLLAMA_EMBEDDING_COLUMN=embedding_nomic
38+
# Needed for GitHub Models:
39+
GITHUB_TOKEN=YOUR-GITHUB-TOKEN
40+
GITHUB_BASE_URL=https://models.inference.ai.azure.com
41+
GITHUB_MODEL=gpt-4o
42+
GITHUB_EMBED_MODEL=text-embedding-3-large
43+
GITHUB_EMBED_DIMENSIONS=1024
44+
GITHUB_EMBEDDING_COLUMN=embedding_3l

.github/workflows/app-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
architecture: x64
8686

8787
- name: Install uv
88-
uses: astral-sh/setup-uv@v5
88+
uses: astral-sh/setup-uv@v6
8989
with:
9090
enable-cache: true
9191
version: "0.4.20"
@@ -123,7 +123,7 @@ jobs:
123123
key: mypy${{ matrix.os }}-${{ matrix.python_version }}-${{ hashFiles('requirements-dev.txt', 'src/backend/requirements.txt', 'src/backend/pyproject.toml') }}
124124

125125
- name: Run MyPy
126-
run: python3 -m mypy .
126+
run: python3 -m mypy . --python-version ${{ matrix.python_version }}
127127

128128
- name: Run Pytest
129129
run: python3 -m pytest -s -vv --cov --cov-fail-under=85

.github/workflows/evaluate.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
python-version: '3.12'
8383

8484
- name: Install uv
85-
uses: astral-sh/setup-uv@v5
85+
uses: astral-sh/setup-uv@v6
8686
with:
8787
enable-cache: true
8888
version: "0.4.20"

evals/evaluate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def get_openai_config() -> dict:
6666
# azure-ai-evaluate will call DefaultAzureCredential behind the scenes,
6767
# so we must be logged in to Azure CLI with the correct tenant
6868
openai_config["model"] = os.environ["AZURE_OPENAI_EVAL_MODEL"]
69+
elif os.environ.get("OPENAI_CHAT_HOST") == "ollama":
70+
raise NotImplementedError("Ollama is not supported. Switch to Azure or OpenAI.com")
71+
elif os.environ.get("OPENAI_CHAT_HOST") == "github":
72+
raise NotImplementedError("GitHub Models is not supported. Switch to Azure or OpenAI.com")
6973
else:
7074
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
7175
openai_config = {"api_key": os.environ["OPENAICOM_KEY"], "model": "gpt-4"}

evals/generate_ground_truth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def get_openai_client() -> tuple[Union[AzureOpenAI, OpenAI], str]:
101101
)
102102
model = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
103103
elif OPENAI_CHAT_HOST == "ollama":
104-
raise NotImplementedError("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com")
104+
raise NotImplementedError("Ollama is not supported. Switch to Azure or OpenAI.com")
105+
elif OPENAI_CHAT_HOST == "github":
106+
raise NotImplementedError("GitHub Models is not supported. Switch to Azure or OpenAI.com")
105107
else:
106108
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
107109
openai_client = OpenAI(api_key=os.environ["OPENAICOM_KEY"])

evals/safety_evaluation.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import argparse
22
import asyncio
3+
import datetime
34
import logging
45
import os
56
import pathlib
67
import sys
8+
from typing import Optional
79

810
import requests
911
from azure.ai.evaluation import AzureAIProject
@@ -52,7 +54,7 @@ async def callback(
5254
return {"messages": messages + [message]}
5355

5456

55-
async def run_simulator(target_url: str, max_simulations: int):
57+
async def run_simulator(target_url: str, max_simulations: int, scan_name: Optional[str] = None):
5658
credential = get_azure_credential()
5759
azure_ai_project: AzureAIProject = {
5860
"subscription_id": os.getenv("AZURE_SUBSCRIPTION_ID"),
@@ -64,26 +66,25 @@ async def run_simulator(target_url: str, max_simulations: int):
6466
credential=credential,
6567
risk_categories=[
6668
RiskCategory.Violence,
67-
# RiskCategory.HateUnfairness,
68-
# RiskCategory.Sexual,
69-
# RiskCategory.SelfHarm,
69+
RiskCategory.HateUnfairness,
70+
RiskCategory.Sexual,
71+
RiskCategory.SelfHarm,
7072
],
7173
num_objectives=1,
7274
)
75+
if scan_name is None:
76+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
77+
scan_name = f"Safety evaluation {timestamp}"
7378
await model_red_team.scan(
7479
target=lambda messages, stream=False, session_state=None, context=None: callback(messages, target_url),
75-
scan_name="Advanced-Callback-Scan",
80+
scan_name=scan_name,
7681
attack_strategies=[
77-
AttackStrategy.EASY, # Group of easy complexity attacks
78-
# AttackStrategy.MODERATE, # Group of moderate complexity attacks
79-
# AttackStrategy.CharacterSpace, # Add character spaces
80-
# AttackStrategy.ROT13, # Use ROT13 encoding
81-
# AttackStrategy.UnicodeConfusable, # Use confusable Unicode characters
82-
# AttackStrategy.CharSwap, # Swap characters in prompts
83-
# AttackStrategy.Morse, # Encode prompts in Morse code
84-
# AttackStrategy.Leetspeak, # Use Leetspeak
85-
# AttackStrategy.Url, # Use URLs in prompts
86-
# AttackStrategy.Binary, # Encode prompts in binary
82+
AttackStrategy.DIFFICULT,
83+
AttackStrategy.Baseline,
84+
AttackStrategy.UnicodeConfusable, # Use confusable Unicode characters
85+
AttackStrategy.Morse, # Encode prompts in Morse code
86+
AttackStrategy.Leetspeak, # Use Leetspeak
87+
AttackStrategy.Url, # Use URLs in prompts
8788
],
8889
output_path="Advanced-Callback-Scan.json",
8990
)
@@ -97,28 +98,29 @@ async def run_simulator(target_url: str, max_simulations: int):
9798
parser.add_argument(
9899
"--max_simulations", type=int, default=200, help="Maximum number of simulations (question/response pairs)."
99100
)
101+
# argument for the name
102+
parser.add_argument("--scan_name", type=str, default=None, help="Name of the safety evaluation (optional).")
100103
args = parser.parse_args()
101104

102105
# Configure logging to show tracebacks for warnings and above
103106
logging.basicConfig(
104-
level=logging.DEBUG,
107+
level=logging.WARNING,
105108
format="%(message)s",
106109
datefmt="[%X]",
107110
handlers=[RichHandler(rich_tracebacks=False, show_path=True)],
108111
)
109112

110113
# Set urllib3 and azure libraries to WARNING level to see connection issues
111114
logging.getLogger("urllib3").setLevel(logging.WARNING)
112-
logging.getLogger("azure").setLevel(logging.DEBUG)
113-
logging.getLogger("RedTeamLogger").setLevel(logging.DEBUG)
115+
logging.getLogger("azure").setLevel(logging.WARNING)
114116

115117
# Set our application logger to INFO level
116118
logger.setLevel(logging.INFO)
117119

118120
load_azd_env()
119121

120122
try:
121-
asyncio.run(run_simulator(args.target_url, args.max_simulations))
123+
asyncio.run(run_simulator(args.target_url, args.max_simulations, args.scan_name))
122124
except Exception:
123125
logging.exception("Unhandled exception in safety evaluation")
124126
sys.exit(1)

infra/main.bicep

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ var webAppEnv = union(azureOpenAIKeyEnv, openAIComKeyEnv, [
302302
value: openAIEmbedHost
303303
}
304304
{
305-
name: 'OPENAICOM_EMBED_MODEL_DIMENSIONS'
305+
name: 'OPENAICOM_EMBED_DIMENSIONS'
306306
value: openAIEmbedHost == 'openaicom' ? '1024' : ''
307307
}
308308
{

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ lint.isort.known-first-party = ["fastapi_app"]
77

88
[tool.mypy]
99
check_untyped_defs = true
10-
python_version = 3.9
1110
exclude = [".venv/*"]
1211

1312
[tool.pytest.ini_options]

src/backend/fastapi_app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)
@@ -53,6 +59,7 @@ def create_app(testing: bool = False):
5359
if not testing:
5460
load_dotenv(override=True)
5561
logging.basicConfig(level=logging.INFO)
62+
5663
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5764
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5865
logging.getLogger("azure.identity").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from enum import Enum
22
from typing import Any, Optional
33

4-
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
4+
from openai.types.responses import ResponseInputItemParam
5+
from pydantic import BaseModel, Field
66

77

88
class AIChatRoles(str, Enum):
@@ -36,19 +36,39 @@ class ChatRequestContext(BaseModel):
3636

3737

3838
class ChatRequest(BaseModel):
39-
messages: list[ChatCompletionMessageParam]
39+
messages: list[ResponseInputItemParam]
4040
context: ChatRequestContext
4141
sessionState: Optional[Any] = None
4242

4343

44+
class ItemPublic(BaseModel):
45+
id: int
46+
type: str
47+
brand: str
48+
name: str
49+
description: str
50+
price: float
51+
52+
def to_str_for_rag(self):
53+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
54+
55+
56+
class ItemWithDistance(ItemPublic):
57+
distance: float
58+
59+
def __init__(self, **data):
60+
super().__init__(**data)
61+
self.distance = round(self.distance, 2)
62+
63+
4464
class ThoughtStep(BaseModel):
4565
title: str
4666
description: Any
4767
props: dict = {}
4868

4969

5070
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
71+
data_points: dict[int, ItemPublic]
5272
thoughts: list[ThoughtStep]
5373
followup_questions: Optional[list[str]] = None
5474

@@ -69,27 +89,39 @@ class RetrievalResponseDelta(BaseModel):
6989
sessionState: Optional[Any] = None
7090

7191

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
81-
class ItemWithDistance(ItemPublic):
82-
distance: float
83-
84-
def __init__(self, **data):
85-
super().__init__(**data)
86-
self.distance = round(self.distance, 2)
87-
88-
8992
class ChatParams(ChatRequestOverrides):
9093
prompt_template: str
9194
response_token_limit: int = 1024
9295
enable_text_search: bool
9396
enable_vector_search: bool
9497
original_user_query: str
95-
past_messages: list[ChatCompletionMessageParam]
98+
past_messages: list[ResponseInputItemParam]
99+
100+
101+
class Filter(BaseModel):
102+
column: str
103+
comparison_operator: str
104+
value: Any
105+
106+
107+
class PriceFilter(Filter):
108+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
109+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
110+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
111+
112+
113+
class BrandFilter(Filter):
114+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
115+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
116+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
117+
118+
119+
class SearchResults(BaseModel):
120+
query: str
121+
"""The original search query"""
122+
123+
items: list[ItemPublic]
124+
"""List of items that match the search query and filters"""
125+
126+
filters: list[Filter]
127+
"""List of filters applied to the search results"""

src/backend/fastapi_app/dependencies.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ async def common_parameters():
5151
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL") or "nomic-embed-text"
5252
openai_embed_dimensions = None
5353
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN") or "embedding_nomic"
54+
elif OPENAI_EMBED_HOST == "github":
55+
openai_embed_deployment = None
56+
openai_embed_model = os.getenv("GITHUB_EMBED_MODEL") or "text-embedding-3-large"
57+
openai_embed_dimensions = int(os.getenv("GITHUB_EMBED_DIMENSIONS", 1024))
58+
embedding_column = os.getenv("GITHUB_EMBEDDING_COLUMN") or "embedding_3l"
5459
else:
5560
openai_embed_deployment = None
5661
openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL") or "text-embedding-3-large"
@@ -63,6 +68,9 @@ async def common_parameters():
6368
openai_chat_deployment = None
6469
openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL") or "phi3:3.8b"
6570
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL") or "nomic-embed-text"
71+
elif OPENAI_CHAT_HOST == "github":
72+
openai_chat_deployment = None
73+
openai_chat_model = os.getenv("GITHUB_MODEL") or "gpt-4o"
6674
else:
6775
openai_chat_deployment = None
6876
openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL") or "gpt-3.5-turbo"

0 commit comments

Comments
 (0)