机器学习入门:SVM中SMO算法说明
他们说:数据和特征决定了模型的上限,而算法决定了你能否以最高效的路径、最稳健的姿态,去触达这个上限

目录
-
SVM 基础回顾 -
SMO 算法原理 -
具体示例数据 -
SMO 迭代计算过程 -
完整代码实现
SVM 基础回顾
优化问题
对于线性可分的 SVM,我们的优化问题是:
拉格朗日对偶问题
引入拉格朗日乘子 ,得到对偶问题:
对于线性核:
决策函数
SMO 算法原理
核心思想
SMO(Sequential Minimal Optimization)由 John Platt 于 1998 年提出,核心思想是:
-
每次只优化两个拉格朗日乘子 和 -
固定其他所有乘子不变 -
解析求解这个两变量的二次规划问题 -
迭代直到收敛
为什么选择两个变量?
-
约束条件 至少需要两个变量才能有意义地更新 -
两个变量的子问题可以解析求解,无需调用数值优化库 -
计算效率高,适合大规模数据
算法步骤1. 初始化所有 α_i = 02. 重复以下步骤直到收敛: a. 选择第一个变量 α_i(违反 KKT 条件最严重的) b. 选择第二个变量 α_j(使 |E_i - E_j| 最大的) c. 计算 α_j 的更新范围 [L, H] d. 计算最优的 α_j^new e. 计算 α_i^new f. 更新阈值 b g. 更新误差缓存 E_k3. 返回最终的 α 和 b
具体示例数据
数据集设计
我们创建一个简单的 2 维数据,包含 6 个样本,分为两类:
|
|
|
|
|
|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
数据可视化类别 +1 (○) 类别 -1 (×) ↑x₂ | 5 | ×(5) | ×(6) 4 | ×(4) | 3 | | ○(3) 2 | ○(2) | ○(1) 1 | +----+----+----+----+----+→ x₁ 1 2 3 4 5
参数设置
-
惩罚参数 C = 1.0(软间隔) -
容差 ε = 0.001(收敛判断) -
线性核:
SMO 迭代计算过程
初始化
初始误差缓存:
|
|
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
第一轮迭代
步骤 1:选择第一个变量 α₁
检查 KKT 条件:
对于所有 的样本,检查 :
-
样本 1: ❌ 违反 -
样本 2: ❌ 违反 -
样本 3: ❌ 违反 -
样本 4: ❌ 违反 -
样本 5: ❌ 违反 -
样本 6: ❌ 违反
选择违反最严重的,这里都相同,选样本 1作为第一个变量。
步骤 2:选择第二个变量 α₂
选择使 最大的 :
-
← 最大 -
← 最大 -
← 最大
选择样本 4作为第二个变量(第一个最大的)。
步骤 3:计算更新范围 [L, H]
由于 (+1 ≠ -1),使用公式:
步骤 4:计算核函数值
计算所需的核函数值(线性核 ):
步骤 5:计算 η
步骤 6:计算 α₄^new
检查范围:,即 ✓
所以
步骤 7:计算 α₁^new
根据约束 :
步骤 8:更新 b
计算 和 :
由于 且 ,取平均:
步骤 9:更新误差缓存
更新所有 ,其中:
当前
计算各样本的 :
样本 1:
样本 2:
样本 3:
样本 4:
样本 5:
样本 6:
第一轮结果汇总
|
|
|
|---|---|
|
|
|
|
|
|
|
|
|
第二轮迭代
步骤 1:选择第一个变量
检查 KKT 条件违反情况:
|
|
|
|
|
|
|
|---|---|---|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
选择违反最严重的:样本 2()
步骤 2:选择第二个变量
计算 :
-
← 最大
选择样本 4作为第二个变量。
步骤 3:计算更新范围 [L, H]
,异号:
步骤 4:计算核函数值
步骤 5:计算 η
步骤 6:计算 α₄^new
检查范围:,即 ✓
步骤 7:计算 α₂^new
步骤 8:更新 b
第二轮结果汇总
|
|
|
|---|---|
|
|
|
|
|
|
继续迭代…
重复上述过程,直到:
-
所有样本都满足 KKT 条件(在容差 ε 内) -
或者达到最大迭代次数
最终结果(收敛后)
假设经过若干轮迭代后收敛,得到:
|
|
|
|
|
|
|
|---|---|---|---|---|---|
|
|
|
|
|
|
|
支持向量
支持向量是 的样本:
-
样本 1:,, -
样本 4:,,
决策边界
决策函数:
决策边界():
完整代码实现
"""SMO 算法完整实现 - 修正正负类边界问题"""import numpy as npimport matplotlib.pyplot as pltclass SMO:def __init__(self, C=1.0, tol=0.001, max_iter=2000):"""初始化 SMO 算法(优化收敛参数)参数:C: 惩罚参数(增大C可增强边界拟合)tol: 容差(更小的容差提高精度)max_iter: 最大迭代次数(增加迭代次数保证收敛)"""self.C = Cself.tol = tolself.max_iter = max_iterself.alpha = Noneself.b = 0.0self.w = Noneself.X_train = Noneself.y_train = Nonedef linear_kernel(self, x1, x2):"""线性核函数(向量化实现,提高精度)"""return np.dot(x1, x2.T) if len(x1.shape) > 1 else np.dot(x1, x2)def compute_kernel_matrix(self, X):"""向量化计算核函数矩阵(减少数值误差)"""return self.linear_kernel(X, X)def compute_error(self, K, k):"""优化误差计算(减少重复计算)"""f_k = np.sum(self.alpha * self.y_train * K[:, k]) + self.breturn f_k - self.y_train[k]def select_j(self, i, E_i, E):"""改进的j选择策略(优先选误差符号相反的样本)"""valid_indices = [k for k in range(len(E)) if k != i and 0 <= self.alpha[k] <= self.C]if not valid_indices:return np.random.choice([k for k in range(len(E)) if k != i])# 优先选择与E_i符号相反的样本(加速收敛)opposite_sign = [k for k in valid_indices if E[k] * E_i < 0]if opposite_sign:return max(opposite_sign, key=lambda k: abs(E_i - E[k]))# 否则选误差差最大的return max(valid_indices, key=lambda k: abs(E_i - E[k]))def clip_alpha(self, alpha_j, L, H):"""裁剪 α_j 到 [L, H] 范围"""return np.clip(alpha_j, L, H)def fit(self, X, y, verbose=True):"""优化的SMO训练逻辑(保证充分收敛)"""self.X_train = np.array(X, dtype=np.float64)self.y_train = np.array(y, dtype=np.float64)n_samples, n_features = self.X_train.shape# 初始化self.alpha = np.zeros(n_samples, dtype=np.float64)self.b = 0.0K = self.compute_kernel_matrix(self.X_train)E = np.array([self.compute_error(K, k) for k in range(n_samples)])iter_count = 0alpha_changed = 0examine_all = Truewhile (iter_count < self.max_iter) and (alpha_changed > 0 or examine_all):alpha_changed = 0# 选择遍历的样本集if examine_all:indices = range(n_samples)else:indices = [i for i in range(n_samples) if 0 < self.alpha[i] < self.C]for i in indices:E_i = E[i]# 严格的KKT条件检查(核心优化)y_i_E_i = self.y_train[i] * E_iviolate_kkt = (y_i_E_i < -self.tol and self.alpha[i] < self.C) or \(y_i_E_i > self.tol and self.alpha[i] > 0)if violate_kkt:# 选择第二个变量j = self.select_j(i, E_i, E)E_j = E[j]alpha_i_old = self.alpha[i].copy()alpha_j_old = self.alpha[j].copy()# 计算L和Hif self.y_train[i] != self.y_train[j]:L = max(0.0, self.alpha[j] - self.alpha[i])H = min(self.C, self.C + self.alpha[j] - self.alpha[i])else:L = max(0.0, self.alpha[i] + self.alpha[j] - self.C)H = min(self.C, self.alpha[i] + self.alpha[j])if L == H:continue# 计算eta(数值稳定性优化)eta = 2.0 * K[i, j] - K[i, i] - K[j, j]if eta >= -1e-8: # 允许微小的非负eta,避免提前终止continue# 更新alpha jself.alpha[j] -= self.y_train[j] * (E_i - E_j) / etaself.alpha[j] = self.clip_alpha(self.alpha[j], L, H)# 检查alpha j是否有足够大的变化if abs(self.alpha[j] - alpha_j_old) < 1e-6:continue# 更新alpha idelta_j = self.alpha[j] - alpha_j_olddelta_i = -self.y_train[i] * self.y_train[j] * delta_jself.alpha[i] += delta_i# 更新b(更精准的计算)b_old = self.bb1 = self.b - E_i - self.y_train[i] * delta_i * K[i, i] - self.y_train[j] * delta_j * K[i, j]b2 = self.b - E_j - self.y_train[i] * delta_i * K[i, j] - self.y_train[j] * delta_j * K[j, j]if 1e-8 < self.alpha[i] < self.C - 1e-8:self.b = b1elif 1e-8 < self.alpha[j] < self.C - 1e-8:self.b = b2else:self.b = (b1 + b2) / 2.0# 更新误差缓存E[i] = self.compute_error(K, i)E[j] = self.compute_error(K, j)alpha_changed += 1iter_count += 1# 切换遍历模式if examine_all:examine_all = Falseelif alpha_changed == 0:examine_all = Trueif verbose and iter_count % 50 == 0:print(f"迭代 {iter_count}, 改变的 α 对数:{alpha_changed}")# 计算w(高精度)self.w = np.sum(self.alpha[:, np.newaxis] * self.y_train[:, np.newaxis] * self.X_train, axis=0)# 输出收敛信息support_vectors = self.alpha > 1e-5print(f"\n训练完成!迭代次数:{iter_count}")print(f"支持向量个数:{np.sum(support_vectors)}")print(f"决策边界参数:w={self.w.round(4)}, b={self.b.round(4)}")return selfdef decision_function(self, X):"""优化的决策函数计算(向量化,提高精度)"""X = np.array(X, dtype=np.float64)# 向量化计算所有样本的核函数值kernel_vals = self.linear_kernel(self.X_train, X) # (n_train, n_test)# 计算决策值:(alpha * y) @ kernel_vals + bscores = np.dot(self.alpha * self.y_train, kernel_vals) + self.breturn scoresdef predict(self, X):"""预测(基于优化的决策函数)"""return np.sign(self.decision_function(X))# ============== 示例使用(验证正负类边界) ==============if __name__ == "__main__":# 创建更易区分的示例数据(避免样本共线)X = np.array([[1.0, 2.0], # +1[2.0, 1.0], # +1[2.0, 3.0], # +1[4.0, 3.0], # -1[5.0, 2.0], # -1[5.0, 4.0], # -1])y = np.array([1, 1, 1, -1, -1, -1])# 训练SMO(增大C值增强拟合)smo = SMO(C=10.0, tol=1e-4, max_iter=2000)smo.fit(X, y, verbose=True)# 输出核心结果print("\n" + "=" * 60)print("训练结果详情")print("=" * 60)print(f"α 值:{smo.alpha.round(4)}")print(f"决策边界:{smo.w[0].round(4)}*x1 + {smo.w[1].round(4)}*x2 + {smo.b.round(4)} = 0")# 可视化(重点修正边界绘制)plt.figure(figsize=(10, 8))# 1. 绘制样本点(区分支持向量和普通样本)sv_mask = smo.alpha > 1e-5for i in range(len(X)):if y[i] == 1:# 正类样本:支持向量用实心圆,普通样本用空心圆marker = 'o' if sv_mask[i] else 'o'fillstyle = 'full' if sv_mask[i] else 'none'plt.scatter(X[i, 0], X[i, 1], c='blue', marker=marker, fillstyle=fillstyle,s=200 if sv_mask[i] else 100, label='正类' if i == 0 else '')else:# 负类样本:支持向量用实心叉,普通样本用空心叉marker = 'x' if sv_mask[i] else 'x'plt.scatter(X[i, 0], X[i, 1], c='red', marker=marker,s=200 if sv_mask[i] else 100, label='负类' if i == 3 else '')# 2. 绘制网格和决策边界(修正参数)x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx1, xx2 = np.meshgrid(np.linspace(x1_min, x1_max, 200),np.linspace(x2_min, x2_max, 200))grid = np.c_[xx1.ravel(), xx2.ravel()]# 计算决策函数值scores = smo.decision_function(grid).reshape(xx1.shape)# 绘制决策边界(0级)和正负类边界(±1级)contour = plt.contour(xx1, xx2, scores,levels=[-1, 0, 1],colors=['red', 'black', 'blue'],linestyles=['dashed', 'solid', 'dashed'],linewidths=[2, 3, 2])# 添加边界标签plt.clabel(contour, inline=True, fontsize=10, fmt={-1: '负类边界', 0: '决策边界', 1: '正类边界'})# 3. 美化图表plt.xlabel('特征 1 (x₁)', fontsize=12)plt.ylabel('特征 2 (x₂)', fontsize=12)plt.title('SMO-SVM 正负类边界(修正版)', fontsize=14)plt.legend(loc='best', fontsize=10)plt.grid(True, alpha=0.3)plt.axis('equal')# 保存并显示plt.tight_layout()plt.savefig('smo_correct_boundary.png', dpi=200)plt.show()# 验证预测结果print("\n" + "=" * 60)print("预测验证")print("=" * 60)test_points = np.array([[1.5, 2.5], [4.5, 3.5], [3.0, 2.5]])predictions = smo.predict(test_points)for i, (point, pred) in enumerate(zip(test_points, predictions)):print(f"测试点 {point}: 预测类别 = {int(pred)} (决策值 = {smo.decision_function([point])[0].round(4)})")
关键公式总结
1. α 更新公式
其中
2. 约束范围 [L, H]
异号 ():
同号 ():
3. b 更新公式
4. KKT 条件
参考资料
-
Platt, J. (1998). “Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines” -
Bishop, C. M. (2006). “Pattern Recognition and Machine Learning” -
李航 (2012). 《统计学习方法》
夜雨聆风