first version
This commit is contained in:
276
main.py
276
main.py
@ -1,6 +1,276 @@
|
||||
def main():
|
||||
print("Hello from acmchat-dashboard!")
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends, Form
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, HTTPBasic, HTTPBasicCredentials
|
||||
import os
|
||||
from loguru import logger
|
||||
import sys
|
||||
import uvicorn
|
||||
import json
|
||||
import aiosqlite
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel
|
||||
import secrets
|
||||
|
||||
listen_port = os.getenv("LISTEN_PORT", "8000")
|
||||
listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
|
||||
config_path = os.getenv("CONFIG_PATH", "./config.json")
|
||||
database_path = os.getenv("DATABASE_PATH", "./data.sqlite3")
|
||||
logging_level = os.getenv("LOGGING_LEVEL", "TRACE")
|
||||
|
||||
# Configure loguru properly - remove default handler and add new one with correct level
|
||||
logger.remove() # Remove the default handler
|
||||
logger.add(sys.stderr, level=logging_level) # Add new handler with specified level
|
||||
|
||||
security_bearer = HTTPBearer()
|
||||
security_basic = HTTPBasic()
|
||||
|
||||
class APICallRecord(BaseModel):
|
||||
timestamp: int
|
||||
model_id: str
|
||||
user_email: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cost_usd: float
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""
|
||||
处理路径,支持绝对路径、用户目录和相对路径
|
||||
|
||||
Args:
|
||||
path: 输入路径
|
||||
|
||||
Returns:
|
||||
处理后的路径
|
||||
"""
|
||||
if path.startswith("/") or path.startswith("~"):
|
||||
# 绝对路径,如果是~开头则展开为用户目录
|
||||
return os.path.expanduser(path)
|
||||
elif path.startswith("./"):
|
||||
# 相对于脚本本身的相对路径
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(script_dir, path[2:]) # 去掉'./'前缀
|
||||
else:
|
||||
# 其他情况保持原样(相对于pwd的相对路径)
|
||||
return path
|
||||
|
||||
# 处理config_path和database_path路径
|
||||
config_path = normalize_path(config_path)
|
||||
database_path = normalize_path(database_path)
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
admin_api_key = config.get("admin_api_key", "")
|
||||
report_api_key = config.get("report_api_key", "")
|
||||
|
||||
async def init_database():
|
||||
"""Initialize the database and create tables if they don't exist"""
|
||||
async with aiosqlite.connect(database_path) as db:
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS api_calls (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL,
|
||||
input_tokens INTEGER NOT NULL,
|
||||
output_tokens INTEGER NOT NULL,
|
||||
cost_usd REAL NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
await db.commit()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
async def verify_admin_basic_auth(credentials: HTTPBasicCredentials = Depends(security_basic)):
|
||||
"""Verify admin credentials for dashboard access using Basic HTTP authentication"""
|
||||
# Check username and password
|
||||
is_correct_username = secrets.compare_digest(credentials.username, "admin")
|
||||
is_correct_password = secrets.compare_digest(credentials.password, admin_api_key)
|
||||
|
||||
if not (is_correct_username and is_correct_password):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid credentials",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return credentials.username
|
||||
|
||||
async def verify_report_key(credentials: HTTPAuthorizationCredentials = Depends(security_bearer)):
|
||||
"""Verify report API key for API access using Bearer authentication"""
|
||||
if credentials.credentials != report_api_key:
|
||||
raise HTTPException(status_code=403, detail="Invalid report API key")
|
||||
return credentials.credentials
|
||||
|
||||
async def get_24h_stats():
|
||||
"""Get statistics for the past 24 hours"""
|
||||
twenty_four_hours_ago = int(time.time()) - (24 * 60 * 60)
|
||||
|
||||
async with aiosqlite.connect(database_path) as db:
|
||||
# Total chats (assuming each API call is a chat)
|
||||
cursor = await db.execute(
|
||||
"SELECT COUNT(*) FROM api_calls WHERE timestamp >= ?",
|
||||
(twenty_four_hours_ago,)
|
||||
)
|
||||
total_chats = (await cursor.fetchone())[0]
|
||||
|
||||
# Total tokens
|
||||
cursor = await db.execute(
|
||||
"SELECT SUM(input_tokens + output_tokens) FROM api_calls WHERE timestamp >= ?",
|
||||
(twenty_four_hours_ago,)
|
||||
)
|
||||
result = await cursor.fetchone()
|
||||
total_tokens = result[0] if result[0] is not None else 0
|
||||
|
||||
# Total cost
|
||||
cursor = await db.execute(
|
||||
"SELECT SUM(cost_usd) FROM api_calls WHERE timestamp >= ?",
|
||||
(twenty_four_hours_ago,)
|
||||
)
|
||||
result = await cursor.fetchone()
|
||||
total_cost = result[0] if result[0] is not None else 0.0
|
||||
|
||||
return {
|
||||
"total_chats": total_chats,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": round(total_cost, 4)
|
||||
}
|
||||
|
||||
async def get_recent_logs(limit: int = 100):
|
||||
"""Get recent API call logs"""
|
||||
async with aiosqlite.connect(database_path) as db:
|
||||
cursor = await db.execute("""
|
||||
SELECT timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd
|
||||
FROM api_calls
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
logs = []
|
||||
async for row in cursor:
|
||||
logs.append({
|
||||
"timestamp": row[0],
|
||||
"datetime": datetime.fromtimestamp(row[0]).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"model_id": row[1],
|
||||
"user_email": row[2],
|
||||
"input_tokens": row[3],
|
||||
"output_tokens": row[4],
|
||||
"total_tokens": row[3] + row[4],
|
||||
"cost_usd": round(row[5], 4)
|
||||
})
|
||||
|
||||
return logs
|
||||
|
||||
async def get_users_report():
|
||||
"""Get user consumption statistics"""
|
||||
async with aiosqlite.connect(database_path) as db:
|
||||
cursor = await db.execute("""
|
||||
SELECT
|
||||
user_email,
|
||||
COUNT(*) as total_calls,
|
||||
SUM(input_tokens) as total_input_tokens,
|
||||
SUM(output_tokens) as total_output_tokens,
|
||||
SUM(input_tokens + output_tokens) as total_tokens,
|
||||
SUM(cost_usd) as total_cost
|
||||
FROM api_calls
|
||||
GROUP BY user_email
|
||||
ORDER BY total_cost DESC
|
||||
""")
|
||||
|
||||
users = []
|
||||
async for row in cursor:
|
||||
users.append({
|
||||
"email": row[0],
|
||||
"total_calls": row[1],
|
||||
"total_input_tokens": row[2],
|
||||
"total_output_tokens": row[3],
|
||||
"total_tokens": row[4],
|
||||
"total_cost": round(row[5], 4)
|
||||
})
|
||||
|
||||
return users
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
# Startup
|
||||
await init_database()
|
||||
logger.info("Application started")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Application shutting down")
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Set up templates
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
# Routes
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Redirect to dashboard"""
|
||||
return RedirectResponse(url="/dashboard", status_code=302)
|
||||
|
||||
@app.get("/dashboard", response_class=HTMLResponse)
|
||||
async def dashboard(request: Request, _: str = Depends(verify_admin_basic_auth)):
|
||||
"""Main dashboard page"""
|
||||
stats = await get_24h_stats()
|
||||
return templates.TemplateResponse("dashboard.html", {
|
||||
"request": request,
|
||||
"stats": stats
|
||||
})
|
||||
|
||||
@app.get("/dashboard/logs", response_class=HTMLResponse)
|
||||
async def dashboard_logs(request: Request, _: str = Depends(verify_admin_basic_auth)):
|
||||
"""Dashboard logs page"""
|
||||
logs = await get_recent_logs(200) # Get more logs for the logs page
|
||||
return templates.TemplateResponse("logs.html", {
|
||||
"request": request,
|
||||
"logs": logs
|
||||
})
|
||||
|
||||
@app.get("/dashboard/users_report", response_class=HTMLResponse)
|
||||
async def dashboard_users_report(request: Request, _: str = Depends(verify_admin_basic_auth)):
|
||||
"""Dashboard users report page"""
|
||||
users = await get_users_report()
|
||||
return templates.TemplateResponse("users_report.html", {
|
||||
"request": request,
|
||||
"users": users
|
||||
})
|
||||
|
||||
@app.post("/api/record_api_call")
|
||||
async def record_api_call(
|
||||
record: APICallRecord,
|
||||
_: str = Depends(verify_report_key)
|
||||
):
|
||||
"""Record an API call"""
|
||||
try:
|
||||
async with aiosqlite.connect(database_path) as db:
|
||||
await db.execute("""
|
||||
INSERT INTO api_calls (timestamp, model_id, user_email, input_tokens, output_tokens, cost_usd)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.timestamp,
|
||||
record.model_id,
|
||||
record.user_email,
|
||||
record.input_tokens,
|
||||
record.output_tokens,
|
||||
record.cost_usd
|
||||
))
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Recorded API call: {record.user_email} - {record.model_id} - ${record.cost_usd}")
|
||||
return {"status": "success", "message": "API call recorded successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording API call: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to record API call")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
uvicorn.run(app, host=listen_host, port=int(listen_port))
|
||||
|
Reference in New Issue
Block a user