feat: add cost tracking

This commit is contained in:
2025-06-27 15:51:00 +08:00
parent f046bdfad2
commit c837be1ef1

View File

@ -11,8 +11,9 @@ import re
import requests
import json
import time
from typing import List, Union, Generator, Iterator
from typing import List, Union, Generator, Iterator, Optional, Callable, Any, Awaitable, AsyncGenerator
from pydantic import BaseModel, Field
import asyncio
def _insert_citations(text: str, citations: list[str]) -> str:
@ -145,7 +146,7 @@ class Pipe:
print(f"Error fetching generation details: {e}")
return {}
def _report_api_call(self, generation_data: dict, user_email: str, model_id: str):
async def _report_api_call(self, generation_data: dict, user_email: str, model_id: str, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Report API call to upstream reporting service"""
if not self.valves.REPORT_API_URL or not self.valves.REPORT_API_KEY:
return
@ -189,10 +190,20 @@ class Pipe:
else:
print(f"Failed to report API call: {response.status_code}")
info = f"input: {input_tokens} | output: {output_tokens} | cost: {cost_usd:.6f}"
await __event_emitter__(
{
"type": "status",
"data": {
"description": info,
"done": True,
},
}
)
except Exception as e:
print(f"Error reporting API call: {e}")
def pipe(self, body: dict, __user__: dict, __metadata__: dict) -> Union[str, Generator, Iterator]:
async def pipe(self, body: dict, __user__: dict, __metadata__: dict, __event_emitter__: Callable[[Any], Awaitable[None]]) -> Union[str, AsyncGenerator]:
"""Process the request and handle reasoning tokens if supported"""
# Clone the body for OpenRouter
payload = body.copy()
@ -237,9 +248,9 @@ class Pipe:
try:
if body.get("stream", False):
return self.stream_response(url, headers, payload, user_email, model_id)
return self.stream_response(url, headers, payload, user_email, model_id, __event_emitter__)
else:
return self.non_stream_response(url, headers, payload, user_email, model_id)
return await self.non_stream_response(url, headers, payload, user_email, model_id, __event_emitter__)
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
return f"Error: Request failed: {e}"
@ -247,7 +258,7 @@ class Pipe:
print(f"Error in pipe method: {e}")
return f"Error: {e}"
def non_stream_response(self, url, headers, payload, user_email, model_id):
async def non_stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Handle non-streaming responses and wrap reasoning in <think> tags if present"""
try:
print(
@ -285,7 +296,7 @@ class Pipe:
try:
generation_data = self._fetch_generation_details(generation_id)
if generation_data:
self._report_api_call(generation_data, user_email, model_id)
await self._report_api_call(generation_data, user_email, model_id, __event_emitter__)
except Exception as e:
print(f"Error reporting API call: {e}")
return f"Error: {e}"
@ -322,7 +333,7 @@ class Pipe:
print(f"Error in non_stream_response: {e}")
return f"Error: {e}"
def stream_response(self, url, headers, payload, user_email, model_id):
async def stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Stream reasoning tokens in real-time with proper tag management"""
try:
response = requests.post(
@ -365,14 +376,14 @@ class Pipe:
try:
generation_data = self._fetch_generation_details(generation_id)
if generation_data:
self._report_api_call(generation_data, user_email, model_id)
await self._report_api_call(generation_data, user_email, model_id, __event_emitter__)
except Exception as e:
print(f"Error reporting API call: {e}")
return f"Error: {e}"
yield f"Error: {e}"
yield "" ## trick
else:
print(f"No generation ID found for reporting")
return f"Error: No generation ID found for reporting"
yield f"Error: No generation ID found for reporting"
continue
try: