避免机器学习中的过拟合与欠拟合(十二)

在机器学习领域,构建一个优秀的模型就像培养一个出色的学生。我们希望模型不仅能记住训练数据(课本上的例题),还能在面对新数据(考试中的新题)时表现出色。然而,在实际操作中,模型往往会遇到两种典型问题:欠拟合过拟合。本文将深入探讨这两种现象及其背后的理论基础——偏差方差,并提供一些实用的诊断和应对策略。

一、核心概念:模型的表现与“拟合”状态

1. 欠拟合

欠拟合 是指模型过于简单,无法捕捉数据中的关键模式。这就像一个学生只学会了加法,却要解决微积分问题。

  • 表现:模型在训练数据上的表现很差,误差大,准确率低。
  • 原因:模型复杂度不足,特征提取不够,或者训练不充分。
  • 类比:用一条直线(一次多项式)去拟合有明显弯曲趋势的数据。

示例

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error

# 生成数据
X = np.linspace(0, 10, 20)
y_true = np.sin(X)
y_noise = np.random.randn(20) * 0.3
y = y_true + y_noise

# 创建一次多项式特征
poly = PolynomialFeatures(degree=1)
X_poly1 = poly.fit_transform(X.reshape(-1, 1))

# 训练模型
model_under = LinearRegression()
model_under.fit(X_poly1, y)

# 预测
y_pred_under = model_under.predict(X_poly1)

# 计算均方误差
mse_train_under = mean_squared_error(y, y_pred_under)
print(f"欠拟合模型在训练集上的均方误差 (MSE): {mse_train_under:.4f}")

2. 恰到好处的拟合

这是理想的状态。模型既能够捕捉数据中的关键模式,又不会过度学习噪声,从而在训练集和测试集上都能表现良好。

  • 表现:模型在训练集和测试集上的误差都较低,且两者接近。
  • 类比:用一个适当阶数的多项式(例如3阶)来拟合数据。

示例

# 创建三次多项式特征
poly = PolynomialFeatures(degree=3)
X_poly3 = poly.fit_transform(X.reshape(-1, 1))

# 训练模型
model_good = LinearRegression()
model_good.fit(X_poly3, y)

# 预测
y_pred_good = model_good.predict(X_poly3)

# 计算均方误差
mse_train_good = mean_squared_error(y, y_pred_good)
print(f"良好拟合模型在训练集上的均方误差 (MSE): {mse_train_good:.4f}")

3. 过拟合

过拟合 是指模型过于复杂,不仅学习了数据中的真实规律,还“记住”了训练数据中的随机噪声和异常值。这就像一个学生记住了每一个例题的细节,但无法应对新的问题。

  • 表现:模型在训练数据上的表现极好(误差极小),但在新的、未见过的数据上表现急剧下降,泛化能力差。
  • 原因:模型复杂度过高,训练数据量太少。
  • 类比:用一个非常高阶的多项式(例如15阶)去拟合数据,使得曲线穿过了几乎每一个数据点,变得极度扭曲。

示例

import matplotlib.pyplot as plt

# 创建十五次多项式特征
poly = PolynomialFeatures(degree=15)
X_poly15 = poly.fit_transform(X.reshape(-1, 1))

# 训练模型
model_over = LinearRegression()
model_over.fit(X_poly15, y)

# 预测
y_pred_over = model_over.predict(X_poly15)

# 计算均方误差
mse_train_over = mean_squared_error(y, y_pred_over)
print(f"过拟合模型在训练集上的均方误差 (MSE): {mse_train_over:.4f}")

# 绘制图形
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plt.scatter(X, y, alpha=0.6)
plt.plot(X, y_pred_under, color='red', linewidth=2, label='欠拟合 (1阶)')
plt.plot(X, y_true, color='green', linestyle='--', label='真实规律')
plt.title(f'欠拟合\n训练MSE: {mse_train_under:.4f}')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.scatter(X, y, alpha=0.6)
plt.plot(X, y_pred_good, color='red', linewidth=2, label='良好拟合 (3阶)')
plt.plot(X, y_true, color='green', linestyle='--', label='真实规律')
plt.title(f'良好拟合\n训练MSE: {mse_train_good:.4f}')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.scatter(X, y, alpha=0.6)
plt.plot(X, y_pred_over, color='red', linewidth=2, label='过拟合 (15阶)')
plt.plot(X, y_true, color='green', linestyle='--', label='真实规律')
plt.title(f'过拟合\n训练MSE: {mse_train_over:.4f}')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

从图中可以看到:

  • 欠拟合(左):红色直线完全无法捕捉数据的波动趋势。
  • 良好拟合(中):红色曲线大致遵循了绿色真实规律的趋势。
  • 过拟合(右):红色曲线剧烈波动,试图穿过每一个蓝色散点,包括噪声点,完全失去了正弦曲线的光滑形态。

二、理论基石:偏差与方差分解

偏差和方差为我们理解过拟合与欠拟合提供了理论框架。它们描述了模型误差的两个不同来源。

1. 偏差

  • 定义:模型预测值的期望(即平均预测值)与真实值之间的差距。反映了模型本身的系统性错误,即模型对问题本质的假设是否有误。
  • 高偏差的表现:模型过于简单,无法刻画数据特征,导致欠拟合。无论用什么数据训练,结果都偏离真实值。
  • 例子:始终用“房价=面积×1000”这个简单线性模型来预测各种房子,忽略了地段、楼层等重要因素,这就是高偏差。

2. 方差

  • 定义:模型预测值自身的波动范围。反映了模型对训练数据中随机噪声的敏感程度。
  • 高方差的表现:模型过于复杂,对训练数据中的微小变化(包括噪声)反应过度,导致过拟合。换一组数据训练,得到的模型可能完全不同。
  • 例子:一个深度神经网络,如果不对其进行任何约束,它可能会为每一套独特的训练数据生成一套完全不同的、极度复杂的预测规则,这就是高方差。

3. 偏差-方差权衡

这是一个机器学习中的核心权衡。我们无法同时最小化偏差和方差。

  • 增加模型复杂度:通常可以降低偏差(模型能力变强),但会增加方差(更容易学到噪声)。
  • 降低模型复杂度:通常可以降低方差(模型更稳定),但会增加偏差(模型能力变弱)。

我们的目标是找到一个平衡点,使得总误差最小。

三、诊断与应对策略

1. 诊断方法:学习曲线

学习曲线是绘制模型在训练集验证集上的性能(如误差)随训练样本数模型复杂度变化的曲线。通过观察学习曲线,我们可以判断模型是否处于欠拟合或过拟合状态。

示例

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
import warnings

warnings.filterwarnings('ignore')

plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'PingFang SC', 'Heiti TC', 'WenQuanYi Micro Hei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

data = load_diabetes()
X, y = data.data[:, np.newaxis, 2], data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

def plot_learning_curve(estimator, title, X, y, cv=5, train_sizes=np.linspace(0.1, 1.0, 10)):
    """
    绘制学习曲线

    参数:
    estimator: 模型估计器
    title: 图表标题
    X: 特征数据
    y: 目标变量
    cv: 交叉验证折数
    train_sizes: 训练样本比例
    """
    train_sizes_abs, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, scoring='neg_mean_squared_error',
        train_sizes=train_sizes, random_state=42, n_jobs=-1
    )

    train_scores_mean = -train_scores.mean(axis=1)
    train_scores_std = train_scores.std(axis=1)
    test_scores_mean = -test_scores.mean(axis=1)
    test_scores_std = test_scores.std(axis=1)

    plt.figure(figsize=(10, 6))
    plt.fill_between(train_sizes_abs, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, alpha=0.1, color='r')
    plt.fill_between(train_sizes_abs, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, alpha=0.1, color='g')
    plt.plot(train_sizes_abs, train_scores_mean, 'o-', color='r', linewidth=2, markersize=8, label='训练集 MSE')
    plt.plot(train_sizes_abs, test_scores_mean, 'o-', color='g', linewidth=2, markersize=8, label='验证集 MSE')
    plt.xlabel('训练样本数量', fontsize=12)
    plt.ylabel('均方误差 (MSE)', fontsize=12)

    plt.title(title, fontsize=14, pad=20)
    plt.legend(loc='upper right', fontsize=11)
    plt.tight_layout()
    plt.show()

# 创建模型
model = make_pipeline(PolynomialFeatures(degree=1), StandardScaler(), LinearRegression())

# 绘制学习曲线
plot_learning_curve(model, '学习曲线 (1阶多项式)', X_train, y_train)

2. 应对策略

  • 增加数据量:更多的训练数据可以帮助模型更好地学习数据的内在规律,减少过拟合。
  • 减少模型复杂度:简化模型结构,减少特征数量,避免模型过于复杂。
  • 正则化:引入正则化项,如L1或L2正则化,惩罚模型的复杂度,防止过拟合。
  • 交叉验证:使用交叉验证技术评估模型的泛化能力,选择最佳的模型参数。

总结

理解过拟合与欠拟合,以及偏差与方差的概念,对于构建高性能的机器学习模型至关重要。通过学习曲线等工具,我们可以诊断模型的状态,并采取相应的策略进行调整。希望本文的内容能帮助你在机器学习实践中更加得心应手。