
本文详解如何使用 torch.gather 对形状为 [b, m, n] 的张量 A,按形状为 [b, k] 的索引张量 B 进行批量二维索引,得到 [b, k, n] 的输出张量。核心在于扩展索引维度并匹配 gather 的维度要求。
本文详解如何使用 `torch.gather` 对形状为 `[b, m, n]` 的张量 a,按形状为 `[b, k]` 的索引张量 b 进行批量二维索引,得到 `[b, k, n]` 的输出张量。核心在于扩展索引维度并匹配 `gather` 的维度要求。
在 PyTorch 中,对高维张量进行“跨批次、按行(或列)选取子集”是常见需求,但 torch.index_select 和 torch.take 仅支持一维索引,而 torch.gather 要求输入与索引张量维度严格一致——这正是本问题的关键难点。
要实现 A[b, m, n] 按 B[b, k] 索引(即:对每个 batch b,从 A[b] 的第 m 维中选取 k 个位置,保留全部 n 列),需将 B 扩展为三维,使其与 A 在 gather 所需维度上对齐:
- A 形状:(b, m, n)
- 目标索引维度:沿 dim=1(即 m 维)选取,因此 B 需扩展为 (b, k, n),且每个 (b, k) 位置对应 n 个相同索引(因每行选同一行号,复制到所有 n 列)
✅ 正确做法如下:
import torch
# 示例数据
b, m, n, k = 2, 5, 4, 3
A = torch.randn(b, m, n) # shape: [2, 5, 4]
B = torch.randint(0, m, (b, k)) # shape: [2, 3],值 ∈ [0, 4]
# Step 1: 扩展 B → (b, k, n),广播索引至每一列
B_expanded = B.unsqueeze(-1).expand(-1, -1, n) # 或 B[:, :, None].expand(-1, -1, n)
# Step 2: 使用 gather 沿 dim=1 聚合(注意:index 必须与 input 在非索引维度上尺寸一致)
result = torch.gather(A, dim=1, index=B_expanded) # shape: [2, 3, 4]
print("A shape:", A.shape)
print("B shape:", B.shape)
print("result shape:", result.shape)
print("result[0, 0] == A[0, B[0,0]]:", torch.equal(result[0, 0], A[0, B[0, 0]]))
? 关键说明:
- unsqueeze(-1)(等价于 [:,:,None])在末尾添加长度为 1 的维度,使 B 变为 (b, k, 1);
- expand(-1, -1, n) 智能广播该维度至 n,生成 (b, k, n) 索引张量,不分配新内存;
- torch.gather(A, dim=1, index=B_expanded) 表示:对每个 (b, n) 位置,从 A[b, :, n] 中按 B_expanded[b, :, n] 指定的行号取值;由于 B_expanded 在最后一维完全一致,等效于“每行选一个完整行向量”。
⚠️ 注意事项:
- B 中的索引值必须在 [0, m) 范围内,否则触发 RuntimeError: index out of bounds;
- gather 不支持负索引(如 -1),需预先处理为正索引;
- 若需梯度回传,请确保 B 是 torch.long 类型(gather 对 index 张量类型有严格要求);
- 替代方案(如高级索引 A[torch.arange(b)[:, None], B])虽更直观,但在某些版本中可能产生视图/副本歧义,gather 是官方推荐的可微、确定性方案。
综上,通过维度扩展 + torch.gather,即可高效、可导地完成批量多维索引,是 PyTorch 中处理此类结构化索引任务的标准范式。
文章来自机圈观察员网,发布者:,转载请注明出处:https://www.jqgcy.com/shoujipingce/124171.html