乐于分享
好东西不私藏

对前一篇文档(mnist数据训练)的补充

对前一篇文档(mnist数据训练)的补充

1. 人工智能太强大了

在经过chat-gpt的帮助下,我得以窥见问题的所在,首先肯定的是我昨天在上一篇文章中进行的猜测是正确的,正确率表现不高果然是因为训练的样本太少,而导致神经网络权重不那么适配。在使用了完整的mnist训练集后,再对完整的mnist测试集进行测试,正确率都能达到95%以上。

2. 给出训练集代码和测试集代码

import numpy as np
from neur_network import neuralNetwork

# 加载 IDX 文件
def load_mnist_images(filename):
    import struct
    with open(filename, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        data = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows * cols)
    return data.astype(np.float32) / 255.0

def load_mnist_labels(filename):
    import struct
    with open(filename, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        data = np.fromfile(f, dtype=np.uint8)
    return data.astype(np.int32)

# 加载训练数据
x_train = load_mnist_images('train-images.idx3-ubyte')
y_train = load_mnist_labels('train-labels.idx1-ubyte')

# 参数设置
input_nodes = 784
hidden_nodes = 100
output_nodes = 10
learning_rate = 0.2
epochs = 5

network = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)

# 训练
print("开始训练...")
for epoch in range(epochs):
    # 每轮打乱样本顺序
    indices = np.random.permutation(len(x_train))
    for idx in indices:
        # 取一个样本
        scaled_pixels = 0.01 + x_train[idx] * 0.99   # 注意 x_train 已在 [0,1] 范围,映射到 [0.01,0.99]
        targets = np.zeros(output_nodes) + 0.01
        targets[y_train[idx]] = 0.99
        network.train(scaled_pixels, targets)
    print(f"Epoch {epoch+1}/{epochs} 完成")

print("训练结束")

# 保存权重(可选)
np.savez('mnist_weights.npz', wih=network.wih, who=network.who)

在训练集代码中,相对于之前的代码,变化的是加载文件的区别,神经网络还是不变。

import numpy as np
import matplotlib.pyplot as plt
from neur_network import neuralNetwork   # 直接导入类,避免导入变量问题

# ==================== 1. 参数设置 ====================
input_nodes = 784
hidden_nodes = 100
output_nodes = 10
learning_rate = 0.2

# 创建网络
network = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)

# ==================== 2. 加载训练好的权重 ====================
data = np.load('mnist_weights.npz')
network.wih = data['wih']
network.who = data['who']
print("权重加载完成")

# ==================== 4. 测试(使用测试文件的数据) ====================
# 加载 IDX 文件
def load_mnist_images(filename):
    import struct
    with open(filename, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        data = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows * cols)
    return data.astype(np.float32) / 255.0

def load_mnist_labels(filename):
    import struct
    with open(filename, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        data = np.fromfile(f, dtype=np.uint8)
    return data.astype(np.int32)

x_test = load_mnist_images('t10k-images.idx3-ubyte')
y_test = load_mnist_labels('t10k-labels.idx1-ubyte')

correct = 0
for i in range(len(x_test)):
    scaled = 0.01 + x_test[i] * 0.99
    output = network.query(scaled)
    if np.argmax(output) == y_test[i]:
        correct += 1
print(f"准确率: {correct/len(x_test)*100:.2f}%")

测试集代码变化的也是加载文件的区别。

3. 对新代码中的images和labels进行说明

在加载 MNIST 数据时,imageslabels是两个完全不同的文件,其内容和用途区别如下:

对比项

images 文件(如train-images.idx3-ubyte

labels 文件(如train-labels.idx1-ubyte

存储内容

手写数字的图像像素值(28×28 = 784 个像素)

每个图像对应的真实数字标签(0~9 的整数)

文件格式

IDX3 格式:包含魔数、图像数量、行数、列数,之后是像素数据(每个像素占 1 字节)

IDX1 格式:包含魔数、标签数量,之后是标签数据(每个标签占 1 字节)

数据结构

形状为(样本数, 784)

的 numpy 数组,每个值范围 0~255(归一化后常变为 [0,1])

形状为(样本数,)

的 numpy 数组,每个值是 0~9 的整数

用途

作为神经网络的输入特征inputs

作为神经网络的监督信号targets

),用于计算损失和反向传播

加载函数返回

load_mnist_images()

返回图像数据矩阵

load_mnist_labels()

返回标签向量

举例

  • 第一张图像的像素数据(images[0])是一个 784 维向量,代表手写数字的灰度图。
  • 对应的标签(labels[0])是一个整数,例如5,表示这张图像上写的是数字“5”。

训练时,你需要同时使用这两个文件:用images作为网络输入,用labels构造 one-hot 向量作为目标输出。

注意:两个文件中的样本顺序必须一致——第 i 个标签对应第 i 张图像。加载函数已经通过相同的读取顺序保证了这一点。

4. 对于自己手写数字并拍照上传的图片模型为什么只预测8

代码里最可能的问题不是“模型只会识别 8”,而是:

4.1. 手写图片预处理和 MNIST 训练数据分布不一致

上传的图片:

  •  是手机拍照 
  •  有阴影 
  •  背景不是纯白 
  •  数字很细 
  •  数字位置偏 
  •  周围留白很大 

而 MNIST 的训练数据是:

  •  纯黑背景 
  •  纯白数字 
  •  数字已经居中 
  •  28×28 
  •  灰度分布很标准 

所以网络实际上看到的不是“3 / 4 / 9”,而是一堆奇怪的灰度块。

这会导致:

  •  输出层某个神经元(这里是 8)长期激活最高 
  •  无论输入什么都预测成 8 

4.2. 代码中的关键问题

当前预处理:

img = Image.open(image_path).convert('L')
img = ImageOps.invert(img)
bbox = img.getbbox()

这里有一个严重问题:

照片背景并不是纯白。

拍照后背景有灰色阴影:

  •  invert 后 
  •  背景变成“非零灰色” 
  • getbbox() 会认为整张图都有内容 

于是:

img.crop(bbox)

实际上根本没裁掉背景。

最后网络输入的是:

  •  大面积灰背景 
  •  很小的数字 

而不是 MNIST 风格数字。


4.3. 为什么总是预测成 8

因为:

输入特征几乎恒定

对于神经网络来说:3 ≈ 4 ≈ 9

因为背景噪声占了大部分像素。


4.4.  输出层偏置倾向于某个数字

MNIST 中:

  •  8 的笔画复杂 
  •  激活区域多 

很多“脏输入”都会被误认为 8。

这是经典现象。


5. 应该这样修复

5.1. 第一件事:二值化

在 invert 后增加:

img = img.point(lambda p: 255 if p > 80 else 0)

完整:

img = Image.open(image_path).convert('L')

# 反色
img = ImageOps.invert(img)

# 二值化(非常关键)
img = img.point(lambda p: 255 if p > 80 else 0)

# 裁剪
bbox = img.getbbox()
if bbox:
    img = img.crop(bbox)

5.2. 第二件事:显示预处理后的图片

现在根本不知道模型实际看到什么。

加:canvas.show()

或者:

import matplotlib.pyplot as plt
plt.imshow(canvas, cmap=’gray’)
plt.show()

会立刻发现:现在输入给模型的图片可能全是灰背景。


5.3. 第三件事:图片太大了

上传的数字:

  •  占画面比例太小 
  •  笔画太细 

MNIST 数字一般占:画面高度的 70% 左右

现在可能只有:10%

网络当然识别不了。


5.4. 第四件事:建议改成 OpenCV 预处理

现在 PIL 处理太弱。

建议:

import cv2

用:

  •  阈值分割 
  •  找轮廓 
  •  居中 
  •  resize 

这才接近真正 OCR。


6. 代码还有一个隐藏问题

这里:

img.resize((20,20))

会强行拉伸。

如果数字本来很瘦:

  •  3 
  •  1 
  •  7 

会严重变形。

MNIST 是保持比例缩放的。


7. 正确版本(推荐直接替换 preprocess)

def preprocess_image(image_path):
    img = Image.open(image_path).convert('L')

    # 反色
    img = ImageOps.invert(img)

    # 二值化
    img = img.point(lambda p: 255 if p > 80 else 0)

    # 裁剪
    bbox = img.getbbox()
    if bbox:
        img = img.crop(bbox)

    # 保持比例缩放
    img.thumbnail((20, 20))

    # 创建黑色画布
    canvas = Image.new('L', (28, 28), 0)

    # 居中
    x = (28 - img.size[0]) // 2
    y = (28 - img.size[1]) // 2

    canvas.paste(img, (x, y))

    # 查看预处理结果
    import matplotlib.pyplot as plt
    plt.imshow(canvas, cmap='gray')
    plt.show()

    # 归一化
    pixels = np.array(canvas, dtype=np.float32)

    scaled = 0.01 + (pixels / 255.0) * 0.99

    return scaled.flatten()

8. 还有一个重要问题:模型太弱

现在是:

784 -> 100 -> 10

纯手写 BP 网络。

没有:

  •  CNN 
  •  卷积 
  •  池化 
  •  BatchNorm 

这种模型:

  •  对 MNIST 可以 
  •  对真实拍照手写几乎不行 

因为它没有平移不变性。


9. 最后总结

现在的问题本质是:

不是“模型不会识别”

而是:“输入图片根本不像 MNIST”

导致:所有输入都被映射到同一种特征

于是:全部预测成 8

这是手写数字识别里最经典的问题之一:

训练集分布 ≠ 测试集分布

把预处理修好以后:

  •  3 
  •  4 
  •  9 

识别率会立刻提升很多。