Files
variet_llm/scripts/auto_tune_gemma4_256k.py

340 lines
9.4 KiB
Python

"""
Gemma4 26B-A4B Comprehensive Auto-Tuner | 256K Context | RTX 3060 12GB
Phase 1: -ngl sweep (GPU layers)
Phase 2: -t / -tb sweep (CPU threads)
Phase 3: -ub / -b sweep (batch sizes)
Phase 4: --cache-type-k/v sweep (KV cache precision)
Phase 5: --no-mmap, --poll, --prio sweep (misc)
Each phase fixes the best from previous phases.
"""
import subprocess
import time
import json
import urllib.request
import sys
import os
import itertools
try:
sys.stdout.reconfigure(encoding='utf-8')
except AttributeError:
pass
BASE_URL = "http://127.0.0.1:8000"
LLAMA_SERVER = r"llama_bin_run\llama-server.exe"
MODEL = r"models\gemma-4-26B-A4B-it-Q4_K_M.gguf"
CONTEXT = 262144
BENCHMARK_RUNS = 3
BENCHMARK_TOKENS = 200
# ─── Baseline (from previous tuning at -c 4096) ───
BEST = {
"ngl": 22,
"t": 8,
"tb": 8,
"ub": 512,
"b": 2048,
"ctk": "q4_0",
"ctv": "q4_0",
"fa": "on",
"mlock": True,
"mmap": True,
"prio": 2,
"poll": 50,
}
ALL_RESULTS = []
def kill_server():
subprocess.run(["taskkill", "/F", "/IM", "llama-server.exe"],
capture_output=True)
time.sleep(4)
def build_cmd(cfg):
cmd = [LLAMA_SERVER, "--model", MODEL,
"-ngl", str(cfg["ngl"]),
"-c", str(CONTEXT),
"-np", "1",
"-fa", cfg["fa"],
"--cache-type-k", cfg["ctk"],
"--cache-type-v", cfg["ctv"],
"-ub", str(cfg["ub"]),
"-b", str(cfg["b"]),
"-t", str(cfg["t"]),
"-tb", str(cfg["tb"]),
"--prio", str(cfg["prio"]),
"--poll", str(cfg["poll"]),
"--port", "8000",
"--host", "0.0.0.0"]
if cfg["mlock"]:
cmd.append("--mlock")
if not cfg["mmap"]:
cmd.append("--no-mmap")
return cmd
def start_server(cfg):
cmd = build_cmd(cfg)
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=os.getcwd(), text=True, encoding='utf-8', errors='replace'
)
return proc
def wait_for_server(timeout=180):
start = time.time()
while time.time() - start < timeout:
try:
req = urllib.request.Request(f"{BASE_URL}/health")
with urllib.request.urlopen(req, timeout=3) as resp:
data = json.loads(resp.read())
if data.get("status") == "ok":
return True
except:
pass
time.sleep(2)
return False
def run_benchmark(max_tokens=BENCHMARK_TOKENS):
payload = json.dumps({
"model": "local-model",
"messages": [{"role": "user", "content": "Count from 1 to 50, writing each number on a new line."}],
"max_tokens": max_tokens,
"temperature": 0.0
}).encode("utf-8")
req = urllib.request.Request(
f"{BASE_URL}/v1/chat/completions",
data=payload,
headers={"Content-Type": "application/json"}
)
start = time.time()
with urllib.request.urlopen(req, timeout=300) as resp:
result = json.loads(resp.read())
elapsed = time.time() - start
usage = result.get("usage", {})
ct = usage.get("completion_tokens", 0)
return ct / elapsed if elapsed > 0 else 0
def get_vram():
try:
r = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
parts = r.stdout.strip().split(",")
return int(parts[0].strip()), int(parts[1].strip())
except:
return 0, 0
def test_config(cfg, label=""):
kill_server()
desc = label or str(cfg)
print(f" [{desc}] Starting server...")
proc = start_server(cfg)
if not wait_for_server():
print(f" [{desc}] FAILED to start")
proc.kill()
return None
vram_used, vram_total = get_vram()
print(f" [{desc}] VRAM: {vram_used}/{vram_total} MiB | ", end="", flush=True)
# Warmup
try:
run_benchmark(max_tokens=20)
except:
pass
# Benchmark
speeds = []
for i in range(BENCHMARK_RUNS):
try:
tps = run_benchmark()
speeds.append(tps)
except Exception as e:
print(f"ERR({e}) ", end="", flush=True)
proc.kill()
if not speeds:
print("ALL FAILED")
return None
avg = sum(speeds) / len(speeds)
best = max(speeds)
print(f"AVG: {avg:.2f} t/s | BEST: {best:.2f} t/s")
result = {**cfg, "avg_tps": avg, "best_tps": best,
"vram_used": vram_used, "vram_total": vram_total, "label": label}
ALL_RESULTS.append(result)
return result
def phase_sweep(phase_name, param_name, values, base_cfg):
print(f"\n{'='*70}")
print(f" PHASE: {phase_name}")
print(f" Sweeping: {param_name} = {values}")
print(f"{'='*70}")
best_result = None
for val in values:
cfg = {**base_cfg}
if isinstance(param_name, list):
for p, v in zip(param_name, val):
cfg[p] = v
label = " | ".join(f"{p}={v}" for p, v in zip(param_name, val))
else:
cfg[param_name] = val
label = f"{param_name}={val}"
r = test_config(cfg, label)
if r and (best_result is None or r["avg_tps"] > best_result["avg_tps"]):
best_result = r
if best_result:
print(f"\n ★ Phase winner: {best_result['label']}{best_result['avg_tps']:.2f} t/s")
return best_result
def main():
print("=" * 70)
print(" Gemma4 26B-A4B COMPREHENSIVE Auto-Tuner")
print(" 256K Context | RTX 3060 12GB")
print("=" * 70)
print()
cfg = dict(BEST)
# ─── Phase 1: -ngl (already done, quick verify top 3) ───
r = phase_sweep("GPU Layers (-ngl)", "ngl", [22, 21, 20], cfg)
if r:
cfg["ngl"] = r["ngl"]
# ─── Phase 2: CPU threads (-t, -tb) ───
thread_combos = [
(2, 2), (4, 4), (4, 8), (6, 6), (6, 8),
(8, 8), (8, 12), (10, 10), (12, 12), (16, 16)
]
r = phase_sweep("CPU Threads (-t, -tb)", ["t", "tb"], thread_combos, cfg)
if r:
cfg["t"] = r["t"]
cfg["tb"] = r["tb"]
# ─── Phase 3: Batch sizes (-ub, -b) ───
batch_combos = [
(128, 512), (256, 1024), (256, 2048),
(512, 1024), (512, 2048), (512, 4096),
(1024, 2048), (1024, 4096)
]
r = phase_sweep("Batch Sizes (-ub, -b)", ["ub", "b"], batch_combos, cfg)
if r:
cfg["ub"] = r["ub"]
cfg["b"] = r["b"]
# ─── Phase 4: KV cache precision ───
kv_combos = [
("q4_0", "q4_0"),
("q8_0", "q8_0"),
("q4_0", "q8_0"),
("f16", "f16"),
]
r = phase_sweep("KV Cache Type (-ctk, -ctv)", ["ctk", "ctv"], kv_combos, cfg)
if r:
cfg["ctk"] = r["ctk"]
cfg["ctv"] = r["ctv"]
# ─── Phase 5: Misc (mmap, poll, prio) ───
misc_combos = [
(True, 50, 2), # baseline
(False, 50, 2), # no-mmap
(True, 0, 2), # no polling
(True, 100, 2), # max polling
(True, 50, 3), # realtime priority
(False, 0, 3), # no-mmap + no-poll + realtime
]
r = phase_sweep("Misc (mmap, poll, prio)", ["mmap", "poll", "prio"], misc_combos, cfg)
if r:
cfg["mmap"] = r["mmap"]
cfg["poll"] = r["poll"]
cfg["prio"] = r["prio"]
# ─── Final Report ───
print()
print("=" * 70)
print(" FINAL OPTIMAL CONFIGURATION")
print("=" * 70)
print(f" ngl: {cfg['ngl']}")
print(f" threads: -t {cfg['t']} -tb {cfg['tb']}")
print(f" batch: -ub {cfg['ub']} -b {cfg['b']}")
print(f" kv cache: -ctk {cfg['ctk']} -ctv {cfg['ctv']}")
print(f" flash: -fa {cfg['fa']}")
print(f" mlock: {'yes' if cfg['mlock'] else 'no'}")
print(f" mmap: {'yes' if cfg['mmap'] else 'no (--no-mmap)'}")
print(f" prio: {cfg['prio']}")
print(f" poll: {cfg['poll']}")
print()
# Final verification run
print(" Running final verification (5 runs)...")
kill_server()
proc = start_server(cfg)
wait_for_server()
try:
run_benchmark(max_tokens=20)
except:
pass
final_speeds = []
for i in range(5):
try:
tps = run_benchmark()
final_speeds.append(tps)
print(f" Run {i+1}: {tps:.2f} t/s")
except:
pass
proc.kill()
if final_speeds:
avg = sum(final_speeds) / len(final_speeds)
best = max(final_speeds)
print(f"\n ★ FINAL: AVG {avg:.2f} t/s | BEST {best:.2f} t/s")
print()
cmd_parts = [
f"llama-server --model {MODEL}",
f"-ngl {cfg['ngl']} -c {CONTEXT}",
f"-t {cfg['t']} -tb {cfg['tb']}",
f"-ub {cfg['ub']} -b {cfg['b']}",
f"-fa {cfg['fa']}",
f"--cache-type-k {cfg['ctk']} --cache-type-v {cfg['ctv']}",
f"--prio {cfg['prio']} --poll {cfg['poll']}",
]
if cfg["mlock"]:
cmd_parts.append("--mlock")
if not cfg["mmap"]:
cmd_parts.append("--no-mmap")
cmd_parts.append("--port 8000 --host 0.0.0.0")
print(" Recommended command:")
print(f" {' '.join(cmd_parts)}")
print("=" * 70)
# Dump all results to JSON
with open("scripts/tune_results_gemma4_256k.json", "w") as f:
json.dump(ALL_RESULTS, f, indent=2, default=str)
print(f"\n Full results saved: scripts/tune_results_gemma4_256k.json")
if __name__ == "__main__":
main()