409 lines
14 KiB
Python
409 lines
14 KiB
Python
"""
|
|
Variet LLM Engine v1.0
|
|
━━━━━━━━━━━━━━━━━━━━━━
|
|
FastAPI reverse-proxy + llama-server process manager.
|
|
|
|
Architecture:
|
|
Client (Machine B) ──► Engine :8000 ──► llama-server :8080 (internal)
|
|
|
|
- /v1/* → transparent proxy to llama-server (OpenAI-compatible)
|
|
- /engine/status → current model & state
|
|
- /engine/models → available roles
|
|
- /engine/switch → hot-swap model
|
|
- /engine/health → health check
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import socket
|
|
import asyncio
|
|
import logging
|
|
import subprocess
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager
|
|
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, Request, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from starlette.background import BackgroundTask
|
|
|
|
# ── Logging ──────────────────────────────────────────────
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[%(asctime)s] %(levelname)s %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger("variet-engine")
|
|
|
|
# ── Paths ────────────────────────────────────────────────
|
|
ROOT_DIR = Path(__file__).resolve().parent.parent
|
|
CONFIG_FILE = ROOT_DIR / "config" / "engine_models.json"
|
|
LOG_DIR = ROOT_DIR / "logs"
|
|
|
|
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
# Engine State
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
class EngineState:
|
|
"""Tracks the running llama-server process and current role."""
|
|
|
|
def __init__(self):
|
|
self.state: str = "starting" # starting | loading | ready | error
|
|
self.current_role: str | None = None
|
|
self.process: subprocess.Popen | None = None
|
|
self.config: dict = {}
|
|
self.boot_time: float = time.time()
|
|
self._log_file = None
|
|
|
|
def load_config(self):
|
|
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
self.config = json.load(f)
|
|
log.info(f"Config loaded: {len(self.config.get('roles', {}))} roles available")
|
|
|
|
@property
|
|
def internal_url(self) -> str:
|
|
host = self.config["llama_server"]["internal_host"]
|
|
port = self.config["llama_server"]["internal_port"]
|
|
return f"http://{host}:{port}"
|
|
|
|
@property
|
|
def role_info(self) -> dict:
|
|
if self.current_role and self.current_role in self.config.get("roles", {}):
|
|
return self.config["roles"][self.current_role]
|
|
return {}
|
|
|
|
|
|
engine = EngineState()
|
|
|
|
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
# Process Management
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
def _is_port_free(port: int) -> bool:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("127.0.0.1", port)) != 0
|
|
|
|
|
|
def kill_llama_server():
|
|
"""Force-kill any running llama-server process."""
|
|
log.info("Stopping llama-server...")
|
|
|
|
# Kill our tracked process first
|
|
if engine.process and engine.process.poll() is None:
|
|
engine.process.terminate()
|
|
try:
|
|
engine.process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
engine.process.kill()
|
|
|
|
# Also kill any orphaned llama-server
|
|
subprocess.run(
|
|
"taskkill /F /IM llama-server.exe",
|
|
shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
|
)
|
|
|
|
# Close log file
|
|
if engine._log_file:
|
|
try:
|
|
engine._log_file.close()
|
|
except:
|
|
pass
|
|
engine._log_file = None
|
|
|
|
# Wait for port to free
|
|
port = engine.config.get("llama_server", {}).get("internal_port", 8080)
|
|
for _ in range(20):
|
|
if _is_port_free(port):
|
|
break
|
|
time.sleep(0.5)
|
|
else:
|
|
log.warning(f"Port {port} still occupied after 10s")
|
|
|
|
|
|
def build_command(role: str) -> list[str]:
|
|
"""Build llama-server CLI command from config."""
|
|
role_cfg = engine.config["roles"][role]
|
|
llama_path = str(ROOT_DIR / engine.config["llama_server"]["path"])
|
|
model_path = str(ROOT_DIR / role_cfg["model_path"])
|
|
internal_port = engine.config["llama_server"]["internal_port"]
|
|
internal_host = engine.config["llama_server"]["internal_host"]
|
|
|
|
cmd = [llama_path, "--model", model_path, "--port", str(internal_port), "--host", internal_host]
|
|
|
|
# args is a flat list of CLI arguments, passed through as-is
|
|
cmd.extend(role_cfg.get("args", []))
|
|
|
|
return cmd
|
|
|
|
|
|
def start_llama_server(role: str):
|
|
"""Start llama-server for the given role."""
|
|
role_cfg = engine.config["roles"][role]
|
|
cmd = build_command(role)
|
|
|
|
log.info(f"Starting [{role}] {role_cfg['display_name']}")
|
|
log.info(f"CMD: {' '.join(cmd)}")
|
|
|
|
LOG_DIR.mkdir(exist_ok=True)
|
|
engine._log_file = open(LOG_DIR / "llama-server.log", "a", encoding="utf-8")
|
|
engine.process = subprocess.Popen(
|
|
cmd, stdout=engine._log_file, stderr=subprocess.STDOUT, cwd=str(ROOT_DIR)
|
|
)
|
|
engine.current_role = role
|
|
engine.boot_time = time.time()
|
|
|
|
|
|
async def wait_for_health(timeout: int = 300) -> bool:
|
|
"""Poll llama-server /health until ready."""
|
|
url = f"{engine.internal_url}/health"
|
|
start = time.time()
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
while time.time() - start < timeout:
|
|
# Check if process died
|
|
if engine.process and engine.process.poll() is not None:
|
|
log.error(f"llama-server exited with code {engine.process.returncode}")
|
|
engine.state = "error"
|
|
return False
|
|
|
|
try:
|
|
resp = await client.get(url, timeout=3.0)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
if data.get("status") in ("ok", "ready"):
|
|
elapsed = round(time.time() - start, 1)
|
|
log.info(f"llama-server READY in {elapsed}s")
|
|
engine.state = "ready"
|
|
return True
|
|
except (httpx.ConnectError, httpx.ReadTimeout, httpx.ConnectTimeout):
|
|
pass
|
|
|
|
await asyncio.sleep(2)
|
|
|
|
log.error(f"llama-server failed to become healthy within {timeout}s")
|
|
engine.state = "error"
|
|
return False
|
|
|
|
|
|
async def perform_switch(role: str):
|
|
"""Full hot-swap sequence: kill → start → wait."""
|
|
engine.state = "loading"
|
|
log.info(f"═══ HOT-SWAP: switching to [{role}] ═══")
|
|
kill_llama_server()
|
|
start_llama_server(role)
|
|
await wait_for_health()
|
|
|
|
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
# FastAPI App
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Startup: boot default model. Shutdown: kill server."""
|
|
LOG_DIR.mkdir(exist_ok=True)
|
|
engine.load_config()
|
|
default_role = engine.config.get("default_role", "fast")
|
|
|
|
engine.state = "loading"
|
|
kill_llama_server()
|
|
start_llama_server(default_role)
|
|
asyncio.create_task(wait_for_health())
|
|
|
|
yield # App is running
|
|
|
|
log.info("Shutting down engine...")
|
|
kill_llama_server()
|
|
|
|
|
|
app = FastAPI(title="Variet LLM Engine", version="1.0", lifespan=lifespan)
|
|
|
|
|
|
# ── Engine Management Endpoints ──────────────────────────
|
|
|
|
@app.get("/engine/health")
|
|
async def health():
|
|
return {"state": engine.state, "role": engine.current_role}
|
|
|
|
|
|
def _get_arg_value(args_list: list, flag: str):
|
|
"""Extract a value from a CLI args list by flag name."""
|
|
try:
|
|
idx = args_list.index(flag)
|
|
return args_list[idx + 1]
|
|
except (ValueError, IndexError):
|
|
return None
|
|
|
|
|
|
@app.get("/engine/status")
|
|
async def status():
|
|
info = engine.role_info
|
|
args = info.get("args", [])
|
|
return {
|
|
"state": engine.state,
|
|
"role": engine.current_role,
|
|
"display_name": info.get("display_name", "Unknown"),
|
|
"measured_tps": info.get("measured_tps"),
|
|
"context_size": _get_arg_value(args, "-c"),
|
|
"uptime_seconds": round(time.time() - engine.boot_time, 1),
|
|
}
|
|
|
|
|
|
@app.get("/engine/models")
|
|
async def models():
|
|
roles = {}
|
|
for name, cfg in engine.config.get("roles", {}).items():
|
|
args = cfg.get("args", [])
|
|
roles[name] = {
|
|
"display_name": cfg["display_name"],
|
|
"measured_tps": cfg.get("measured_tps"),
|
|
"context_size": _get_arg_value(args, "-c"),
|
|
}
|
|
return {"current": engine.current_role, "roles": roles}
|
|
|
|
|
|
@app.post("/engine/switch/{role}")
|
|
async def switch(role: str, background_tasks: BackgroundTasks):
|
|
engine.load_config()
|
|
|
|
if role not in engine.config["roles"]:
|
|
available = list(engine.config["roles"].keys())
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"error": f"Unknown role '{role}'", "available": available},
|
|
)
|
|
|
|
if engine.state == "loading":
|
|
return JSONResponse(
|
|
status_code=409,
|
|
content={"error": "A model switch is already in progress."},
|
|
)
|
|
|
|
if role == engine.current_role and engine.state == "ready":
|
|
return {"status": "already_active", "role": role}
|
|
|
|
target = engine.config["roles"][role]
|
|
background_tasks.add_task(perform_switch, role)
|
|
|
|
return {
|
|
"status": "switching",
|
|
"from_role": engine.current_role,
|
|
"to_role": role,
|
|
"to_model": target["display_name"],
|
|
"eta_seconds": 180 if role == "ultra" else 30,
|
|
}
|
|
|
|
|
|
# ── Reverse Proxy ────────────────────────────────────────
|
|
|
|
@app.api_route(
|
|
"/v1/{path:path}",
|
|
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
|
)
|
|
async def proxy(request: Request, path: str):
|
|
"""Transparently proxy all /v1/* to llama-server."""
|
|
|
|
if engine.state != "ready":
|
|
eta = 180 if engine.current_role == "ultra" else 30
|
|
return JSONResponse(
|
|
status_code=503,
|
|
headers={"Retry-After": str(eta)},
|
|
content={
|
|
"error": {
|
|
"message": f"Engine is loading model ({engine.current_role}). Retry in ~{eta}s.",
|
|
"type": "engine_loading",
|
|
"code": "service_unavailable",
|
|
},
|
|
"state": engine.state,
|
|
},
|
|
)
|
|
|
|
target_url = f"{engine.internal_url}/v1/{path}"
|
|
body = await request.body()
|
|
|
|
# Forward headers, removing hop-by-hop
|
|
fwd_headers = {}
|
|
for k, v in request.headers.items():
|
|
if k.lower() not in ("host", "content-length", "transfer-encoding"):
|
|
fwd_headers[k] = v
|
|
|
|
client = httpx.AsyncClient(timeout=7200.0) # 2h — dense models may need extended time
|
|
try:
|
|
req = client.build_request(
|
|
method=request.method,
|
|
url=target_url,
|
|
headers=fwd_headers,
|
|
content=body,
|
|
params=dict(request.query_params),
|
|
)
|
|
resp = await client.send(req, stream=True)
|
|
|
|
async def stream_and_close():
|
|
async for chunk in resp.aiter_raw():
|
|
yield chunk
|
|
await resp.aclose()
|
|
await client.aclose()
|
|
|
|
# Forward response headers, removing hop-by-hop
|
|
resp_headers = {}
|
|
for k, v in resp.headers.items():
|
|
if k.lower() not in ("transfer-encoding", "content-length", "content-encoding"):
|
|
resp_headers[k] = v
|
|
|
|
return StreamingResponse(
|
|
stream_and_close(),
|
|
status_code=resp.status_code,
|
|
headers=resp_headers,
|
|
)
|
|
except httpx.ConnectError:
|
|
await client.aclose()
|
|
engine.state = "error"
|
|
return JSONResponse(
|
|
status_code=502,
|
|
content={"error": "llama-server connection failed. Model may have crashed."},
|
|
)
|
|
except Exception as e:
|
|
await client.aclose()
|
|
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
|
|
# ── Root info ────────────────────────────────────────────
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"name": "Variet LLM Engine",
|
|
"version": "1.0",
|
|
"state": engine.state,
|
|
"current_model": engine.role_info.get("display_name", "None"),
|
|
"endpoints": {
|
|
"inference": "/v1/chat/completions",
|
|
"status": "/engine/status",
|
|
"models": "/engine/models",
|
|
"switch": "/engine/switch/{role}",
|
|
"health": "/engine/health",
|
|
},
|
|
}
|
|
|
|
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
# Main
|
|
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
|
|
if __name__ == "__main__":
|
|
engine.load_config()
|
|
ext_port = engine.config["engine"].get("external_port", 8000)
|
|
ext_host = engine.config["engine"].get("external_host", "0.0.0.0")
|
|
|
|
log.info("══════════════════════════════════════")
|
|
log.info(" Variet LLM Engine v1.0")
|
|
log.info(f" Listening on {ext_host}:{ext_port}")
|
|
log.info(f" Default role: {engine.config.get('default_role', 'fast')}")
|
|
log.info("══════════════════════════════════════")
|
|
|
|
uvicorn.run(app, host=ext_host, port=ext_port, log_level="info")
|