feat: add cost tracking

This commit is contained in:
2025-06-27 15:51:00 +08:00
parent f046bdfad2
commit c837be1ef1

View File

@ -11,8 +11,9 @@ import re
import requests import requests
import json import json
import time import time
from typing import List, Union, Generator, Iterator from typing import List, Union, Generator, Iterator, Optional, Callable, Any, Awaitable, AsyncGenerator
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import asyncio
def _insert_citations(text: str, citations: list[str]) -> str: def _insert_citations(text: str, citations: list[str]) -> str:
@ -145,7 +146,7 @@ class Pipe:
print(f"Error fetching generation details: {e}") print(f"Error fetching generation details: {e}")
return {} return {}
def _report_api_call(self, generation_data: dict, user_email: str, model_id: str): async def _report_api_call(self, generation_data: dict, user_email: str, model_id: str, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Report API call to upstream reporting service""" """Report API call to upstream reporting service"""
if not self.valves.REPORT_API_URL or not self.valves.REPORT_API_KEY: if not self.valves.REPORT_API_URL or not self.valves.REPORT_API_KEY:
return return
@ -189,10 +190,20 @@ class Pipe:
else: else:
print(f"Failed to report API call: {response.status_code}") print(f"Failed to report API call: {response.status_code}")
info = f"input: {input_tokens} | output: {output_tokens} | cost: {cost_usd:.6f}"
await __event_emitter__(
{
"type": "status",
"data": {
"description": info,
"done": True,
},
}
)
except Exception as e: except Exception as e:
print(f"Error reporting API call: {e}") print(f"Error reporting API call: {e}")
def pipe(self, body: dict, __user__: dict, __metadata__: dict) -> Union[str, Generator, Iterator]: async def pipe(self, body: dict, __user__: dict, __metadata__: dict, __event_emitter__: Callable[[Any], Awaitable[None]]) -> Union[str, AsyncGenerator]:
"""Process the request and handle reasoning tokens if supported""" """Process the request and handle reasoning tokens if supported"""
# Clone the body for OpenRouter # Clone the body for OpenRouter
payload = body.copy() payload = body.copy()
@ -237,9 +248,9 @@ class Pipe:
try: try:
if body.get("stream", False): if body.get("stream", False):
return self.stream_response(url, headers, payload, user_email, model_id) return self.stream_response(url, headers, payload, user_email, model_id, __event_emitter__)
else: else:
return self.non_stream_response(url, headers, payload, user_email, model_id) return await self.non_stream_response(url, headers, payload, user_email, model_id, __event_emitter__)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
print(f"Request failed: {e}") print(f"Request failed: {e}")
return f"Error: Request failed: {e}" return f"Error: Request failed: {e}"
@ -247,7 +258,7 @@ class Pipe:
print(f"Error in pipe method: {e}") print(f"Error in pipe method: {e}")
return f"Error: {e}" return f"Error: {e}"
def non_stream_response(self, url, headers, payload, user_email, model_id): async def non_stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Handle non-streaming responses and wrap reasoning in <think> tags if present""" """Handle non-streaming responses and wrap reasoning in <think> tags if present"""
try: try:
print( print(
@ -285,7 +296,7 @@ class Pipe:
try: try:
generation_data = self._fetch_generation_details(generation_id) generation_data = self._fetch_generation_details(generation_id)
if generation_data: if generation_data:
self._report_api_call(generation_data, user_email, model_id) await self._report_api_call(generation_data, user_email, model_id, __event_emitter__)
except Exception as e: except Exception as e:
print(f"Error reporting API call: {e}") print(f"Error reporting API call: {e}")
return f"Error: {e}" return f"Error: {e}"
@ -322,7 +333,7 @@ class Pipe:
print(f"Error in non_stream_response: {e}") print(f"Error in non_stream_response: {e}")
return f"Error: {e}" return f"Error: {e}"
def stream_response(self, url, headers, payload, user_email, model_id): async def stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]):
"""Stream reasoning tokens in real-time with proper tag management""" """Stream reasoning tokens in real-time with proper tag management"""
try: try:
response = requests.post( response = requests.post(
@ -365,14 +376,14 @@ class Pipe:
try: try:
generation_data = self._fetch_generation_details(generation_id) generation_data = self._fetch_generation_details(generation_id)
if generation_data: if generation_data:
self._report_api_call(generation_data, user_email, model_id) await self._report_api_call(generation_data, user_email, model_id, __event_emitter__)
except Exception as e: except Exception as e:
print(f"Error reporting API call: {e}") print(f"Error reporting API call: {e}")
return f"Error: {e}" yield f"Error: {e}"
yield "" ## trick yield "" ## trick
else: else:
print(f"No generation ID found for reporting") print(f"No generation ID found for reporting")
return f"Error: No generation ID found for reporting" yield f"Error: No generation ID found for reporting"
continue continue
try: try: