轻量级超分的双频域协同:深入源码解析 DMNet 架构设计
轻量级超分的双频域协同:深入源码解析 DMNet 架构设计

在边缘计算设备(如手机、无人机)上部署图像超分辨率(SR)模型时,算法工程师常常面临一个权衡:如何在极其有限的计算资源(低参数量、低 FLOPs)下,尽可能恢复出高保真的图像细节?
现有的轻量级超分网络大多局限于空间域(Spatial Domain)进行操作,这天然限制了模型的感受野,导致难以捕捉全局结构。为了打破这一限制,研究者开始引入频域(如傅里叶域或小波域)信息。
由北京邮电大学等机构发表在IEEE TRANSACTIONS ON MULTIMEDIA的论文《Dual-domain Modulation Network for Lightweight Image Super-Resolution》(DMNet)提出了一种极具工程启发性的方案:在小波域中进行局部特征重构,在傅里叶域中进行全局结构监督。
本文将结合论文作者开源的源码,深入拆解 DMNet 的底层逻辑,并提供核心模块的可插拔代码与详细注释,探讨其在工业界落地的潜力与局限。
一、 核心动机:为什么必须是“双频域”协同?
许多读者在阅读频域相关的论文时,常有一个疑问:既然小波变换和傅里叶变换都能把图像转到频域,为什么 DMNet 非要两个一起用?只用一个不行吗?
这里我们可以用一个“画肖像”的生活比喻来解释:
傅里叶变换(Fourier)就像是画家退后三步,眯着眼睛看整幅画的“宏观比例”(比如五官位置对不对)。它提取的是全局频率分布。如果只用傅里叶域做特征提取,模型能很好地把握整体结构,但当你凑近看时,会发现睫毛、毛孔等高频细节是模糊的。
小波变换(Wavelet)就像是画家拿着放大镜,专门刻画某一个局部的“边缘细节”。它能将图像无损分解为一个低频(大体轮廓)和三个高频(水平、垂直、对角线纹理)。但如果只在小波域里疯狂优化高频细节,不同方向的高频特征在梯度下降时容易产生冲突,导致最终合成的图像虽然清晰,但整体结构可能发生扭曲。
DMNet 的思路: 既然小波擅长“抓细节”,傅里叶擅长“控大局”,DMNet 将两者分工:
-
模型内部的特征提取交给小波域(利用其保留空间位置信息的特性,锐化纹理)。 -
模型外部的损失约束交给傅里叶域(利用全局频率分布,稳住整体结构,防止小波域优化跑偏)。

上图展示了本文模型在精度与效率的 Trade-off。DMNet 在保证极高 PSNR 的同时,大幅削减了 FLOPs。相比于基于 Mamba 架构的 MambaIR,DMNet 节省了超过 50% 的计算量,推理速度提升了数倍。这证明了找对物理特征域比单纯堆叠复杂算子更高效。
二、 核心架构与源码深度剖析

DMNet 的整体网络结构(对应源码 dmnet_arch.py)非常清晰。它由一个浅层特征提取卷积、若干个 SWGroup(空间-小波组),以及一个上采样重建模块组成。
每个 SWGroup 内部,交替使用了两种核心的 Transformer 变体:SMA(空间域调制自注意力) 和 WMA(小波域调制自注意力)。
下面我把源码进行可插拔封装,方便大家结合论文图片深入理解这两个核心模块。

1. SMA (Spatial-domain Modulation Attention):轻量化空间交互
在轻量级网络中,如果在空间维度计算自注意力(复杂度为 ),显存会直接溢出。SMA 的做法是:利用 卷积聚合跨通道上下文,利用 深度可分离卷积(DWConv)聚合局部空间上下文,并在通道维度上计算注意力。
import torchimport torch.nn as nnimport torch.nn.functional as FclassSMA(nn.Module):""" 空间域调制自注意力 (Spatial-domain Modulation Attention) 核心逻辑:在通道维度进行 Attention 计算,复杂度降为 O(C^2 * HW) """def__init__(self, dim): super(SMA, self).__init__()# 生成 Q, K, V# 1x1 卷积处理通道交互,3x3 DWConv 处理局部空间交互 self.qkv = nn.Sequential( nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False), nn.Conv2d(dim * 3, dim * 3, kernel_size=3, padding=1, groups=dim * 3, bias=False) ) self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=False)# 可学习的温度系数,用于调节 Softmax 分布 self.temperature = nn.Parameter(torch.ones(1, 1, 1))defforward(self, x): B, C, H, W = x.shape# 获取Q, K, V qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=1) # [B, C, H, W]# 展平空间维度:[B, C, H, W] -> [B, C, H*W] q = q.view(B, C, -1) k = k.view(B, C, -1) v = v.view(B, C, -1)# L2 归一化,使得内积等于余弦相似度 q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1)# 在通道维度计算注意力矩阵:[B, C, C] attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1)# 将注意力权重施加到 V 上 out = (attn @ v).view(B, C, H, W)return self.proj(out)
2. WMA (Wavelet-domain Modulation Attention):频域解耦与动态增强
WMA 是本文最大的创新点。它的核心流程是:通过离散小波变换(DWT)将特征图拆分为低频和高频子带 在频带特征间计算注意力 通过动态卷积(Dynamic Conv)进行局部特征增强 逆小波变换(IDWT)还原。
可插拔模块源码:
import torchimport torch.nn as nnimport torch.nn.functional as F# pip install pytorch_waveletsfrom pytorch_wavelets import DWTForward, DWTInverseclassWMA_Optimized(nn.Module):""" 小波域调制自注意力 (Wavelet-domain Modulation Attention) """def__init__(self, dim): super(WMA_Optimized, self).__init__()# 1. 前置降维:保证 DWT 后四个频带拼接起来的通道数恰好等于 dim self.reduce = nn.Conv2d(dim, dim // 4, kernel_size=1, bias=False)# 2. 实例化小波算子 (J=1: 单级分解, mode='zero': 零填充边界, wave='haar': Haar小波基) self.dwt = DWTForward(J=1, mode='zero', wave='haar') self.idwt = DWTInverse(mode='zero', wave='haar')# 3. 频域内的 Q, K, V 投影网络 self.qkv = nn.Sequential( nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False), nn.Conv2d(dim * 3, dim * 3, kernel_size=3, padding=1, groups=dim * 3, bias=False) )# 4. 动态调制权重生成分支 self.dynamic_weight = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, dim, kernel_size=1, bias=False), nn.Sigmoid() )# 5. 后置升维:恢复特征尺寸 self.expand = nn.Conv2d(dim // 4, dim, kernel_size=1, bias=False) self.temperature = nn.Parameter(torch.ones(1, 1, 1))defforward(self, x): B, C, H, W = x.shape# --- 第一阶段:小波频域解耦 --- x_red = self.reduce(x) # [B, C/4, H, W]# DWT 输出解析:# yl: 低频分量 (LL), 形状为 [B, C/4, H/2, W/2]# yh: 高频分量列表,单级分解(J=1)时 yh[0] 形状为 [B, C/4, 3, H/2, W/2]# 其中的 3 代表 LH(水平), HL(垂直), HH(对角线) 三个高频方向 yl, yh = self.dwt(x_red)# 拆包提取高频,并在通道维度上与低频拼接,重组为计算 Attention 所需的统一张量 high_freqs = yh[0] x_dwt = torch.cat([ yl, # LL: [B, C/4, H/2, W/2] high_freqs[:, :, 0, :, :], # LH: [B, C/4, H/2, W/2] high_freqs[:, :, 1, :, :], # HL: [B, C/4, H/2, W/2] high_freqs[:, :, 2, :, :] # HH: [B, C/4, H/2, W/2] ], dim=1) # 拼接后总形状: [B, C, H/2, W/2]# --- 第二阶段:跨频带自注意力计算 --- qkv = self.qkv(x_dwt) q, k, v = qkv.chunk(3, dim=1) # 各自形状: [B, C, H/2, W/2]# 展平空间维度并进行 L2 归一化 (通道 Attention 范式) q = F.normalize(q.view(B, C, -1), dim=-1) k = F.normalize(k.view(B, C, -1), dim=-1) v = v.view(B, C, -1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out_attn = (attn @ v).view(B, C, H//2, W//2)# --- 第三阶段:动态调制 (Dynamic Modulation) --- weight = self.dynamic_weight(x_dwt) # 获取自适应通道权重 [B, C, 1, 1] out_dyn = out_attn * weight # 逐通道调制频域特征 [B, C, H/2, W/2]# --- 第四阶段:逆小波变换还原 ---# 关键工程细节:需要将 [B, C, H/2, W/2] 的张量重新“打包”回 IDWT 能够识别的元组格式 sub_c = C // 4# 1. 剥离出低频 yl yl_out = out_dyn[:, :sub_c, :, :]# 2. 剥离出三个高频,利用 unsqueeze 恢复维度索引,再 cat 拼装成 yh[0] 格式 yh_out_LH = out_dyn[:, sub_c : 2*sub_c, :, :].unsqueeze(2) yh_out_HL = out_dyn[:, 2*sub_c : 3*sub_c, :, :].unsqueeze(2) yh_out_HH = out_dyn[:, 3*sub_c :, :, :].unsqueeze(2) yh_out = [torch.cat([yh_out_LH, yh_out_HL, yh_out_HH], dim=2)] # [B, C/4, 3, H/2, W/2]# 执行 IDWT 还原至空间域 out_idwt = self.idwt((yl_out, yh_out)) # 形状恢复为: [B, C/4, H, W]# 扩张通道数并输出return self.expand(out_idwt)
3. 全局频域监管:傅里叶损失 (Fourier Loss)
在网络输出端,DMNet 并没有改变推理架构,而是在训练阶段引入了快速傅里叶变换(FFT)损失。模型对超分输出图像和高斯清晰图像分别进行 2D-FFT,获取各自的振幅(Amplitude)和相位(Phase),并计算两者之间的 L1 距离。这一步就像是一个“无形的手”,在不增加任何推理计算量的前提下,强迫网络学习正确的全局结构,防止小波域“用力过猛”导致的局部失真。以下是完全还原论文公式 的实现:
import torchimport torch.nn as nnclassFourierLoss(nn.Module):""" DMNet中的全局频域监管:傅里叶损失函数 (Fourier Loss) 在频域中分离振幅与相位,使用L1距离约束网络学习正确的全局结构分布。 """def__init__(self, loss_weight=0.1): super(FourierLoss, self).__init__()# 频域损失通常作为辅助损失,权重设为 0.1 左右(需结合实际任务调参) self.loss_weight = loss_weight self.criterion = nn.L1Loss()defforward(self, sr, hr):""" sr: 网络输出的超分图像 (Super-Resolved), Shape: [B, C, H, W] hr: 真实的高清图像 (High-Resolution Ground Truth), Shape: [B, C, H, W] """# 1. 执行二维快速傅里叶变换 (2D FFT)# norm='ortho' 表示使用正交归一化,防止变换前后能量发生剧烈缩放,有利于梯度稳定 fft_sr = torch.fft.fft2(sr, norm='ortho') fft_hr = torch.fft.fft2(hr, norm='ortho')# fft2 输出的是复数张量 (Complex Tensor: real + imag * j)# 2. 提取振幅 (Amplitude)# torch.abs 用于计算复数的模。# 物理意义:振幅代表了图像中各个频率成分的“强度”(宏观的明暗对比和整体分布)。 amp_sr = torch.abs(fft_sr) amp_hr = torch.abs(fft_hr)# 3. 提取相位 (Phase)# torch.angle 用于计算复数的幅角。# 物理意义:相位决定了图像中不同正弦波的“位置”,包含了图像最关键的轮廓、边缘和结构信息。 phase_sr = torch.angle(fft_sr) phase_hr = torch.angle(fft_hr)# 4. 特征拼接 (Concatenation)# 将振幅和相位在通道维度 (dim=1) 拼接起来,形成混合频域特征# 拼接后 Shape: [B, 2*C, H, W] freq_sr = torch.cat([amp_sr, phase_sr], dim=1) freq_hr = torch.cat([amp_hr, phase_hr], dim=1)# 5. 计算频域的 L1 距离 loss = self.criterion(freq_sr, freq_hr)return loss * self.loss_weight
三、 评判性分析与优化建议
在轻量级 SR 赛道中,DMNet 的“小波提特征 + 傅里叶算损失”策略提供了一个优秀的范本。通过其论文中提供的局部归因图(LAM)可以看出,DMNet 激活的像素范围(红点分布)显著广于常规的轻量级网络,证明其有效感受野得到了极大扩展。

LAM 展示了模型在重建某一块局部区域时,利用了原图中多大范围的信息。DMNet 拥有更广的红点分布,说明双频域设计成功打破了纯空间卷积的局部性限制。
然而,从算法落地和架构演进的角度来看,DMNet 仍有以下可值得思考的地方:
1. 串行结构的效率
在当前的 SWGroup 中,空间模块(SMA)和频域模块(WMA)是严格串行执行的。虽然这有利于特征的深度提纯,但在硬件底层,空间和频域的计算原本是可以解耦并行的。如果“双分支并行计算 + 后期特征融合门控(Gating)”的结构是否可行?这样不仅能缩短前向传播的计算图长度,还能进一步降低推理延迟(Latency)。
2. 应对真实世界退化的鲁棒性
DMNet 的实验主要建立在理想的双三次插值降采样(Bicubic)数据集上。在真实的工业或医疗场景中,图像往往伴随复杂的传感器散斑噪声和运动模糊。小波变换对高频噪声非常敏感,容易将噪声误认为“边缘纹理”进行放大。如果在 WMA 模块前引入一个轻量级的“退化模式估计器”,或者在小波的高频子带(HL, LH, HH)处理中加入自适应的软阈值去噪(Soft-thresholding)是否有用?
3. 极端低分辨率下的特征坍塌
由于 DWT 操作会在物理上将特征图的宽和高减半,如果输入的低分辨率图像已经非常小(例如 ),经过 DWT 后特征图将变为 。在如此小的分辨率下,空间语义极度压缩,WMA 的频域自注意力可能无法捕捉到有效的纹理差异。如果对于极低分辨率任务,将离散小波变换替换为不改变空间分辨率的频域滤波器,或结合近期大热的可变形状态空间模型(Deformable Mamba)进行全尺寸的序列扫描是否可行?
小结
在算力受限的场景下,巧妙利用传统的数学工具(小波与傅里叶变换)对深度特征进行物理域的解耦,远比盲目堆叠 Transformer 层更加高效且优雅。DMNet为未来的边缘视觉模型设计提供了宝贵的参考。
夜雨聆风