325 lines
11 KiB
Python
325 lines
11 KiB
Python
from fastapi import FastAPI, Request, HTTPException, Depends, Form
|
||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, HTTPBasic, HTTPBasicCredentials
|
||
from starlette.middleware.sessions import SessionMiddleware
|
||
import os
|
||
from loguru import logger
|
||
import sys
|
||
import uvicorn
|
||
import json
|
||
import aiosqlite
|
||
import time
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, List, Dict, Any
|
||
from contextlib import asynccontextmanager
|
||
from pydantic import BaseModel
|
||
import secrets
|
||
|
||
listen_port = os.getenv("LISTEN_PORT", "8000")
|
||
listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
|
||
config_path = os.getenv("CONFIG_PATH", "./config.json")
|
||
database_path = os.getenv("DATABASE_PATH", "./data.sqlite3")
|
||
logging_level = os.getenv("LOGGING_LEVEL", "TRACE")
|
||
|
||
# Configure loguru properly - remove default handler and add new one with correct level
|
||
logger.remove() # Remove the default handler
|
||
logger.add(sys.stderr, level=logging_level) # Add new handler with specified level
|
||
|
||
security_bearer = HTTPBearer()
|
||
security_basic = HTTPBasic()
|
||
|
||
class APICallRecord(BaseModel):
|
||
timestamp: int
|
||
model_id: str
|
||
user_email: str
|
||
input_tokens: int
|
||
output_tokens: int
|
||
cost_usd: float
|
||
|
||
def normalize_path(path: str) -> str:
|
||
"""
|
||
处理路径,支持绝对路径、用户目录和相对路径
|
||
|
||
Args:
|
||
path: 输入路径
|
||
|
||
Returns:
|
||
处理后的路径
|
||
"""
|
||
if path.startswith("/") or path.startswith("~"):
|
||
# 绝对路径,如果是~开头则展开为用户目录
|
||
return os.path.expanduser(path)
|
||
elif path.startswith("./"):
|
||
# 相对于脚本本身的相对路径
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
return os.path.join(script_dir, path[2:]) # 去掉'./'前缀
|
||
else:
|
||
# 其他情况保持原样(相对于pwd的相对路径)
|
||
return path
|
||
|
||
# 处理config_path和database_path路径
|
||
config_path = normalize_path(config_path)
|
||
database_path = normalize_path(database_path)
|
||
|
||
with open(config_path, "r") as f:
|
||
config = json.load(f)
|
||
|
||
admin_api_key = config.get("admin_api_key", "")
|
||
report_api_key = config.get("report_api_key", "")
|
||
session_secret_key = config.get("session_secret_key", secrets.token_hex(32))
|
||
|
||
async def init_database():
|
||
"""Initialize the database and create tables if they don't exist"""
|
||
async with aiosqlite.connect(database_path) as db:
|
||
await db.execute("""
|
||
CREATE TABLE IF NOT EXISTS api_calls (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
timestamp INTEGER NOT NULL,
|
||
model_id TEXT NOT NULL,
|
||
user_email TEXT NOT NULL,
|
||
input_tokens INTEGER NOT NULL,
|
||
output_tokens INTEGER NOT NULL,
|
||
cost_usd REAL NOT NULL,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
""")
|
||
await db.commit()
|
||
logger.info("Database initialized successfully")
|
||
|
||
async def verify_admin_session(request: Request):
|
||
"""Verify admin session for dashboard access"""
|
||
if not request.session.get("authenticated"):
|
||
raise HTTPException(
|
||
status_code=302,
|
||
headers={"Location": "/login"}
|
||
)
|
||
return True
|
||
|
||
async def verify_report_key(credentials: HTTPAuthorizationCredentials = Depends(security_bearer)):
|
||
"""Verify report API key for API access using Bearer authentication"""
|
||
if credentials.credentials != report_api_key:
|
||
raise HTTPException(status_code=403, detail="Invalid report API key")
|
||
return credentials.credentials
|
||
|
||
async def get_24h_stats():
|
||
"""Get statistics for the past 24 hours"""
|
||
twenty_four_hours_ago = int(time.time()) - (24 * 60 * 60)
|
||
|
||
async with aiosqlite.connect(database_path) as db:
|
||
# Total chats (assuming each API call is a chat)
|
||
cursor = await db.execute(
|
||
"SELECT COUNT(*) FROM api_calls WHERE timestamp >= ?",
|
||
(twenty_four_hours_ago,)
|
||
)
|
||
total_chats = (await cursor.fetchone())[0]
|
||
|
||
# Total tokens
|
||
cursor = await db.execute(
|
||
"SELECT SUM(input_tokens + output_tokens) FROM api_calls WHERE timestamp >= ?",
|
||
(twenty_four_hours_ago,)
|
||
)
|
||
result = await cursor.fetchone()
|
||
total_tokens = result[0] if result[0] is not None else 0
|
||
|
||
# Total cost
|
||
cursor = await db.execute(
|
||
"SELECT SUM(cost_usd) FROM api_calls WHERE timestamp >= ?",
|
||
(twenty_four_hours_ago,)
|
||
)
|
||
result = await cursor.fetchone()
|
||
total_cost = result[0] if result[0] is not None else 0.0
|
||
|
||
return {
|
||
"total_chats": total_chats,
|
||
"total_tokens": total_tokens,
|
||
"total_cost": round(total_cost, 6)
|
||
}
|
||
|
||
async def get_recent_logs(page: int = 1, page_size: int = 100):
|
||
"""Get recent API call logs with pagination"""
|
||
offset = (page - 1) * page_size
|
||
|
||
async with aiosqlite.connect(database_path) as db:
|
||
# Get total count for pagination
|
||
cursor = await db.execute("SELECT COUNT(*) FROM api_calls")
|
||
total_count = (await cursor.fetchone())[0]
|
||
|
||
# Get paginated results
|
||
cursor = await db.execute("""
|
||
SELECT timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd
|
||
FROM api_calls
|
||
ORDER BY timestamp DESC
|
||
LIMIT ? OFFSET ?
|
||
""", (page_size, offset))
|
||
|
||
logs = []
|
||
async for row in cursor:
|
||
logs.append({
|
||
"timestamp": row[0],
|
||
"model_id": row[1],
|
||
"user_email": row[2],
|
||
"input_tokens": row[3],
|
||
"output_tokens": row[4],
|
||
"total_tokens": row[3] + row[4],
|
||
"cost_usd": round(row[5], 6)
|
||
})
|
||
|
||
total_pages = (total_count + page_size - 1) // page_size # Ceiling division
|
||
|
||
return {
|
||
"logs": logs,
|
||
"pagination": {
|
||
"current_page": page,
|
||
"total_pages": total_pages,
|
||
"total_count": total_count,
|
||
"page_size": page_size,
|
||
"has_prev": page > 1,
|
||
"has_next": page < total_pages
|
||
}
|
||
}
|
||
|
||
async def get_users_report():
|
||
"""Get user consumption statistics"""
|
||
async with aiosqlite.connect(database_path) as db:
|
||
cursor = await db.execute("""
|
||
SELECT
|
||
user_email,
|
||
COUNT(*) as total_calls,
|
||
SUM(input_tokens) as total_input_tokens,
|
||
SUM(output_tokens) as total_output_tokens,
|
||
SUM(input_tokens + output_tokens) as total_tokens,
|
||
SUM(cost_usd) as total_cost
|
||
FROM api_calls
|
||
GROUP BY user_email
|
||
ORDER BY total_cost DESC
|
||
""")
|
||
|
||
users = []
|
||
async for row in cursor:
|
||
users.append({
|
||
"email": row[0],
|
||
"total_calls": row[1],
|
||
"total_input_tokens": row[2],
|
||
"total_output_tokens": row[3],
|
||
"total_tokens": row[4],
|
||
"total_cost": round(row[5], 6)
|
||
})
|
||
|
||
return users
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""Lifespan context manager for startup and shutdown events"""
|
||
# Startup
|
||
await init_database()
|
||
logger.info("Application started")
|
||
|
||
yield
|
||
|
||
# Shutdown
|
||
logger.info("Application shutting down")
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
# Add session middleware
|
||
app.add_middleware(SessionMiddleware, secret_key=session_secret_key)
|
||
|
||
# Set up templates
|
||
templates = Jinja2Templates(directory="templates")
|
||
|
||
# Routes
|
||
@app.get("/")
|
||
async def root(request: Request):
|
||
"""Redirect to dashboard or login"""
|
||
if request.session.get("authenticated"):
|
||
return RedirectResponse(url="/dashboard", status_code=302)
|
||
else:
|
||
return RedirectResponse(url="/login", status_code=302)
|
||
|
||
@app.get("/dashboard", response_class=HTMLResponse)
|
||
async def dashboard(request: Request, _: bool = Depends(verify_admin_session)):
|
||
"""Main dashboard page"""
|
||
stats = await get_24h_stats()
|
||
return templates.TemplateResponse("dashboard.html", {
|
||
"request": request,
|
||
"stats": stats
|
||
})
|
||
|
||
@app.get("/dashboard/logs", response_class=HTMLResponse)
|
||
async def dashboard_logs(request: Request, page: int = 1, _: bool = Depends(verify_admin_session)):
|
||
"""Dashboard logs page"""
|
||
if page < 1:
|
||
page = 1
|
||
|
||
result = await get_recent_logs(page=page, page_size=100)
|
||
return templates.TemplateResponse("logs.html", {
|
||
"request": request,
|
||
"logs": result["logs"],
|
||
"pagination": result["pagination"]
|
||
})
|
||
|
||
@app.get("/dashboard/users_report", response_class=HTMLResponse)
|
||
async def dashboard_users_report(request: Request, _: bool = Depends(verify_admin_session)):
|
||
"""Dashboard users report page"""
|
||
users = await get_users_report()
|
||
return templates.TemplateResponse("users_report.html", {
|
||
"request": request,
|
||
"users": users
|
||
})
|
||
|
||
@app.post("/api/record_api_call")
|
||
async def record_api_call(
|
||
record: APICallRecord,
|
||
_: str = Depends(verify_report_key)
|
||
):
|
||
"""Record an API call"""
|
||
try:
|
||
async with aiosqlite.connect(database_path) as db:
|
||
await db.execute("""
|
||
INSERT INTO api_calls (timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""", (
|
||
record.timestamp,
|
||
record.model_id,
|
||
record.user_email,
|
||
record.input_tokens,
|
||
record.output_tokens,
|
||
record.cost_usd
|
||
))
|
||
await db.commit()
|
||
|
||
logger.info(f"Recorded API call: {record.user_email} - {record.model_id} - ${record.cost_usd}")
|
||
return {"status": "success", "message": "API call recorded successfully"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error recording API call: {e}")
|
||
raise HTTPException(status_code=500, detail="Failed to record API call")
|
||
|
||
@app.get("/login", response_class=HTMLResponse)
|
||
async def login_page(request: Request):
|
||
"""Login page"""
|
||
return templates.TemplateResponse("login.html", {"request": request})
|
||
|
||
@app.post("/login")
|
||
async def login(request: Request, password: str = Form(...)):
|
||
"""Handle login form submission"""
|
||
if secrets.compare_digest(password, admin_api_key):
|
||
request.session["authenticated"] = True
|
||
return RedirectResponse(url="/dashboard", status_code=302)
|
||
else:
|
||
return templates.TemplateResponse("login.html", {
|
||
"request": request,
|
||
"error": "Invalid password"
|
||
})
|
||
|
||
@app.get("/logout")
|
||
async def logout(request: Request):
|
||
"""Logout and clear session"""
|
||
request.session.clear()
|
||
return RedirectResponse(url="/login", status_code=302)
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host=listen_host, port=int(listen_port))
|