diff --git a/openrouter_provider.py b/openrouter_provider.py index f804082..f04d6f9 100644 --- a/openrouter_provider.py +++ b/openrouter_provider.py @@ -11,8 +11,9 @@ import re import requests import json import time -from typing import List, Union, Generator, Iterator +from typing import List, Union, Generator, Iterator, Optional, Callable, Any, Awaitable, AsyncGenerator from pydantic import BaseModel, Field +import asyncio def _insert_citations(text: str, citations: list[str]) -> str: @@ -145,7 +146,7 @@ class Pipe: print(f"Error fetching generation details: {e}") return {} - def _report_api_call(self, generation_data: dict, user_email: str, model_id: str): + async def _report_api_call(self, generation_data: dict, user_email: str, model_id: str, __event_emitter__: Callable[[Any], Awaitable[None]]): """Report API call to upstream reporting service""" if not self.valves.REPORT_API_URL or not self.valves.REPORT_API_KEY: return @@ -188,11 +189,21 @@ class Pipe: print(f"Successfully reported API call for user {user_email}") else: print(f"Failed to report API call: {response.status_code}") - + + info = f"input: {input_tokens} | output: {output_tokens} | cost: {cost_usd:.6f}" + await __event_emitter__( + { + "type": "status", + "data": { + "description": info, + "done": True, + }, + } + ) except Exception as e: print(f"Error reporting API call: {e}") - def pipe(self, body: dict, __user__: dict, __metadata__: dict) -> Union[str, Generator, Iterator]: + async def pipe(self, body: dict, __user__: dict, __metadata__: dict, __event_emitter__: Callable[[Any], Awaitable[None]]) -> Union[str, AsyncGenerator]: """Process the request and handle reasoning tokens if supported""" # Clone the body for OpenRouter payload = body.copy() @@ -237,9 +248,9 @@ class Pipe: try: if body.get("stream", False): - return self.stream_response(url, headers, payload, user_email, model_id) + return self.stream_response(url, headers, payload, user_email, model_id, __event_emitter__) else: - return self.non_stream_response(url, headers, payload, user_email, model_id) + return await self.non_stream_response(url, headers, payload, user_email, model_id, __event_emitter__) except requests.exceptions.RequestException as e: print(f"Request failed: {e}") return f"Error: Request failed: {e}" @@ -247,7 +258,7 @@ class Pipe: print(f"Error in pipe method: {e}") return f"Error: {e}" - def non_stream_response(self, url, headers, payload, user_email, model_id): + async def non_stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]): """Handle non-streaming responses and wrap reasoning in tags if present""" try: print( @@ -285,7 +296,7 @@ class Pipe: try: generation_data = self._fetch_generation_details(generation_id) if generation_data: - self._report_api_call(generation_data, user_email, model_id) + await self._report_api_call(generation_data, user_email, model_id, __event_emitter__) except Exception as e: print(f"Error reporting API call: {e}") return f"Error: {e}" @@ -322,7 +333,7 @@ class Pipe: print(f"Error in non_stream_response: {e}") return f"Error: {e}" - def stream_response(self, url, headers, payload, user_email, model_id): + async def stream_response(self, url, headers, payload, user_email, model_id, __event_emitter__: Callable[[Any], Awaitable[None]]): """Stream reasoning tokens in real-time with proper tag management""" try: response = requests.post( @@ -365,14 +376,14 @@ class Pipe: try: generation_data = self._fetch_generation_details(generation_id) if generation_data: - self._report_api_call(generation_data, user_email, model_id) + await self._report_api_call(generation_data, user_email, model_id, __event_emitter__) except Exception as e: print(f"Error reporting API call: {e}") - return f"Error: {e}" + yield f"Error: {e}" yield "" ## trick else: print(f"No generation ID found for reporting") - return f"Error: No generation ID found for reporting" + yield f"Error: No generation ID found for reporting" continue try: