AI 绘图 Web 应用开发 — 搭建在线绘图平台
AI 绘图 Web 应用开发 : 搭建在线绘图平台
经过八期学习,我们已经掌握了 AI 绘图的全部核心技术。但所有的能力都还停留在”命令行/本地界面”层面。如果你想做出别人可以直接使用的产品,就需要一个 Web 应用。
本期我们用 FastAPI(后端)+ 现代前端(Vanilla JS/React)搭建一个完整的在线绘图平台,包含用户认证、模型管理、图片生成、历史记录等核心功能。
这也是你从”会用工具”到”能做产品”的关键一步。
一、架构设计
1.1 整体架构
┌─────────────────────────────────────────────────┐
│ 前端 (SPA) │
│ HTML + CSS + JavaScript │
├─────────────────────────────────────────────────┤
│ HTTP / WebSocket │
├─────────────────────────────────────────────────┤
│ 后端 API (FastAPI) │
├─────────────────────────────────────────────────┤
│ 模型推理 │ 用户系统 │ 任务队列 │ 文件存储 │
├───────────┴──────────┴──────────┴──────────────┤
│ GPU / 模型 / 数据库 │
└─────────────────────────────────────────────────┘
1.2 数据流
用户操作 后端处理 推理引擎
│ │ │
├── 提交生成 ──→ 创建任务 ──→ ─────→ 模型推理
│ │ │
│ ←── 轮询状态 ──────┘
│ │ 完成 → 返回 URL
│ ←── 显示结果 ─┘
│
└── 管理记录 ──→ 数据库 CRUD ──→ SQLite/PostgreSQL
1.3 文件结构
09-web-app/
├── app.py # 主应用
├── static/ # 前端静态文件
│ ├── index.html # 主页面
│ ├── style.css # 样式
│ └── app.js # 前端逻辑
├── templates/ # Jinja2 模板(可选)
├── routes/
│ ├── auth.py # 认证路由
│ ├── generate.py # 生成路由
│ └── history.py # 历史记录
├── models/
│ ├── user.py # 用户模型
│ └── generation.py # 生成记录模型
└── requirements.txt
二、后端实现
2.1 主应用
"""app.py : AI 绘图 Web 应用主入口"""
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, HTMLResponse
import uvicorn
app = FastAPI(title="AI 绘图 Web 应用")
# 静态文件服务
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# 输出目录
outputs_dir = Path("web_outputs")
outputs_dir.mkdir(exist_ok=True)
app.mount("/outputs", StaticFiles(directory=str(outputs_dir)), name="outputs")
@app.get("/")
async def index():
"""前端首页"""
index_path = static_dir / "index.html"
if index_path.exists():
return HTMLResponse(index_path.read_text(encoding="utf-8"))
return {"message": "AI 绘图 Web 应用"}
# 加载子路由
from routes.generate import router as generate_router
from routes.history import router as history_router
app.include_router(generate_router, prefix="/api")
app.include_router(history_router, prefix="/api")
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8080, reload=True)
2.2 生成路由
"""routes/generate.py : 生成相关 API"""
import uuid
import time
import asyncio
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, HTTPException, BackgroundTasks, Form
from pydantic import BaseModel, Field
router = APIRouter()
# 推理引擎(单例)
class InferenceEngine:
def __init__(self):
self._pipe = None
def ensure_model(self):
if self._pipe is None:
from diffusers import StableDiffusionPipeline
import torch
self._pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
safety_checker=None,
).to("cuda")
self._pipe.enable_attention_slicing()
def generate(self, prompt: str, negative_prompt: str = "",
steps: int = 20, cfg: float = 7.5,
width: int = 512, height: int = 512,
seed: int = -1) -> bytes:
self.ensure_model()
import torch
generator = None
if seed != -1:
generator = torch.Generator("cuda").manual_seed(seed)
result = self._pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=cfg,
width=width,
height=height,
generator=generator,
).images[0]
import io
buf = io.BytesIO()
result.save(buf, format="PNG")
return buf.getvalue()
engine = InferenceEngine()
# 生成历史(内存版,生产环境请用数据库)
generation_history = []
task_progress = {}
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=1000)
negative_prompt: str = ""
steps: int = Field(20, ge=1, le=100)
cfg: float = Field(7.5, ge=1.0, le=20.0)
width: int = Field(512, ge=256, le=2048)
height: int = Field(512, ge=256, le=2048)
seed: int = Field(-1)
@router.post("/generate")
async def generate_image(req: GenerateRequest):
"""同步生成"""
try:
image_bytes = engine.generate(
req.prompt, req.negative_prompt,
req.steps, req.cfg,
req.width, req.height, req.seed
)
filename = f"gen_{uuid.uuid4().hex[:12]}.png"
output_path = Path("web_outputs") / filename
output_path.write_bytes(image_bytes)
record = {
"id": uuid.uuid4().hex[:12],
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"params": {
"steps": req.steps, "cfg": req.cfg,
"width": req.width, "height": req.height,
"seed": req.seed,
},
"image_url": f"/outputs/{filename}",
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
}
generation_history.insert(0, record)
return {"status": "success", "data": record}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate/async")
async def generate_async(req: GenerateRequest,
background_tasks: BackgroundTasks):
"""异步生成"""
task_id = uuid.uuid4().hex[:12]
task_progress[task_id] = {"status": "queued", "progress": 0}
async def process():
task_progress[task_id] = {"status": "processing", "progress": 0.3}
try:
import asyncio
await asyncio.sleep(0.1) # 让事件循环有机会响应其他请求
image_bytes = engine.generate(
req.prompt, req.negative_prompt,
req.steps, req.cfg,
req.width, req.height, req.seed
)
task_progress[task_id] = {"status": "processing", "progress": 0.7,
"message": "保存中..."}
filename = f"async_{task_id}.png"
output_path = Path("web_outputs") / filename
output_path.write_bytes(image_bytes)
record = {
"id": task_id,
"prompt": req.prompt,
"image_url": f"/outputs/{filename}",
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
}
generation_history.insert(0, record)
task_progress[task_id] = {
"status": "completed",
"progress": 1.0,
"result": record,
}
except Exception as e:
task_progress[task_id] = {
"status": "failed",
"error": str(e),
}
background_tasks.add_task(process)
return {"task_id": task_id, "status": "queued"}
@router.get("/task/{task_id}")
async def get_task_progress(task_id: str):
"""查询任务进度"""
progress = task_progress.get(task_id)
if not progress:
raise HTTPException(404, "Task not found")
return progress
@router.get("/history")
async def get_history(limit: int = 20):
"""获取生成历史"""
return {"records": generation_history[:limit]}
@router.get("/models")
async def get_models():
"""获取可用模型列表"""
return {
"models": [
{"id": "sd15", "name": "Stable Diffusion v1.5", "type": "sd"},
{"id": "sdxl", "name": "Stable Diffusion XL", "type": "sd"},
{"id": "flux", "name": "FLUX.1 Schnell", "type": "flux"},
]
}
2.3 历史记录路由(可选扩展)
"""routes/history.py : 生成历史记录 API"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
router = APIRouter()
# 扩展历史记录功能
class HistoryManager:
"""历史记录管理器"""
def __init__(self, db_path: str = "generation_history.json"):
self.db_path = db_path
self._load()
def _load(self):
import json
from pathlib import Path
p = Path(self.db_path)
self.records = json.loads(p.read_text()) if p.exists() else []
def _save(self):
import json
from pathlib import Path
Path(self.db_path).write_text(
json.dumps(self.records, ensure_ascii=False, indent=2)
)
def add(self, record: dict):
self.records.insert(0, record)
self._save()
def get_all(self, limit: int = 50) -> list:
return self.records[:limit]
def delete(self, record_id: str) -> bool:
before = len(self.records)
self.records = [r for r in self.records if r.get("id") != record_id]
if len(self.records) < before:
self._save()
return True
return False
def search(self, query: str) -> list:
query = query.lower()
return [
r for r in self.records
if query in r.get("prompt", "").lower()
]
# 集成到 FastAPI
history_manager = HistoryManager()
@router.get("/history/search")
async def search_history(query: str = ""):
"""搜索历史"""
if not query:
return {"records": history_manager.get_all()}
return {"records": history_manager.search(query)}
@router.delete("/history/{record_id}")
async def delete_history(record_id: str):
"""删除记录"""
if not history_manager.delete(record_id):
raise HTTPException(404, "Record not found")
return {"status": "deleted"}
三、前端实现
3.1 主页面
我们的前端使用纯 HTML/CSS/JavaScript,无需构建工具,开箱即用:
<!-- static/index.html : AI 绘图 Web 应用 -->
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI 绘图在线工具</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<div class="container">
<header>
<h1>🎨 AI 绘图在线工具</h1>
<p class="subtitle">输入描述文字,AI 为您生成图片</p>
</header>
<main>
<!-- 生成面板 -->
<section class="generate-panel">
<div class="input-group">
<label for="prompt">图片描述</label>
<textarea id="prompt" placeholder="描述你想生成的图片,如:a beautiful mountain landscape at sunset, highly detailed" rows="4"></textarea>
<span class="char-count" id="charCount">0 / 500</span>
</div>
<div class="input-group">
<label for="negative_prompt">负向提示词(可选)</label>
<textarea id="negative_prompt" placeholder="描述你不想要的内容,如:blurry, low quality, distorted" rows="2"></textarea>
</div>
<div class="params-grid">
<div class="param-item">
<label>步数</label>
<input type="range" id="steps" min="10" max="50" value="25">
<span id="stepsValue">25</span>
</div>
<div class="param-item">
<label>CFG</label>
<input type="range" id="cfg" min="1" max="20" value="7.5" step="0.5">
<span id="cfgValue">7.5</span>
</div>
<div class="param-item">
<label>宽度</label>
<select id="width">
<option value="512">512</option>
<option value="768" selected>768</option>
<option value="1024">1024</option>
</select>
</div>
<div class="param-item">
<label>高度</label>
<select id="height">
<option value="512">512</option>
<option value="768" selected>768</option>
<option value="1024">1024</option>
</select>
</div>
<div class="param-item">
<label>种子</label>
<input type="number" id="seed" value="-1" min="-1" max="999999">
<small>-1: 随机</small>
</div>
</div>
<button id="generateBtn" onclick="generate()">
<span class="btn-icon">✨</span> 生成图片
</button>
<div id="progress" class="progress-bar" style="display:none">
<div class="progress-fill" id="progressFill"></div>
<span id="progressText">生成中...</span>
</div>
</section>
<!-- 结果展示 -->
<section class="result-panel" id="resultPanel" style="display:none">
<h2>生成结果</h2>
<div class="image-container">
<img id="resultImage" alt="生成图片">
<div class="image-actions">
<button onclick="downloadImage()">⬇ 下载</button>
<button onclick="useAsImg2Img()">📷 以图生图</button>
</div>
</div>
<div class="result-info" id="resultInfo"></div>
</section>
<!-- 历史记录 -->
<section class="history-panel">
<h2>生成历史 <span id="historyCount" class="badge">0</span></h2>
<div id="historyGrid" class="history-grid"></div>
</section>
</main>
</div>
<script src="/static/app.js"></script>
</body>
</html>
3.2 样式
/* static/style.css : Web 应用样式 */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #0f0c29, #302b63, #24243e);
color: #e0e0e0;
min-height: 100vh;
line-height: 1.6;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
}
/* 头部 */
header {
text-align: center;
margin-bottom: 3rem;
}
header h1 {
font-size: 2.5rem;
background: linear-gradient(90deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0.5rem;
}
.subtitle {
color: #9ca3af;
font-size: 1.1rem;
}
/* 生成面板 */
.generate-panel {
background: rgba(255, 255, 255, 0.05);
border-radius: 16px;
padding: 2rem;
margin-bottom: 2rem;
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.1);
}
.input-group {
margin-bottom: 1.5rem;
}
.input-group label {
display: block;
margin-bottom: 0.5rem;
font-weight: 500;
color: #c4b5fd;
}
textarea {
width: 100%;
background: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(255, 255, 255, 0.1);
border-radius: 8px;
color: #e0e0e0;
padding: 0.8rem 1rem;
font-size: 0.95rem;
resize: vertical;
transition: border-color 0.2s;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
.char-count {
display: block;
text-align: right;
color: #6b7280;
font-size: 0.85rem;
margin-top: 0.3rem;
}
/* 参数网格 */
.params-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 1.5rem;
margin-bottom: 1.5rem;
}
.param-item {
background: rgba(0, 0, 0, 0.2);
border-radius: 8px;
padding: 1rem;
}
.param-item label {
display: block;
margin-bottom: 0.5rem;
color: #a78bfa;
font-size: 0.9rem;
}
.param-item input[type="range"] {
width: 100%;
accent-color: #667eea;
}
.param-item select,
.param-item input[type="number"] {
width: 100%;
background: rgba(0, 0, 0, 0.3);
border: 1px solid rgba(255, 255, 255, 0.1);
border-radius: 6px;
color: #e0e0e0;
padding: 0.5rem;
}
.param-item small {
display: block;
color: #6b7280;
font-size: 0.8rem;
margin-top: 0.3rem;
}
/* 按钮 */
#generateBtn {
width: 100%;
padding: 1rem;
background: linear-gradient(90deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 8px;
font-size: 1.1rem;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
#generateBtn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.4);
}
#generateBtn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.btn-icon {
margin-right: 0.5rem;
}
/* 进度条 */
.progress-bar {
margin-top: 1rem;
height: 24px;
background: rgba(255, 255, 255, 0.1);
border-radius: 12px;
overflow: hidden;
position: relative;
}
.progress-fill {
height: 100%;
background: linear-gradient(90deg, #667eea, #764ba2);
border-radius: 12px;
transition: width 0.3s;
width: 0%;
}
#progressText {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
font-size: 0.85rem;
font-weight: 500;
}
/* 结果面板 */
.result-panel {
background: rgba(255, 255, 255, 0.05);
border-radius: 16px;
padding: 2rem;
margin-bottom: 2rem;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.result-panel h2 {
margin-bottom: 1rem;
color: #c4b5fd;
}
.image-container {
text-align: center;
margin-bottom: 1rem;
}
.image-container img {
max-width: 100%;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
}
.image-actions {
margin-top: 1rem;
display: flex;
gap: 0.5rem;
justify-content: center;
}
.image-actions button {
padding: 0.5rem 1.5rem;
background: rgba(255, 255, 255, 0.1);
border: 1px solid rgba(255, 255, 255, 0.2);
border-radius: 6px;
color: #e0e0e0;
cursor: pointer;
transition: background 0.2s;
}
.image-actions button:hover {
background: rgba(255, 255, 255, 0.2);
}
.result-info {
font-size: 0.9rem;
color: #9ca3af;
padding: 1rem;
background: rgba(0, 0, 0, 0.2);
border-radius: 8px;
}
/* 历史记录 */
.history-panel {
background: rgba(255, 255, 255, 0.05);
border-radius: 16px;
padding: 2rem;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.history-panel h2 {
margin-bottom: 1.5rem;
color: #c4b5fd;
}
.badge {
background: #667eea;
color: white;
padding: 0.2rem 0.6rem;
border-radius: 10px;
font-size: 0.8rem;
vertical-align: middle;
}
.history-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(150px, 1fr));
gap: 1rem;
}
.history-item {
position: relative;
border-radius: 8px;
overflow: hidden;
cursor: pointer;
transition: transform 0.2s;
aspect-ratio: 1;
}
.history-item:hover {
transform: scale(1.05);
}
.history-item img {
width: 100%;
height: 100%;
object-fit: cover;
}
.history-item .prompt-overlay {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: rgba(0, 0, 0, 0.7);
padding: 0.5rem;
font-size: 0.8rem;
opacity: 0;
transition: opacity 0.2s;
text-overflow: ellipsis;
white-space: nowrap;
overflow: hidden;
}
.history-item:hover .prompt-overlay {
opacity: 1;
}
/* 加载动画 */
.loading {
display: inline-block;
width: 20px;
height: 20px;
border: 2px solid rgba(255, 255, 255, 0.3);
border-top-color: white;
border-radius: 50%;
animation: spin 0.8s linear infinite;
margin-right: 0.5rem;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
/* 响应式 */
@media (max-width: 768px) {
.container {
padding: 1rem;
}
header h1 {
font-size: 1.8rem;
}
.params-grid {
grid-template-columns: 1fr 1fr;
}
}
3.3 前端逻辑
/* static/app.js : 前端业务逻辑 */
// 更新参数显示
document.getElementById('steps').addEventListener('input', function() {
document.getElementById('stepsValue').textContent = this.value;
});
document.getElementById('cfg').addEventListener('input', function() {
document.getElementById('cfgValue').textContent = this.value;
});
document.getElementById('prompt').addEventListener('input', function() {
document.getElementById('charCount').textContent =
`${this.value.length} / 500`;
});
// 生成图片
async function generate() {
const btn = document.getElementById('generateBtn');
const progress = document.getElementById('progress');
const progressFill = document.getElementById('progressFill');
const progressText = document.getElementById('progressText');
const resultPanel = document.getElementById('resultPanel');
const prompt = document.getElementById('prompt').value.trim();
if (!prompt) {
alert('请输入图片描述');
return;
}
// 禁用按钮
btn.disabled = true;
btn.innerHTML = '<span class="loading"></span> 生成中...';
progress.style.display = 'block';
progressFill.style.width = '10%';
progressText.textContent = '正在提交...';
try {
// 1. 异步提交
const reqPayload = {
prompt: prompt,
negative_prompt: document.getElementById('negative_prompt').value,
steps: parseInt(document.getElementById('steps').value),
cfg: parseFloat(document.getElementById('cfg').value),
width: parseInt(document.getElementById('width').value),
height: parseInt(document.getElementById('height').value),
seed: parseInt(document.getElementById('seed').value),
};
// 2. 调用同步 API(为简化示例,实际可用异步)
progressFill.style.width = '30%';
progressText.textContent = '正在生成...';
const resp = await fetch('/api/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(reqPayload),
});
if (!resp.ok) {
const err = await resp.json();
throw new Error(err.detail || '生成失败');
}
const result = await resp.json();
// 3. 显示结果
progressFill.style.width = '100%';
progressText.textContent = '完成!';
setTimeout(() => {
progress.style.display = 'none';
resultPanel.style.display = 'block';
document.getElementById('resultImage').src = result.data.image_url;
document.getElementById('resultInfo').textContent =
`Prompt: ${result.data.prompt}\n` +
`步数: ${result.data.params.steps} | ` +
`CFG: ${result.data.params.cfg} | ` +
`分辨率: ${result.data.params.width}x${result.data.params.height} | ` +
`种子: ${result.data.params.seed}`;
// 刷新历史
loadHistory();
}, 500);
} catch (error) {
progressFill.style.width = '0%';
progressText.textContent = error.message || '生成失败';
setTimeout(() => {
progress.style.display = 'none';
}, 3000);
} finally {
btn.disabled = false;
btn.innerHTML = '<span class="btn-icon">✨</span> 生成图片';
}
}
// 下载图片
function downloadImage() {
const img = document.getElementById('resultImage');
const link = document.createElement('a');
link.href = img.src;
link.download = `ai_image_${Date.now()}.png`;
link.click();
}
// 加载历史记录
async function loadHistory() {
try {
const resp = await fetch('/api/history');
const data = await resp.json();
const grid = document.getElementById('historyGrid');
const count = document.getElementById('historyCount');
count.textContent = data.records.length;
grid.innerHTML = data.records.map(record => `
<div class="history-item" onclick="showResult('${record.image_url}', '${record.prompt.replace(/'/g, "\\'")}')">
<img src="${record.image_url}" alt="${record.prompt}" loading="lazy">
<div class="prompt-overlay">${record.prompt}</div>
</div>
`).join('');
} catch (error) {
console.error('加载历史失败:', error);
}
}
// 显示历史中的图片
function showResult(imageUrl, prompt) {
document.getElementById('resultPanel').style.display = 'block';
document.getElementById('resultImage').src = imageUrl;
document.getElementById('resultInfo').textContent = `Prompt: ${prompt}`;
// 滚动到结果
document.getElementById('resultPanel').scrollIntoView({
behavior: 'smooth'
});
}
// 初始加载
document.addEventListener('DOMContentLoaded', function() {
loadHistory();
// 支持 Ctrl+Enter 快捷生成
document.getElementById('prompt').addEventListener('keydown', function(e) {
if (e.ctrlKey && e.key === 'Enter') {
generate();
}
});
});
四、运行与使用
# 1. 安装依赖
pip install fastapi uvicorn diffusers transformers accelerate torch
# 2. 启动服务
python app.py
# 3. 浏览器打开
open http://localhost:8080
# 4. 输入 Prompt → 点击生成
五、进阶功能
5.1 使用 Vue/React 前端
上述纯 JS 版本可直接运行。如果你想用 SPA 框架搭建更复杂的界面,推荐 Vue 3:
// 用 Vue 3 改写主要组件(示意)
const { createApp, ref, reactive } = Vue;
createApp({
setup() {
const prompt = ref('');
const params = reactive({
steps: 25, cfg: 7.5,
width: 768, height: 768,
seed: -1,
});
const result = ref(null);
const loading = ref(false);
async function generate() {
loading.value = true;
try {
const resp = await fetch('/api/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
prompt: prompt.value,
...params,
}),
});
result.value = (await resp.json()).data;
} catch (e) {
alert(e.message);
} finally {
loading.value = false;
}
}
return { prompt, params, result, loading, generate };
}
}).mount('#app');
5.2 添加用户系统
# routes/auth.py : 简单用户认证
import hashlib
import uuid
from fastapi import APIRouter, HTTPException, Cookie, Response
from pydantic import BaseModel
router = APIRouter()
# 内存用户存储
users = {}
sessions = {}
class UserRegister(BaseModel):
username: str
password: str
class UserLogin(BaseModel):
username: str
password: str
@router.post("/register")
async def register(user: UserRegister):
if user.username in users:
raise HTTPException(400, "用户名已存在")
users[user.username] = {
"username": user.username,
"password_hash": hashlib.sha256(user.password.encode()).hexdigest(),
"created_at": "2026-01-01",
"generation_count": 0,
}
return {"status": "registered"}
@router.post("/login")
async def login(user: UserLogin, response: Response):
user_data = users.get(user.username)
if not user_data:
raise HTTPException(401, "用户不存在")
pwd_hash = hashlib.sha256(user.password.encode()).hexdigest()
if pwd_hash != user_data["password_hash"]:
raise HTTPException(401, "密码错误")
session_id = uuid.uuid4().hex
sessions[session_id] = user.username
response.set_cookie(key="session_id", value=session_id, httponly=True)
return {"status": "logged_in", "username": user.username}
5.3 WebSocket 实时推送
# 使用 WebSocket 实现实时生成进度
from fastapi import WebSocket, WebSocketDisconnect
class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
async def send_progress(self, client_id: str, progress: float,
message: str = ""):
ws = self.active_connections.get(client_id)
if ws:
await ws.send_json({
"progress": progress,
"message": message,
})
manager = ConnectionManager()
@router.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
while True:
data = await websocket.receive_text()
# 处理客户端消息...
except WebSocketDisconnect:
manager.disconnect(client_id)
六、部署
6.1 生产部署
# 使用 Gunicorn + Uvicorn(推荐)
pip install gunicorn
# 启动
gunicorn app:app \
--worker-class uvicorn.workers.UvicornWorker \
--workers 1 \
--bind 0.0.0.0:8080 \
--timeout 300
# 使用 Nginx 反向代理
# 参见第 8 期的 Nginx 配置
6.2 优化建议
性能优化:
├── 前端:图片懒加载、结果缓存、CDN
├── 后端:模型热加载、推理缓存、异步队列
├── 存储:OSS 存储图片(阿里云/七牛)、定期清理
└── 扩展:多 GPU 负载均衡、模型预热
用户体验:
├── 生成时显示实时进度
├── Prompt 历史自动补全
├── 参数模板保存/分享
├── 图片放大/对比功能
└── 移动端适配
下一期是终极实战: : 综合项目实战,我们将把前面九期的所有技术整合为一个企业级 AI 绘图平台。
代码仓库:https://gitee.com/genesisesNoun/ai-drawing-tutorial.git
对应目录:09-web-app/,包含完整的 Web 应用代码(后端 + 前端)。
夜雨聆风