from lmnr import evaluate, LaminarDataset
from openai import OpenAI
import json
import os
from dotenv import load_dotenv
load_dotenv()
client = OpenAI()
# Same tools as production - consistency is key
tools = [
{
"type": "function",
"function": {
"name": "query_database",
"description": "Execute SQL queries to retrieve data from the database",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "SQL query to execute"},
"database": {"type": "string", "enum": ["analytics", "sales", "users"]}
},
"required": ["query", "database"]
}
}
},
{
"type": "function",
"function": {
"name": "create_visualization",
"description": "Create charts and graphs from data",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to visualize"},
"chart_type": {"type": "string", "enum": ["line", "bar", "pie", "scatter"]},
"title": {"type": "string", "description": "Chart title"}
},
"required": ["data", "chart_type"]
}
}
},
{
"type": "function",
"function": {
"name": "compare_periods",
"description": "Compare metrics across different time periods",
"parameters": {
"type": "object",
"properties": {
"metric": {"type": "string", "description": "Metric to compare"},
"period1": {"type": "string", "description": "First time period"},
"period2": {"type": "string", "description": "Second time period"}
},
"required": ["metric", "period1", "period2"]
}
}
}
]
def data_analysis_agent(data):
"""Executor function that tests new system prompts"""
# Get the original messages from the dataset
original_messages = data["input"]
# Create new messages with improved system prompt
# This is where you test prompt improvements!
messages = [
{
"role": "system",
"content": """You are a data analysis assistant. Use the available tools to help users analyze their data and generate insights. Always:
1. Query the appropriate database first
2. Create visualizations when helpful
3. Provide clear summaries of findings
4. Compare time periods when relevant
IMPORTANT: Always start by understanding what data you need, then query it, then process it."""
}
]
# Add all non-system messages from the original conversation
for msg in original_messages:
if msg["role"] != "system":
messages.append(msg)
response = client.chat.completions.create(
model="o4-mini",
messages=messages,
tools=tools,
tool_choice="auto"
)
return response
def evaluate_tool_selection(output, target):
"""Evaluator to check if correct tools were selected"""
# Extract actual tool calls from the response
actual_tool_calls = []
if hasattr(output.choices[0].message, 'tool_calls') and output.choices[0].message.tool_calls:
actual_tool_calls = [
{
"name": call.function.name,
"arguments": json.loads(call.function.arguments)
}
for call in output.choices[0].message.tool_calls
]
# Extract expected tool calls from target - fix based on actual structure
if isinstance(target, dict):
expected_tool_calls = target.get("output", [])
# Handle nested list structure: [[{tool_objects}]] -> [{tool_objects}]
if expected_tool_calls and isinstance(expected_tool_calls[0], list):
expected_tool_calls = expected_tool_calls[0]
elif isinstance(target, list):
expected_tool_calls = target
else:
# If target is something else (like a string), try to parse it
try:
if isinstance(target, str):
parsed_target = json.loads(target)
expected_tool_calls = parsed_target.get("output", []) if isinstance(parsed_target, dict) else parsed_target
else:
expected_tool_calls = []
except:
expected_tool_calls = []
# Check if we called the expected tools
expected_tool_names = [tool["name"] for tool in expected_tool_calls]
actual_tool_names = [tool["name"] for tool in actual_tool_calls]
# Simple binary check: did we call all expected tools?
for expected_tool in expected_tool_names:
if expected_tool not in actual_tool_names:
return 0
return 1
# Run the evaluation
evaluate(
data=LaminarDataset("eval_dataset"),
executor=data_analysis_agent,
evaluators={
"tool_selection": evaluate_tool_selection
},
project_api_key=os.environ["LMNR_PROJECT_API_KEY"],
group_name="improved_system_prompt_v1"
)