文章目录

  • 1 要点
  • 2 方法
    • 2.1 ABMIL
    • 2.2 多分支注意力
    • 2.3 随机Top-K实例掩蔽

1 要点

题目:用于WSI分类的注意力挑战多示例学习 (Attention-challenging multiple instance learning for whole slide image classification)

代码:https://github.com/dazhangyu123/ACMIL

研究目标
在已有的MIL方法中,注意力机制往往会集中在一小部分具有辨别性的实例上,这与过拟合密切相关。对此,ACMIL旨在通过减少注意力值的过度集中来提高模型的泛化能力。

关键技术

  1. 多重分支注意力 (MBA):利用多个注意力分支来关注多种模式的实例,确保更多辨别性实例对最终预测做出贡献;
  2. 随机Top- K K K实例掩蔽 (STKIM):通过随机掩蔽部分Top-K注意力值的实例,并将它们的注意力值分配给剩余实例,来抑制这些显著实例;

数据集

  1. CAMELYON16;
  2. BRACS;
  3. LBC:收集了1989张WSI,包含4个类别,即阴性、ASC-US、LSIL和ASC-H/HSIL;

2 方法

2.1 ABMIL

在二分类MIL中,一个包 X = { x n } n = 1 N X=\{x_{n}\}_{n = 1}^{N} X={xn}n=1N与一个包标签 Y Y Y相关联。每个实例 x n x_{n} xn则与一个二元标签 y n y_{n} yn相关联,在训练过程中其是未知的。MIL标准假设记为:
Y = { 0 , iff ∑ n = 1 N y n = 0 1 , otherwise (1) \tag{1} Y= \begin{cases}0, & \text{iff} \sum_{n = 1}^{N} y_{n} = 0 \\ 1, & \text{otherwise} \end{cases} Y={0,1,iffn=1Nyn=0otherwise(1)在ABMIL中,MIL通过一个三步过程建模:

  1. 通过神经网络将实例转换为低维嵌入: h n = f ( x n ) h_{n} = f(x_{n}) hn=f(xn)
  2. 使用注意力操作将所有实例嵌入聚合为包级表示:
    z = ∑ n = 1 N a n h n (2) z = \sum_{n = 1}^{N} a_{n}h_{n} \tag{2} z=n=1Nanhn(2)这里, a n = σ ( h n ) a_{n} = \sigma(h_{n}) an=σ(hn)表示第 n n n个实例 h n h_{n} hn的注意力值。在ABMIL中,采用了门控注意力 (GA) 机制来计算注意力值:
    σ ( h n ) = exp ⁡ { w T ( tanh ⁡ ( V 1 h n ) ⊙ sigm ( V 2 h n ) ) } ∑ j = 1 N exp ⁡ { w T ( tanh ⁡ ( V 1 h j ) ⊙ sigm ( V 2 h j ) ) } (3) \sigma(h_{n}) = \frac{\exp\{w^{T}(\tanh(V_{1}h_{n}) \odot \text{sigm}(V_{2}h_{n}))\}}{\sum_{j = 1}^{N} \exp\{w^{T}(\tanh(V_{1}h_{j}) \odot \text{sigm}(V_{2}h_{j}))\}} \tag{3} σ(hn)=j=1Nexp{wT(tanh(V1hj)sigm(V2hj))}exp{wT(tanh(V1hn)sigm(V2hn))}(3)
    其中 V 1 , V 2 ∈ R L × M V_{1}, V_{2} \in \mathbb{R}^{L×M} V1,V2RL×M w ∈ R L × 1 w \in \mathbb{R}^{L×1} wRL×1是参数、 ⊙ \odot 是逐元素乘法,以及 sigm ( ⋅ ) \text{sigm}(·) sigm()是sigmoid非线性函数;
  3. 基于聚合的包嵌入生成包预测: Y ^ = g ( z ) \hat{Y} = g(z) Y^=g(z)

2.2 多分支注意力

动机:使用单个注意力分支来捕捉所有辨别实例是具有挑战性的,如图3。这一挑战源于辨别性图像块的模式变化,即纹理和形态的差异。此外,深度神经网络往往表现出一种“惰性”,它们倾向于捕捉更简单的模式以最小化训练损失,而忽略了更复杂和具有挑战性的模式 [19,20]。为了解决这个问题,本文设计了多分支注意力机制 (MBA),通过多个注意力分支来捕捉更多的判别实例。

如图3所示,以CAMELYON16数据集中“test_113”病例的肿瘤实例特征的UMAP可视化为例。肿瘤实例之间存在各种模式/簇,依赖单个分支往往只能捕捉到部分簇。我们选取了三个实例来展示它们的纹理差异。

如图4顶视图所示,MBA首先捕捉 M M M种模式,然后聚合它们的嵌入以进行预测。每种模式由一个注意力分支捕捉。为了保持模式的辨别性以及它们之间的语义多样性,引入了两种正则化技术:

  1. 为确保捕捉到辨别性模式,通过在每个模式嵌入后连接一个多层感知器 (MLP) 层,并配备交叉熵损失函数来实现语义正则化
    L p = − 1 M ∑ i = 1 M Y log ⁡ Y ^ i + ( 1 − Y ) log ⁡ ( 1 − Y ^ i ) (4) \tag{4} \mathcal{L}_{p}=-\frac{1}{M} \sum_{i = 1}^{M} Y \log \hat{Y}_{i} + (1 - Y)\log(1 - \hat{Y}_{i}) Lp=M1i=1MYlogY^i+(1Y)log(1Y^i)(4)其中 Y ^ i = g i ( z i ) \hat{Y}_{i} = g_{i}(z_{i}) Y^i=gi(zi)是基于第 i i i个模式嵌入 z i z_{i} zi的预测;
  2. 仅配备交叉熵损失可能会学习到相似的模式,无法挖掘出更多辨别性信息。为了解决这个问题,进一步引入了多样性损失
    L d = 2 M ( M − 1 ) ∑ i = 1 M ∑ j = i + 1 M cos ⁡ ( a i , a j ) (5) \tag{5} \mathcal{L}_{d}=\frac{2}{M(M - 1)} \sum_{i = 1}^{M} \sum_{j = i + 1}^{M} \cos(a_{i}, a_{j}) Ld=M(M1)2i=1Mj=i+1Mcos(ai,aj)(5)
    其中 a i = { a i 1 , ⋯   , a i N } a_{i} = \{a_{i1}, \cdots, a_{iN}\} ai={ai1,,aiN}由第 i i i个模式的所有注意力值组成,通常也被称为注意力图 cos ⁡ ( ⋅ ) \cos(·) cos()函数用于衡量分支之间注意力图的相似性。通过使注意力图多样化,每个分支的嵌入可以专注于不同的模式。

    为了聚合捕捉到的模式以进行预测,将注意力图的平均值作为整个包的注意力图
    a = 1 M ∑ i = 1 M a i (6) \tag{6} a=\frac{1}{M} \sum_{i = 1}^{M} a_{i} a=M1i=1Mai(6)其中 a a a是整个包的注意力图,维度为 N N N。然后,可以使用平均注意力图 a a a聚合实例特征来获得包嵌入。此外,由于
    ∑ n = 1 N ( 1 M ∑ i = 1 M a i n ) h n = 1 M ∑ i = 1 M ( ∑ n = 1 N a i n h n ) \sum_{n = 1}^{N}(\frac{1}{M} \sum_{i = 1}^{M} a_{in})h_{n}=\frac{1}{M} \sum_{i = 1}^{M}(\sum_{n = 1}^{N} a_{in}h_{n}) n=1N(M1i=1Main)hn=M1i=1M(n=1Nainhn)包嵌入也可以通过对模式嵌入应用平均池化算子来表示。为简洁起见,图4的顶视图采用了后一种表示方式。包分类器的损失函数定义为:
    L b = − Y log ⁡ Y ^ + ( 1 − Y ) log ⁡ ( 1 − Y ^ ) (7) \tag{7} \mathcal{L}_{b}=-Y \log \hat{Y}+(1 - Y)\log(1 - \hat{Y}) Lb=YlogY^+(1Y)log(1Y^)(7)最后,ACMIL的总体损失函数记为:
    L = L b + L p + L d (8) \mathcal{L}=\mathcal{L}_{b}+\mathcal{L}_{p}+\mathcal{L}_{d} \tag{8} L=Lb+Lp+Ld(8)讨论:需要着重指出的是,当MBA中的参数 M M M设置为1时,它本质上反映了ABMIL的特征聚合过程,只能识别单一模式。从这个意义上讲,MBA是ABMIL的扩展,专门用于捕捉更多样化的模式集。我们进一步讨论MBA与多头注意力机制 (MHA) 之间的联系。HIPT揭示了MHA中不同的头可以有效地捕捉不同的视觉概念,这与我们的MBA所起的作用类似。然而,这两种技术可以通过以下方式轻易区分:
  1. MBA具有多样性正则化,确保不同的分支能够学习不同的概念。而MHA中不存在这种机制,导致不同的头可能学习相同的概念;
  2. MHA是一种注意力公式,而MBA独立于注意力公式运作,其框架内可包含MHA。

2.3 随机Top-K实例掩蔽

动机:在ABMIL中,极少数实例会占据大部分注意力,而忽略了复杂的辨别实例。如图5所示,在所有三个数据集上,前10个注意力值的总和都大于0.85。然而,一张WSI中通常包含超过10个辨别实例。例如,在CAMELYON16数据集中,155张肿瘤切片中有129张包含10到20000个癌性实例。从本质上讲,大量辨别实例被忽视了。为了解决这个问题,提出了旨在抑制显著实例的随机Top-K实例掩蔽 (STKIM),并将更多注意力分配给其余实例。


如图4底视图所示,STKIM在注意力机制中,于特征聚合之前、注意力值生成之后引入了一个掩蔽操作。其主要目标是抑制Top-K显著实例。一种直接的解决方法是将所有Top-K显著实例掩蔽掉。然而,这种方法存在一定的问题。它可能导致与关键辨别实例相关的信息丢失,而这些实例对于判别至关重要。此外,丢弃这些关键实例可能会导致特征表示在前后出现统计上的不匹配。为了解决这些问题,我们从计算机视觉中常用的随机失活(dropout)[46]和Cutout [17,62]方法中获得灵感。我们提出的解决方案是对具有Top-K注意力值的实例特征进行随机掩蔽。具体来说,我们首先将所有注意力值从高到低排序。随后,以概率 p p p将Top-K个实例的注意力值随机设置为0。这个过程可以表示为:
a n = { 0 , 以概率 p 且在Top-K值范围内 a n , 其他情况 a_{n}= \begin{cases}0, & \text{以概率}p\text{且在Top-K值范围内} \\ a_{n}, & \text{其他情况} \end{cases} an={0,an,以概率p且在Top-K值范围内其他情况
其中, p p p K K K是两个超参数,用于控制掩蔽的强度。根据公式(9),我们将被掩蔽实例的注意力值分配给其余实例,即 a n → 1 ∑ n = 1 N a n a n a_{n} \to \frac{1}{\sum_{n=1}^{N} a_{n}} a_{n} ann=1Nan1an 。值得注意的是,受随机失活和Cutout的启发,我们在推理时移除STKIM。
讨论:虽然STKIM、MHIM - MIL [47]和WENO [42]都采用了掩蔽显著实例的技术,但它们之间存在显著的技术差异。首先,与WENO和MHIM - MIL相比,我们的STKIM具有不同的掩蔽策略。STKIM仅以概率 p p p掩蔽少数实例(即 K = 10 K = 10 K=10)。相比之下,其他两种方法掩蔽的实例数量更多。WENO掩蔽95个实例,MHIM - MIL掩蔽1%的实例。在我们的框架中,我们的方案在三种策略中表现最佳(见附录第9.5节)。其次,MHIM - MIL和WENO都需要一个训练良好的模型来掩蔽显著实例,并利用剩余实例进行模型训练。它们都采用教师-学生框架,其中教师模型需要预先训练(WENO中的预热过程和MHIM - MIL中的预训练阶段)。相比之下,STKIM既不需要教师-学生框架,也不需要预训练过程,因此具有简单高效的特点。

Logo

一站式 AI 云服务平台

更多推荐