fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
1403 lines
48 KiB
Python
1403 lines
48 KiB
Python
"""
|
|
CRUD ENDPOINTS FOR PROMPTS
|
|
"""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
File,
|
|
HTTPException,
|
|
Request,
|
|
Response,
|
|
UploadFile,
|
|
)
|
|
from pydantic import BaseModel
|
|
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
|
from litellm.proxy.auth.auth_utils import is_request_body_safe
|
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
|
from litellm.proxy.common_utils.path_utils import safe_filename
|
|
from litellm.repositories.table_repositories import PromptRepository
|
|
from litellm.types.prompts.init_prompts import (
|
|
ListPromptsResponse,
|
|
PromptInfo,
|
|
PromptInfoResponse,
|
|
PromptLiteLLMParams,
|
|
PromptSpec,
|
|
PromptTemplateBase,
|
|
)
|
|
from litellm.types.proxy.prompt_endpoints import TestPromptRequest
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def get_base_prompt_id(prompt_id: str) -> str:
|
|
"""
|
|
Extract the base prompt ID by stripping the version suffix if present.
|
|
|
|
Args:
|
|
prompt_id: Prompt ID that may include version suffix (e.g., "jack_success.v1" or "jack_success_v1")
|
|
|
|
Returns:
|
|
Base prompt ID without version suffix (e.g., "jack_success")
|
|
|
|
Examples:
|
|
>>> get_base_prompt_id("jack_success.v1")
|
|
"jack_success"
|
|
>>> get_base_prompt_id("jack_success_v1")
|
|
"jack_success"
|
|
>>> get_base_prompt_id("jack_success")
|
|
"jack_success"
|
|
"""
|
|
# Try dot separator first (.v)
|
|
if ".v" in prompt_id:
|
|
return prompt_id.split(".v")[0]
|
|
# Try underscore separator (_v)
|
|
if "_v" in prompt_id:
|
|
return prompt_id.split("_v")[0]
|
|
return prompt_id
|
|
|
|
|
|
def get_version_number(prompt_id: str) -> int:
|
|
"""
|
|
Extract the version number from a versioned prompt ID.
|
|
|
|
Args:
|
|
prompt_id: Prompt ID that may include version suffix (e.g., "jack_success.v2" or "jack_success_v2")
|
|
|
|
Returns:
|
|
Version number (defaults to 1 if no version suffix or invalid format)
|
|
|
|
Examples:
|
|
>>> get_version_number("jack_success.v2")
|
|
2
|
|
>>> get_version_number("jack_success_v2")
|
|
2
|
|
>>> get_version_number("jack_success")
|
|
1
|
|
"""
|
|
# Try dot separator first (.v)
|
|
if ".v" in prompt_id:
|
|
version_str = prompt_id.split(".v")[1]
|
|
try:
|
|
return int(version_str)
|
|
except ValueError:
|
|
pass
|
|
|
|
# Try underscore separator (_v)
|
|
if "_v" in prompt_id:
|
|
version_str = prompt_id.split("_v")[1]
|
|
try:
|
|
return int(version_str)
|
|
except ValueError:
|
|
pass
|
|
|
|
return 1
|
|
|
|
|
|
def construct_versioned_prompt_id(prompt_id: str, version: Optional[int] = None) -> str:
|
|
"""
|
|
Construct a versioned prompt ID from a base prompt_id and version number.
|
|
|
|
Args:
|
|
prompt_id: Base prompt ID (e.g., "jack_success")
|
|
version: Version number (if None, returns the base prompt_id unchanged)
|
|
|
|
Returns:
|
|
Versioned prompt ID (e.g., "jack_success.v4")
|
|
|
|
Examples:
|
|
>>> construct_versioned_prompt_id("jack_success", 4)
|
|
"jack_success.v4"
|
|
>>> construct_versioned_prompt_id("jack_success", None)
|
|
"jack_success"
|
|
>>> construct_versioned_prompt_id("jack_success.v2", 4)
|
|
"jack_success.v4"
|
|
"""
|
|
if version is None:
|
|
return prompt_id
|
|
|
|
# Strip any existing version suffix first
|
|
base_id = get_base_prompt_id(prompt_id)
|
|
return f"{base_id}.v{version}"
|
|
|
|
|
|
def get_latest_version_prompt_id(prompt_id: str, all_prompt_ids: Dict[str, Any]) -> str:
|
|
"""
|
|
Find the latest version of a prompt from available prompt IDs.
|
|
|
|
Args:
|
|
prompt_id: Base prompt ID or versioned prompt ID (e.g., "jack_success" or "jack_success.v2")
|
|
all_prompt_ids: Dictionary of all available prompt IDs (keys are prompt IDs)
|
|
|
|
Returns:
|
|
The prompt ID with the highest version number, or the original prompt_id if no versions exist
|
|
|
|
Examples:
|
|
>>> all_ids = {"jack.v1": {}, "jack.v2": {}, "jack.v3": {}}
|
|
>>> get_latest_version_prompt_id("jack", all_ids)
|
|
"jack.v3"
|
|
>>> get_latest_version_prompt_id("jack.v1", all_ids)
|
|
"jack.v3"
|
|
>>> all_ids = {"simple": {}}
|
|
>>> get_latest_version_prompt_id("simple", all_ids)
|
|
"simple"
|
|
"""
|
|
base_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
|
|
# Find all versions of this prompt
|
|
matching_versions = []
|
|
for stored_prompt_id in all_prompt_ids.keys():
|
|
if get_base_prompt_id(prompt_id=stored_prompt_id) == base_id:
|
|
version_num = get_version_number(prompt_id=stored_prompt_id)
|
|
matching_versions.append((version_num, stored_prompt_id))
|
|
|
|
# Use the highest version number
|
|
if matching_versions:
|
|
matching_versions.sort(reverse=True)
|
|
return matching_versions[0][1]
|
|
else:
|
|
# No versioned prompts found, use the base ID as-is
|
|
return prompt_id
|
|
|
|
|
|
def get_latest_prompt_versions(prompts: List[PromptSpec]) -> List[PromptSpec]:
|
|
"""
|
|
Filter a list of prompts to return only the latest version of each unique prompt.
|
|
|
|
Args:
|
|
prompts: List of PromptSpec objects
|
|
|
|
Returns:
|
|
List of PromptSpec objects with only the latest version of each prompt
|
|
"""
|
|
latest_prompts: Dict[str, PromptSpec] = {}
|
|
|
|
for prompt in prompts:
|
|
base_id = get_base_prompt_id(prompt_id=prompt.prompt_id)
|
|
version = get_version_number(prompt_id=prompt.prompt_id)
|
|
|
|
# Keep the prompt with the highest version number
|
|
if base_id not in latest_prompts:
|
|
latest_prompts[base_id] = prompt
|
|
else:
|
|
existing_version = get_version_number(
|
|
prompt_id=latest_prompts[base_id].prompt_id
|
|
)
|
|
if version > existing_version:
|
|
latest_prompts[base_id] = prompt
|
|
|
|
return list(latest_prompts.values())
|
|
|
|
|
|
async def get_next_version_for_prompt(
|
|
prisma_client, prompt_id: str, environment: str = "development"
|
|
) -> int:
|
|
"""
|
|
Get the next version number for a prompt in a specific environment.
|
|
|
|
Args:
|
|
prisma_client: Prisma database client
|
|
prompt_id: Base prompt ID
|
|
environment: The environment to check versions for
|
|
|
|
Returns:
|
|
Next version number (1 if no versions exist, max_version + 1 otherwise)
|
|
"""
|
|
existing_prompts = await PromptRepository(prisma_client).table.find_many(
|
|
where={"prompt_id": prompt_id, "environment": environment}
|
|
)
|
|
|
|
if existing_prompts:
|
|
max_version = max(p.version for p in existing_prompts)
|
|
return max_version + 1
|
|
else:
|
|
return 1
|
|
|
|
|
|
def create_versioned_prompt_spec(db_prompt) -> PromptSpec:
|
|
"""
|
|
Helper function to create a PromptSpec with versioned prompt_id from a DB prompt entry.
|
|
|
|
Args:
|
|
db_prompt: The DB prompt object (from prisma)
|
|
|
|
Returns:
|
|
PromptSpec with versioned prompt_id (e.g., "chat_prompt.v1")
|
|
"""
|
|
import json
|
|
|
|
from litellm.types.prompts.init_prompts import PromptLiteLLMParams
|
|
|
|
prompt_dict = db_prompt.model_dump()
|
|
base_prompt_id = prompt_dict["prompt_id"]
|
|
version = prompt_dict.get("version", 1)
|
|
environment = prompt_dict.get("environment", "development")
|
|
created_by = prompt_dict.get("created_by")
|
|
|
|
# Parse litellm_params
|
|
litellm_params_data = prompt_dict.get("litellm_params")
|
|
if isinstance(litellm_params_data, str):
|
|
litellm_params_data = json.loads(litellm_params_data)
|
|
litellm_params = PromptLiteLLMParams(**litellm_params_data)
|
|
|
|
# Parse prompt_info
|
|
prompt_info_data = prompt_dict.get("prompt_info")
|
|
if prompt_info_data:
|
|
if isinstance(prompt_info_data, str):
|
|
prompt_info_data = json.loads(prompt_info_data)
|
|
prompt_info = PromptInfo(**prompt_info_data)
|
|
else:
|
|
prompt_info = PromptInfo(prompt_type="db")
|
|
|
|
# Create versioned prompt_id
|
|
versioned_prompt_id = f"{base_prompt_id}.v{version}"
|
|
|
|
return PromptSpec(
|
|
prompt_id=versioned_prompt_id,
|
|
litellm_params=litellm_params,
|
|
prompt_info=prompt_info,
|
|
created_at=prompt_dict.get("created_at"),
|
|
updated_at=prompt_dict.get("updated_at"),
|
|
environment=environment,
|
|
created_by=created_by,
|
|
)
|
|
|
|
|
|
class Prompt(BaseModel):
|
|
prompt_id: str
|
|
litellm_params: PromptLiteLLMParams
|
|
prompt_info: Optional[PromptInfo] = None
|
|
|
|
|
|
class PatchPromptRequest(BaseModel):
|
|
litellm_params: Optional[PromptLiteLLMParams] = None
|
|
prompt_info: Optional[PromptInfo] = None
|
|
|
|
|
|
@router.get(
|
|
"/prompts/list",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=ListPromptsResponse,
|
|
)
|
|
async def list_prompts(
|
|
environment: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
List the prompts that are available on the proxy server
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X GET "http://localhost:4000/prompts/list" -H "Authorization: Bearer <your_api_key>"
|
|
```
|
|
|
|
Example Response:
|
|
```json
|
|
{
|
|
"prompts": [
|
|
{
|
|
"prompt_id": "my_prompt_id",
|
|
"litellm_params": {
|
|
"prompt_id": "my_prompt_id",
|
|
"prompt_integration": "dotprompt",
|
|
"prompt_directory": "/path/to/prompts"
|
|
},
|
|
"prompt_info": {
|
|
"prompt_type": "config"
|
|
},
|
|
"created_at": "2023-11-09T12:34:56.789Z",
|
|
"updated_at": "2023-11-09T12:34:56.789Z"
|
|
}
|
|
]
|
|
}
|
|
```
|
|
"""
|
|
from litellm.proxy._types import LitellmUserRoles
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
|
|
# check key metadata for prompts
|
|
key_metadata = user_api_key_dict.metadata
|
|
if key_metadata is not None:
|
|
prompts = cast(Optional[List[str]], key_metadata.get("prompts", None))
|
|
if prompts is not None:
|
|
all_prompts = [
|
|
IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS[prompt_id]
|
|
for prompt_id in prompts
|
|
if prompt_id in IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS
|
|
]
|
|
if environment:
|
|
all_prompts = [p for p in all_prompts if p.environment == environment]
|
|
prompt_list = []
|
|
for original_prompt in all_prompts:
|
|
# Create a copy with base prompt_id (without version suffix)
|
|
prompt_copy = PromptSpec(
|
|
prompt_id=get_base_prompt_id(prompt_id=original_prompt.prompt_id),
|
|
litellm_params=original_prompt.litellm_params,
|
|
prompt_info=original_prompt.prompt_info,
|
|
created_at=original_prompt.created_at,
|
|
updated_at=original_prompt.updated_at,
|
|
environment=original_prompt.environment,
|
|
created_by=original_prompt.created_by,
|
|
)
|
|
prompt_list.append(prompt_copy)
|
|
return ListPromptsResponse(prompts=prompt_list)
|
|
# check if user is proxy admin - show all prompts
|
|
if user_api_key_dict.user_role is not None and (
|
|
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
# Get all prompts and filter to show only the latest version of each
|
|
all_prompts = list(IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS.values())
|
|
if environment:
|
|
all_prompts = [p for p in all_prompts if p.environment == environment]
|
|
latest_prompts = get_latest_prompt_versions(prompts=all_prompts)
|
|
# Create copies with base prompt_id (without version suffix) for display
|
|
prompts_for_display = []
|
|
for original_prompt in latest_prompts:
|
|
prompt_copy = PromptSpec(
|
|
prompt_id=get_base_prompt_id(prompt_id=original_prompt.prompt_id),
|
|
litellm_params=original_prompt.litellm_params,
|
|
prompt_info=original_prompt.prompt_info,
|
|
created_at=original_prompt.created_at,
|
|
updated_at=original_prompt.updated_at,
|
|
environment=original_prompt.environment,
|
|
created_by=original_prompt.created_by,
|
|
)
|
|
prompts_for_display.append(prompt_copy)
|
|
return ListPromptsResponse(prompts=prompts_for_display)
|
|
else:
|
|
return ListPromptsResponse(prompts=[])
|
|
|
|
|
|
@router.get(
|
|
"/prompts/{prompt_id}/versions",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=ListPromptsResponse,
|
|
)
|
|
async def get_prompt_versions(
|
|
prompt_id: str,
|
|
environment: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Get all versions of a specific prompt by base prompt ID
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X GET "http://localhost:4000/prompts/jack_success/versions" \\
|
|
-H "Authorization: Bearer <your_api_key>"
|
|
```
|
|
|
|
Example Response:
|
|
```json
|
|
{
|
|
"prompts": [
|
|
{
|
|
"prompt_id": "jack_success.v1",
|
|
"litellm_params": {...},
|
|
"prompt_info": {"prompt_type": "db"},
|
|
"created_at": "2023-11-09T12:34:56.789Z",
|
|
"updated_at": "2023-11-09T12:34:56.789Z"
|
|
},
|
|
{
|
|
"prompt_id": "jack_success.v2",
|
|
"litellm_params": {...},
|
|
"prompt_info": {"prompt_type": "db"},
|
|
"created_at": "2023-11-09T13:45:12.345Z",
|
|
"updated_at": "2023-11-09T13:45:12.345Z"
|
|
}
|
|
]
|
|
}
|
|
```
|
|
"""
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
# Only allow proxy admins to view version history
|
|
if user_api_key_dict.user_role is None or (
|
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
raise HTTPException(
|
|
status_code=403, detail="Only proxy admins can view prompt versions"
|
|
)
|
|
|
|
base_prompt_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
|
|
# Query DB for versions
|
|
versioned_prompts = []
|
|
if prisma_client is not None:
|
|
where_clause: Dict[str, Any] = {"prompt_id": base_prompt_id}
|
|
if environment:
|
|
where_clause["environment"] = environment
|
|
db_prompts = await PromptRepository(prisma_client).table.find_many(
|
|
where=where_clause,
|
|
order={"version": "desc"},
|
|
)
|
|
for db_prompt in db_prompts:
|
|
spec = create_versioned_prompt_spec(db_prompt=db_prompt)
|
|
versioned_prompts.append(
|
|
PromptSpec(
|
|
prompt_id=base_prompt_id,
|
|
litellm_params=spec.litellm_params,
|
|
prompt_info=spec.prompt_info,
|
|
created_at=spec.created_at,
|
|
updated_at=spec.updated_at,
|
|
version=get_version_number(prompt_id=spec.prompt_id),
|
|
environment=spec.environment,
|
|
created_by=spec.created_by,
|
|
)
|
|
)
|
|
else:
|
|
# Fallback: in-memory registry (no DB)
|
|
all_prompts = list(IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS.values())
|
|
prompt_versions = [
|
|
prompt
|
|
for prompt in all_prompts
|
|
if get_base_prompt_id(prompt_id=prompt.prompt_id) == base_prompt_id
|
|
and (environment is None or prompt.environment == environment)
|
|
]
|
|
for prompt in prompt_versions:
|
|
version_number = get_version_number(prompt_id=prompt.prompt_id)
|
|
versioned_prompts.append(
|
|
PromptSpec(
|
|
prompt_id=base_prompt_id,
|
|
litellm_params=prompt.litellm_params,
|
|
prompt_info=prompt.prompt_info,
|
|
created_at=prompt.created_at,
|
|
updated_at=prompt.updated_at,
|
|
version=version_number,
|
|
environment=prompt.environment,
|
|
created_by=prompt.created_by,
|
|
)
|
|
)
|
|
versioned_prompts.sort(key=lambda p: p.version or 1, reverse=True)
|
|
|
|
if not versioned_prompts:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"No versions found for prompt ID {base_prompt_id}"
|
|
)
|
|
|
|
return ListPromptsResponse(prompts=versioned_prompts)
|
|
|
|
|
|
def _get_prompt_template(
|
|
prompt_spec: PromptSpec, base_prompt_id: str
|
|
) -> Optional[PromptTemplateBase]:
|
|
"""Resolve the raw prompt template from dotprompt content or the in-memory registry."""
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
|
|
try:
|
|
dotprompt_content = prompt_spec.litellm_params.dotprompt_content
|
|
if dotprompt_content:
|
|
from litellm.integrations.dotprompt import (
|
|
_get_prompt_data_from_dotprompt_content,
|
|
)
|
|
|
|
parsed = _get_prompt_data_from_dotprompt_content(dotprompt_content)
|
|
if parsed:
|
|
return PromptTemplateBase(
|
|
litellm_prompt_id=base_prompt_id,
|
|
content=parsed.get("content", ""),
|
|
metadata=parsed.get("metadata"),
|
|
)
|
|
else:
|
|
prompt_callback = IN_MEMORY_PROMPT_REGISTRY.get_prompt_callback_by_id(
|
|
prompt_spec.prompt_id
|
|
)
|
|
if prompt_callback is not None:
|
|
integration_name = prompt_callback.integration_name
|
|
if integration_name == "dotprompt":
|
|
from litellm.integrations.dotprompt.dotprompt_manager import (
|
|
DotpromptManager,
|
|
)
|
|
|
|
if isinstance(prompt_callback, DotpromptManager):
|
|
template = (
|
|
prompt_callback.prompt_manager.get_all_prompts_as_json()
|
|
)
|
|
if template is not None and len(template) == 1:
|
|
template_id = list(template.keys())[0]
|
|
return PromptTemplateBase(
|
|
litellm_prompt_id=template_id,
|
|
content=template[template_id]["content"],
|
|
metadata=template[template_id]["metadata"],
|
|
)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
@router.get(
|
|
"/prompts/{prompt_id}",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=PromptInfoResponse,
|
|
)
|
|
@router.get(
|
|
"/prompts/{prompt_id}/info",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
response_model=PromptInfoResponse,
|
|
)
|
|
async def get_prompt_info(
|
|
prompt_id: str,
|
|
environment: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Get detailed information about a specific prompt by ID, including prompt content
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X GET "http://localhost:4000/prompts/my_prompt_id/info" \\
|
|
-H "Authorization: Bearer <your_api_key>"
|
|
```
|
|
|
|
Example Response:
|
|
```json
|
|
{
|
|
"prompt_id": "my_prompt_id",
|
|
"litellm_params": {
|
|
"prompt_id": "my_prompt_id",
|
|
"prompt_integration": "dotprompt",
|
|
"prompt_directory": "/path/to/prompts"
|
|
},
|
|
"prompt_info": {
|
|
"prompt_type": "config"
|
|
},
|
|
"created_at": "2023-11-09T12:34:56.789Z",
|
|
"updated_at": "2023-11-09T12:34:56.789Z",
|
|
"content": "System: You are a helpful assistant.\n\nUser: {{user_message}}"
|
|
}
|
|
```
|
|
"""
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
## CHECK IF USER HAS ACCESS TO PROMPT
|
|
prompts: Optional[List[str]] = None
|
|
if user_api_key_dict.metadata is not None:
|
|
prompts = cast(
|
|
Optional[List[str]], user_api_key_dict.metadata.get("prompts", None)
|
|
)
|
|
if prompts is not None and prompt_id not in prompts:
|
|
raise HTTPException(status_code=400, detail=f"Prompt {prompt_id} not found")
|
|
if user_api_key_dict.user_role is not None and (
|
|
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
|
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
pass
|
|
else:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"You are not authorized to access this prompt. Your role - {user_api_key_dict.user_role}, Your key's prompts - {prompts}",
|
|
)
|
|
|
|
base_prompt_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
|
|
# Query all environments this prompt exists in (lightweight: distinct on environment)
|
|
all_environments: List[str] = []
|
|
if prisma_client is not None:
|
|
all_prompt_rows = await PromptRepository(prisma_client).table.find_many(
|
|
where={"prompt_id": base_prompt_id},
|
|
distinct=["environment"],
|
|
)
|
|
all_environments = sorted(
|
|
set(row.environment for row in all_prompt_rows if row.environment)
|
|
)
|
|
|
|
# If environment is specified, find the version in that environment from DB
|
|
# If prompt_id has a version suffix (e.g., "testprompt.v2"), fetch that specific version
|
|
# Otherwise fetch the latest version in that environment
|
|
prompt_spec = None
|
|
requested_version = (
|
|
get_version_number(prompt_id=prompt_id) if prompt_id != base_prompt_id else None
|
|
)
|
|
if environment and prisma_client is not None:
|
|
where_clause: Dict[str, Any] = {
|
|
"prompt_id": base_prompt_id,
|
|
"environment": environment,
|
|
}
|
|
if requested_version is not None:
|
|
where_clause["version"] = requested_version
|
|
env_prompts = await PromptRepository(prisma_client).table.find_many(
|
|
where=where_clause,
|
|
order={"version": "desc"},
|
|
take=1,
|
|
)
|
|
if env_prompts:
|
|
prompt_spec = create_versioned_prompt_spec(db_prompt=env_prompts[0])
|
|
|
|
# Fallback: use in-memory registry (no environment filter)
|
|
if prompt_spec is None and environment is None:
|
|
prompt_spec = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(prompt_id)
|
|
if prompt_spec is None:
|
|
latest_prompt_id = get_latest_version_prompt_id(
|
|
prompt_id=prompt_id,
|
|
all_prompt_ids=IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS,
|
|
)
|
|
prompt_spec = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(latest_prompt_id)
|
|
|
|
if prompt_spec is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Prompt {prompt_id} not found"
|
|
+ (f" in environment {environment}" if environment else ""),
|
|
)
|
|
|
|
# Extract version number from the prompt_id
|
|
version_number = get_version_number(prompt_id=prompt_spec.prompt_id)
|
|
|
|
# Create a copy of the prompt spec with the base prompt ID (stripped of version)
|
|
prompt_spec_response = PromptSpec(
|
|
prompt_id=get_base_prompt_id(prompt_id=prompt_spec.prompt_id),
|
|
litellm_params=prompt_spec.litellm_params,
|
|
prompt_info=prompt_spec.prompt_info,
|
|
created_at=prompt_spec.created_at,
|
|
updated_at=prompt_spec.updated_at,
|
|
version=version_number,
|
|
environment=prompt_spec.environment,
|
|
created_by=prompt_spec.created_by,
|
|
)
|
|
|
|
# Get prompt content
|
|
prompt_template = _get_prompt_template(prompt_spec, base_prompt_id)
|
|
|
|
return PromptInfoResponse(
|
|
prompt_spec=prompt_spec_response,
|
|
raw_prompt_template=prompt_template,
|
|
environments=all_environments,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/prompts",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def create_prompt(
|
|
request: Prompt,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Create a new prompt
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X POST "http://localhost:4000/prompts" \\
|
|
-H "Authorization: Bearer <your_api_key>" \\
|
|
-H "Content-Type: application/json" \\
|
|
-d '{
|
|
"prompt_id": "my_prompt",
|
|
"litellm_params": {
|
|
"prompt_id": "json_prompt",
|
|
"prompt_integration": "dotprompt",
|
|
### EITHER prompt_directory OR prompt_data MUST BE PROVIDED
|
|
"prompt_directory": "/path/to/dotprompt/folder",
|
|
"prompt_data": {"json_prompt": {"content": "This is a prompt", "metadata": {"model": "gpt-4"}}}
|
|
},
|
|
"prompt_info": {
|
|
"prompt_type": "config"
|
|
}
|
|
}'
|
|
```
|
|
"""
|
|
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
# Only allow proxy admins to create prompts
|
|
if user_api_key_dict.user_role is None or (
|
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
raise HTTPException(
|
|
status_code=403, detail="Only proxy admins can create prompts"
|
|
)
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
|
)
|
|
|
|
try:
|
|
# Extract environment from request
|
|
environment = (
|
|
request.prompt_info.environment
|
|
if request.prompt_info and request.prompt_info.environment
|
|
else "development"
|
|
)
|
|
|
|
# Get next version number
|
|
new_version = await get_next_version_for_prompt(
|
|
prisma_client=prisma_client,
|
|
prompt_id=request.prompt_id,
|
|
environment=environment,
|
|
)
|
|
|
|
# Store prompt in db with version
|
|
prompt_db_entry = await PromptRepository(prisma_client).table.create(
|
|
data={
|
|
"prompt_id": request.prompt_id,
|
|
"version": new_version,
|
|
"environment": environment,
|
|
"created_by": user_api_key_dict.user_id,
|
|
"litellm_params": request.litellm_params.model_dump_json(),
|
|
"prompt_info": (
|
|
request.prompt_info.model_dump_json()
|
|
if request.prompt_info
|
|
else PromptInfo(prompt_type="db").model_dump_json()
|
|
),
|
|
}
|
|
)
|
|
|
|
# Create versioned prompt spec
|
|
prompt_spec = create_versioned_prompt_spec(db_prompt=prompt_db_entry)
|
|
|
|
# Initialize the prompt
|
|
initialized_prompt = IN_MEMORY_PROMPT_REGISTRY.initialize_prompt(
|
|
prompt=prompt_spec, config_file_path=None
|
|
)
|
|
|
|
if initialized_prompt is None:
|
|
raise HTTPException(status_code=500, detail="Failed to initialize prompt")
|
|
|
|
return initialized_prompt
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error creating prompt: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.put(
|
|
"/prompts/{prompt_id}",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def update_prompt(
|
|
prompt_id: str,
|
|
request: Prompt,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Update an existing prompt
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X PUT "http://localhost:4000/prompts/my_prompt_id" \\
|
|
-H "Authorization: Bearer <your_api_key>" \\
|
|
-H "Content-Type: application/json" \\
|
|
-d '{
|
|
"prompt_id": "my_prompt",
|
|
"litellm_params": {
|
|
"prompt_id": "my_prompt",
|
|
"prompt_integration": "dotprompt",
|
|
"prompt_directory": "/path/to/prompts"
|
|
},
|
|
"prompt_info": {
|
|
"prompt_type": "config"
|
|
}
|
|
}
|
|
}'
|
|
```
|
|
"""
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
# Only allow proxy admins to update prompts
|
|
if user_api_key_dict.user_role is None or (
|
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
raise HTTPException(
|
|
status_code=403, detail="Only proxy admins can update prompts"
|
|
)
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
|
)
|
|
|
|
try:
|
|
# Strip version suffix from prompt_id if present (e.g., "jack_success.v1" -> "jack_success")
|
|
base_prompt_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
|
|
# Extract environment from request
|
|
environment = (
|
|
request.prompt_info.environment
|
|
if request.prompt_info and request.prompt_info.environment
|
|
else "development"
|
|
)
|
|
|
|
# Check if any version of this prompt exists (in any environment)
|
|
existing_prompts = await PromptRepository(prisma_client).table.find_many(
|
|
where={"prompt_id": base_prompt_id}
|
|
)
|
|
|
|
if not existing_prompts:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Prompt with ID {base_prompt_id} not found",
|
|
)
|
|
|
|
# Check if it's a config prompt
|
|
existing_in_memory = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(prompt_id)
|
|
if (
|
|
existing_in_memory
|
|
and existing_in_memory.prompt_info.prompt_type == "config"
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot update config prompts.",
|
|
)
|
|
|
|
# Get next version number (UPDATE creates a new version)
|
|
new_version = await get_next_version_for_prompt(
|
|
prisma_client=prisma_client,
|
|
prompt_id=base_prompt_id,
|
|
environment=environment,
|
|
)
|
|
|
|
# Store new version in db
|
|
prompt_db_entry = await PromptRepository(prisma_client).table.create(
|
|
data={
|
|
"prompt_id": base_prompt_id,
|
|
"version": new_version,
|
|
"environment": environment,
|
|
"created_by": user_api_key_dict.user_id,
|
|
"litellm_params": request.litellm_params.model_dump_json(),
|
|
"prompt_info": (
|
|
request.prompt_info.model_dump_json()
|
|
if request.prompt_info
|
|
else PromptInfo(prompt_type="db").model_dump_json()
|
|
),
|
|
}
|
|
)
|
|
|
|
# Create versioned prompt spec
|
|
prompt_spec = create_versioned_prompt_spec(db_prompt=prompt_db_entry)
|
|
|
|
# Initialize the new version
|
|
initialized_prompt = IN_MEMORY_PROMPT_REGISTRY.initialize_prompt(
|
|
prompt=prompt_spec, config_file_path=None
|
|
)
|
|
|
|
if initialized_prompt is None:
|
|
raise HTTPException(status_code=500, detail="Failed to update prompt")
|
|
|
|
return initialized_prompt
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error updating prompt: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete(
|
|
"/prompts/{prompt_id}",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def delete_prompt(
|
|
prompt_id: str,
|
|
environment: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Delete a prompt
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X DELETE "http://localhost:4000/prompts/my_prompt_id" \\
|
|
-H "Authorization: Bearer <your_api_key>"
|
|
```
|
|
|
|
Example Response:
|
|
```json
|
|
{
|
|
"message": "Prompt my_prompt_id deleted successfully"
|
|
}
|
|
```
|
|
"""
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
# Only allow proxy admins to delete prompts
|
|
if user_api_key_dict.user_role is None or (
|
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
raise HTTPException(
|
|
status_code=403, detail="Only proxy admins can delete prompts"
|
|
)
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
|
)
|
|
|
|
try:
|
|
# Try to get prompt directly first
|
|
existing_prompt = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(prompt_id)
|
|
|
|
# If not found, try to find the latest version
|
|
if existing_prompt is None:
|
|
latest_prompt_id = get_latest_version_prompt_id(
|
|
prompt_id=prompt_id,
|
|
all_prompt_ids=IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS,
|
|
)
|
|
existing_prompt = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(
|
|
latest_prompt_id
|
|
)
|
|
# Use the resolved prompt_id for deletion
|
|
prompt_id = latest_prompt_id
|
|
|
|
if existing_prompt is None:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Prompt with ID {prompt_id} not found"
|
|
)
|
|
|
|
if existing_prompt.prompt_info.prompt_type == "config":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot delete config prompts.",
|
|
)
|
|
|
|
# Get the base prompt ID (without version suffix) for database deletion
|
|
base_prompt_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
|
|
# Build delete filter; scope to environment if provided
|
|
delete_where: Dict[str, Any] = {"prompt_id": base_prompt_id}
|
|
if environment:
|
|
delete_where["environment"] = environment
|
|
|
|
# Delete versions from the database (scoped to environment if provided)
|
|
await PromptRepository(prisma_client).table.delete_many(where=delete_where)
|
|
|
|
# Remove matching prompts from memory — scope to environment if provided
|
|
if environment:
|
|
prompts_to_delete = [
|
|
pid
|
|
for pid, prompt in IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS.items()
|
|
if get_base_prompt_id(prompt_id=pid) == base_prompt_id
|
|
and prompt.environment == environment
|
|
]
|
|
for pid in prompts_to_delete:
|
|
del IN_MEMORY_PROMPT_REGISTRY.IN_MEMORY_PROMPTS[pid]
|
|
if pid in IN_MEMORY_PROMPT_REGISTRY.prompt_id_to_custom_prompt:
|
|
del IN_MEMORY_PROMPT_REGISTRY.prompt_id_to_custom_prompt[pid]
|
|
else:
|
|
IN_MEMORY_PROMPT_REGISTRY.delete_prompts_by_base_id(base_prompt_id)
|
|
|
|
env_msg = f" from {environment}" if environment else ""
|
|
return {"message": f"Prompt {base_prompt_id} deleted successfully{env_msg}"}
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error deleting prompt: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
def _reload_prompt_in_registry(
|
|
registry: Any, versioned_id: str, updated_prompt_spec: PromptSpec
|
|
) -> PromptSpec:
|
|
"""Remove stale entry and re-initialize the prompt in the in-memory registry."""
|
|
if versioned_id in registry.IN_MEMORY_PROMPTS:
|
|
del registry.IN_MEMORY_PROMPTS[versioned_id]
|
|
if versioned_id in registry.prompt_id_to_custom_prompt:
|
|
del registry.prompt_id_to_custom_prompt[versioned_id]
|
|
initialized = registry.initialize_prompt(
|
|
prompt=updated_prompt_spec, config_file_path=None
|
|
)
|
|
if initialized is None:
|
|
raise HTTPException(status_code=500, detail="Failed to patch prompt")
|
|
return initialized
|
|
|
|
|
|
@router.patch(
|
|
"/prompts/{prompt_id}",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def patch_prompt(
|
|
prompt_id: str,
|
|
request: PatchPromptRequest,
|
|
environment: Optional[str] = None,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Partially update an existing prompt
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
This endpoint allows updating specific fields of a prompt without sending the entire object.
|
|
Only the following fields can be updated:
|
|
- litellm_params: LiteLLM parameters for the prompt
|
|
- prompt_info: Additional information about the prompt
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X PATCH "http://localhost:4000/prompts/my_prompt_id" \\
|
|
-H "Authorization: Bearer <your_api_key>" \\
|
|
-H "Content-Type: application/json" \\
|
|
-d '{
|
|
"prompt_info": {
|
|
"prompt_type": "db"
|
|
}
|
|
}'
|
|
```
|
|
"""
|
|
|
|
from litellm.proxy.prompts.prompt_registry import IN_MEMORY_PROMPT_REGISTRY
|
|
from litellm.proxy.proxy_server import prisma_client
|
|
|
|
# Only allow proxy admins to patch prompts
|
|
if user_api_key_dict.user_role is None or (
|
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
|
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
|
):
|
|
raise HTTPException(
|
|
status_code=403, detail="Only proxy admins can patch prompts"
|
|
)
|
|
|
|
if prisma_client is None:
|
|
raise HTTPException(
|
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
|
)
|
|
|
|
try:
|
|
# Resolve the target row: find the latest version in the given environment
|
|
base_prompt_id = get_base_prompt_id(prompt_id=prompt_id)
|
|
env = environment or "development"
|
|
requested_version = (
|
|
get_version_number(prompt_id=prompt_id)
|
|
if prompt_id != base_prompt_id
|
|
else None
|
|
)
|
|
|
|
# Build query to find the exact row by composite unique key
|
|
find_where: Dict[str, Any] = {
|
|
"prompt_id": base_prompt_id,
|
|
"environment": env,
|
|
}
|
|
if requested_version is not None:
|
|
find_where["version"] = requested_version
|
|
|
|
db_rows = await PromptRepository(prisma_client).table.find_many(
|
|
where=find_where,
|
|
order={"version": "desc"},
|
|
take=1,
|
|
)
|
|
if not db_rows:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Prompt with ID {base_prompt_id} not found in environment {env}",
|
|
)
|
|
|
|
target_row = db_rows[0]
|
|
|
|
# Check if prompt exists in memory
|
|
versioned_id = f"{base_prompt_id}.v{target_row.version}"
|
|
existing_prompt = IN_MEMORY_PROMPT_REGISTRY.get_prompt_by_id(versioned_id)
|
|
|
|
if existing_prompt and existing_prompt.prompt_info.prompt_type == "config":
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Cannot update config prompts.",
|
|
)
|
|
|
|
# Use existing prompt from memory or build from DB row for field merging
|
|
if existing_prompt:
|
|
current_litellm_params = existing_prompt.litellm_params
|
|
current_prompt_info = existing_prompt.prompt_info
|
|
else:
|
|
current_spec = create_versioned_prompt_spec(db_prompt=target_row)
|
|
current_litellm_params = current_spec.litellm_params
|
|
current_prompt_info = current_spec.prompt_info
|
|
|
|
# Update fields if provided
|
|
updated_litellm_params = (
|
|
request.litellm_params
|
|
if request.litellm_params is not None
|
|
else current_litellm_params
|
|
)
|
|
|
|
updated_prompt_info = (
|
|
request.prompt_info
|
|
if request.prompt_info is not None
|
|
else current_prompt_info
|
|
)
|
|
|
|
# Ensure we have valid litellm_params
|
|
if updated_litellm_params is None:
|
|
raise HTTPException(status_code=400, detail="litellm_params cannot be None")
|
|
|
|
# Build update data dict
|
|
update_data: Dict[str, Any] = {
|
|
"litellm_params": updated_litellm_params.model_dump_json(),
|
|
"prompt_info": updated_prompt_info.model_dump_json(),
|
|
}
|
|
if user_api_key_dict.user_id:
|
|
update_data["created_by"] = user_api_key_dict.user_id
|
|
|
|
# Update by primary key (id) to target exactly one row
|
|
updated_prompt_db_entry = await PromptRepository(prisma_client).table.update(
|
|
where={"id": target_row.id},
|
|
data=update_data,
|
|
)
|
|
|
|
updated_prompt_spec = create_versioned_prompt_spec(
|
|
db_prompt=updated_prompt_db_entry
|
|
)
|
|
|
|
return _reload_prompt_in_registry(
|
|
IN_MEMORY_PROMPT_REGISTRY, versioned_id, updated_prompt_spec
|
|
)
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error patching prompt: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post(
|
|
"/prompts/test",
|
|
tags=["Prompt Management"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def test_prompt(
|
|
request: TestPromptRequest,
|
|
fastapi_request: Request,
|
|
fastapi_response: Response,
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
):
|
|
"""
|
|
Test a prompt by rendering it with variables and executing an LLM call.
|
|
|
|
This endpoint allows testing prompts before saving them to the database.
|
|
The response is always streamed.
|
|
|
|
👉 [Prompt docs](https://docs.litellm.ai/docs/proxy/prompt_management)
|
|
|
|
Example Request:
|
|
```bash
|
|
curl -X POST "http://localhost:4000/prompts/test" \\
|
|
-H "Authorization: Bearer <your_api_key>" \\
|
|
-H "Content-Type: application/json" \\
|
|
-d '{
|
|
"dotprompt_content": "---\\nmodel: gpt-4o\\ntemperature: 0.7\\n---\\n\\nUser: Hello {{name}}",
|
|
"prompt_variables": {
|
|
"name": "World"
|
|
}
|
|
}'
|
|
```
|
|
"""
|
|
from pydantic import BaseModel
|
|
|
|
from litellm.integrations.dotprompt.dotprompt_manager import DotpromptManager
|
|
from litellm.integrations.dotprompt.prompt_manager import (
|
|
PromptManager,
|
|
PromptTemplate,
|
|
)
|
|
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
|
from litellm.proxy.proxy_server import (
|
|
general_settings,
|
|
llm_router,
|
|
proxy_config,
|
|
proxy_logging_obj,
|
|
select_data_generator,
|
|
user_api_base,
|
|
user_max_tokens,
|
|
user_model,
|
|
user_request_timeout,
|
|
user_temperature,
|
|
version,
|
|
)
|
|
|
|
try:
|
|
# Parse the dotprompt content and create PromptTemplate
|
|
prompt_manager = PromptManager()
|
|
frontmatter, template_content = prompt_manager._parse_frontmatter(
|
|
content=request.dotprompt_content
|
|
)
|
|
|
|
# Create PromptTemplate to leverage existing parameter extraction logic
|
|
template = PromptTemplate(
|
|
content=template_content, metadata=frontmatter, template_id="test_prompt"
|
|
)
|
|
|
|
# Extract model from template
|
|
if not template.model:
|
|
raise HTTPException(
|
|
status_code=400, detail="Model is required in dotprompt metadata"
|
|
)
|
|
|
|
# Always render the template to extract system messages and other metadata
|
|
variables = request.prompt_variables or {}
|
|
rendered_content = prompt_manager.jinja_env.from_string(
|
|
template_content
|
|
).render(**variables)
|
|
|
|
# Convert rendered content to messages using DotpromptManager's method
|
|
dotprompt_manager = DotpromptManager()
|
|
rendered_messages = dotprompt_manager._convert_to_messages(
|
|
rendered_content=rendered_content
|
|
)
|
|
|
|
if not rendered_messages:
|
|
raise HTTPException(
|
|
status_code=400, detail="No messages found in rendered prompt"
|
|
)
|
|
|
|
# If conversation history is provided, use it but preserve system messages
|
|
if request.conversation_history:
|
|
# Extract system messages from rendered prompt
|
|
system_messages = [
|
|
msg for msg in rendered_messages if msg.get("role") == "system"
|
|
]
|
|
# Use conversation history for user/assistant messages
|
|
messages = system_messages + request.conversation_history
|
|
else:
|
|
messages = rendered_messages # type: ignore[assignment]
|
|
|
|
# Use PromptTemplate's optional_params which already extracts all parameters
|
|
optional_params = template.optional_params.copy()
|
|
|
|
# Always stream the response
|
|
optional_params["stream"] = True
|
|
|
|
# Build request data for chat completion
|
|
data = {
|
|
"model": template.model,
|
|
"messages": messages,
|
|
}
|
|
data.update(optional_params)
|
|
|
|
is_request_body_safe(
|
|
request_body=data,
|
|
general_settings=general_settings,
|
|
llm_router=llm_router,
|
|
model=data.get("model", ""),
|
|
)
|
|
|
|
# Use ProxyBaseLLMRequestProcessing to go through all proxy logic
|
|
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
|
|
result = await base_llm_response_processor.base_process_llm_request(
|
|
request=fastapi_request,
|
|
fastapi_response=fastapi_response,
|
|
user_api_key_dict=user_api_key_dict,
|
|
route_type="acompletion",
|
|
proxy_logging_obj=proxy_logging_obj,
|
|
llm_router=llm_router,
|
|
general_settings=general_settings,
|
|
proxy_config=proxy_config,
|
|
select_data_generator=select_data_generator,
|
|
model=None,
|
|
user_model=user_model,
|
|
user_temperature=user_temperature,
|
|
user_request_timeout=user_request_timeout,
|
|
user_max_tokens=user_max_tokens,
|
|
user_api_base=user_api_base,
|
|
version=version,
|
|
)
|
|
|
|
if isinstance(result, BaseModel):
|
|
return result.model_dump(exclude_none=True, exclude_unset=True)
|
|
else:
|
|
return result
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(f"Error testing prompt: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post(
|
|
"/utils/dotprompt_json_converter",
|
|
tags=["prompts", "utils"],
|
|
dependencies=[Depends(user_api_key_auth)],
|
|
)
|
|
async def convert_prompt_file_to_json(
|
|
file: UploadFile = File(...),
|
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convert a .prompt file to JSON format.
|
|
|
|
This endpoint accepts a .prompt file upload and returns the equivalent JSON representation
|
|
that can be stored in a database or used programmatically.
|
|
|
|
Returns the JSON structure with 'content' and 'metadata' fields.
|
|
"""
|
|
global general_settings
|
|
from litellm.integrations.dotprompt.prompt_manager import PromptManager
|
|
|
|
# Validate file extension
|
|
if not file.filename or not file.filename.endswith(".prompt"):
|
|
raise HTTPException(status_code=400, detail="File must have .prompt extension")
|
|
|
|
temp_file_path = None
|
|
try:
|
|
# Read file content
|
|
file_content = await file.read()
|
|
|
|
# Create temporary file — use safe_filename to prevent path traversal
|
|
temp_file_path = Path(tempfile.mkdtemp()) / safe_filename(file.filename)
|
|
temp_file_path.write_bytes(file_content)
|
|
|
|
# Create a PromptManager instance just for conversion
|
|
prompt_manager = PromptManager()
|
|
|
|
# Convert to JSON
|
|
json_data = prompt_manager.prompt_file_to_json(temp_file_path)
|
|
|
|
# Extract prompt ID from filename
|
|
prompt_id = temp_file_path.stem
|
|
|
|
return {
|
|
"prompt_id": prompt_id,
|
|
"json_data": json_data,
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error converting prompt file: {str(e)}"
|
|
)
|
|
|
|
finally:
|
|
# Clean up temp file
|
|
if temp_file_path and temp_file_path.exists():
|
|
temp_file_path.unlink()
|
|
# Also try to remove the temp directory if it's empty
|
|
try:
|
|
temp_file_path.parent.rmdir()
|
|
except OSError:
|
|
pass # Directory not empty or other error
|