Files
ACMChat-Dashboard/main.py

325 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Request, HTTPException, Depends, Form
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, HTTPBasic, HTTPBasicCredentials
from starlette.middleware.sessions import SessionMiddleware
import os
from loguru import logger
import sys
import uvicorn
import json
import aiosqlite
import time
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
from pydantic import BaseModel
import secrets
listen_port = os.getenv("LISTEN_PORT", "8000")
listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
config_path = os.getenv("CONFIG_PATH", "./config.json")
database_path = os.getenv("DATABASE_PATH", "./data.sqlite3")
logging_level = os.getenv("LOGGING_LEVEL", "TRACE")
# Configure loguru properly - remove default handler and add new one with correct level
logger.remove() # Remove the default handler
logger.add(sys.stderr, level=logging_level) # Add new handler with specified level
security_bearer = HTTPBearer()
security_basic = HTTPBasic()
class APICallRecord(BaseModel):
timestamp: int
model_id: str
user_email: str
input_tokens: int
output_tokens: int
cost_usd: float
def normalize_path(path: str) -> str:
"""
处理路径,支持绝对路径、用户目录和相对路径
Args:
path: 输入路径
Returns:
处理后的路径
"""
if path.startswith("/") or path.startswith("~"):
# 绝对路径,如果是~开头则展开为用户目录
return os.path.expanduser(path)
elif path.startswith("./"):
# 相对于脚本本身的相对路径
script_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(script_dir, path[2:]) # 去掉'./'前缀
else:
# 其他情况保持原样相对于pwd的相对路径
return path
# 处理config_path和database_path路径
config_path = normalize_path(config_path)
database_path = normalize_path(database_path)
with open(config_path, "r") as f:
config = json.load(f)
admin_api_key = config.get("admin_api_key", "")
report_api_key = config.get("report_api_key", "")
session_secret_key = config.get("session_secret_key", secrets.token_hex(32))
async def init_database():
"""Initialize the database and create tables if they don't exist"""
async with aiosqlite.connect(database_path) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS api_calls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp INTEGER NOT NULL,
model_id TEXT NOT NULL,
user_email TEXT NOT NULL,
input_tokens INTEGER NOT NULL,
output_tokens INTEGER NOT NULL,
cost_usd REAL NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
await db.commit()
logger.info("Database initialized successfully")
async def verify_admin_session(request: Request):
"""Verify admin session for dashboard access"""
if not request.session.get("authenticated"):
raise HTTPException(
status_code=302,
headers={"Location": "/login"}
)
return True
async def verify_report_key(credentials: HTTPAuthorizationCredentials = Depends(security_bearer)):
"""Verify report API key for API access using Bearer authentication"""
if credentials.credentials != report_api_key:
raise HTTPException(status_code=403, detail="Invalid report API key")
return credentials.credentials
async def get_24h_stats():
"""Get statistics for the past 24 hours"""
twenty_four_hours_ago = int(time.time()) - (24 * 60 * 60)
async with aiosqlite.connect(database_path) as db:
# Total chats (assuming each API call is a chat)
cursor = await db.execute(
"SELECT COUNT(*) FROM api_calls WHERE timestamp >= ?",
(twenty_four_hours_ago,)
)
total_chats = (await cursor.fetchone())[0]
# Total tokens
cursor = await db.execute(
"SELECT SUM(input_tokens + output_tokens) FROM api_calls WHERE timestamp >= ?",
(twenty_four_hours_ago,)
)
result = await cursor.fetchone()
total_tokens = result[0] if result[0] is not None else 0
# Total cost
cursor = await db.execute(
"SELECT SUM(cost_usd) FROM api_calls WHERE timestamp >= ?",
(twenty_four_hours_ago,)
)
result = await cursor.fetchone()
total_cost = result[0] if result[0] is not None else 0.0
return {
"total_chats": total_chats,
"total_tokens": total_tokens,
"total_cost": round(total_cost, 6)
}
async def get_recent_logs(page: int = 1, page_size: int = 100):
"""Get recent API call logs with pagination"""
offset = (page - 1) * page_size
async with aiosqlite.connect(database_path) as db:
# Get total count for pagination
cursor = await db.execute("SELECT COUNT(*) FROM api_calls")
total_count = (await cursor.fetchone())[0]
# Get paginated results
cursor = await db.execute("""
SELECT timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd
FROM api_calls
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
""", (page_size, offset))
logs = []
async for row in cursor:
logs.append({
"timestamp": row[0],
"model_id": row[1],
"user_email": row[2],
"input_tokens": row[3],
"output_tokens": row[4],
"total_tokens": row[3] + row[4],
"cost_usd": round(row[5], 6)
})
total_pages = (total_count + page_size - 1) // page_size # Ceiling division
return {
"logs": logs,
"pagination": {
"current_page": page,
"total_pages": total_pages,
"total_count": total_count,
"page_size": page_size,
"has_prev": page > 1,
"has_next": page < total_pages
}
}
async def get_users_report():
"""Get user consumption statistics"""
async with aiosqlite.connect(database_path) as db:
cursor = await db.execute("""
SELECT
user_email,
COUNT(*) as total_calls,
SUM(input_tokens) as total_input_tokens,
SUM(output_tokens) as total_output_tokens,
SUM(input_tokens + output_tokens) as total_tokens,
SUM(cost_usd) as total_cost
FROM api_calls
GROUP BY user_email
ORDER BY total_cost DESC
""")
users = []
async for row in cursor:
users.append({
"email": row[0],
"total_calls": row[1],
"total_input_tokens": row[2],
"total_output_tokens": row[3],
"total_tokens": row[4],
"total_cost": round(row[5], 6)
})
return users
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Startup
await init_database()
logger.info("Application started")
yield
# Shutdown
logger.info("Application shutting down")
app = FastAPI(lifespan=lifespan)
# Add session middleware
app.add_middleware(SessionMiddleware, secret_key=session_secret_key)
# Set up templates
templates = Jinja2Templates(directory="templates")
# Routes
@app.get("/")
async def root(request: Request):
"""Redirect to dashboard or login"""
if request.session.get("authenticated"):
return RedirectResponse(url="/dashboard", status_code=302)
else:
return RedirectResponse(url="/login", status_code=302)
@app.get("/dashboard", response_class=HTMLResponse)
async def dashboard(request: Request, _: bool = Depends(verify_admin_session)):
"""Main dashboard page"""
stats = await get_24h_stats()
return templates.TemplateResponse("dashboard.html", {
"request": request,
"stats": stats
})
@app.get("/dashboard/logs", response_class=HTMLResponse)
async def dashboard_logs(request: Request, page: int = 1, _: bool = Depends(verify_admin_session)):
"""Dashboard logs page"""
if page < 1:
page = 1
result = await get_recent_logs(page=page, page_size=100)
return templates.TemplateResponse("logs.html", {
"request": request,
"logs": result["logs"],
"pagination": result["pagination"]
})
@app.get("/dashboard/users_report", response_class=HTMLResponse)
async def dashboard_users_report(request: Request, _: bool = Depends(verify_admin_session)):
"""Dashboard users report page"""
users = await get_users_report()
return templates.TemplateResponse("users_report.html", {
"request": request,
"users": users
})
@app.post("/api/record_api_call")
async def record_api_call(
record: APICallRecord,
_: str = Depends(verify_report_key)
):
"""Record an API call"""
try:
async with aiosqlite.connect(database_path) as db:
await db.execute("""
INSERT INTO api_calls (timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd)
VALUES (?, ?, ?, ?, ?, ?)
""", (
record.timestamp,
record.model_id,
record.user_email,
record.input_tokens,
record.output_tokens,
record.cost_usd
))
await db.commit()
logger.info(f"Recorded API call: {record.user_email} - {record.model_id} - ${record.cost_usd}")
return {"status": "success", "message": "API call recorded successfully"}
except Exception as e:
logger.error(f"Error recording API call: {e}")
raise HTTPException(status_code=500, detail="Failed to record API call")
@app.get("/login", response_class=HTMLResponse)
async def login_page(request: Request):
"""Login page"""
return templates.TemplateResponse("login.html", {"request": request})
@app.post("/login")
async def login(request: Request, password: str = Form(...)):
"""Handle login form submission"""
if secrets.compare_digest(password, admin_api_key):
request.session["authenticated"] = True
return RedirectResponse(url="/dashboard", status_code=302)
else:
return templates.TemplateResponse("login.html", {
"request": request,
"error": "Invalid password"
})
@app.get("/logout")
async def logout(request: Request):
"""Logout and clear session"""
request.session.clear()
return RedirectResponse(url="/login", status_code=302)
if __name__ == "__main__":
uvicorn.run(app, host=listen_host, port=int(listen_port))