
本文详解如何使用 torch.gather 对 batched 3D 张量进行高效索引,通过维度扩展使索引张量与目标张量对齐,从而实现从 [b, m, n] 张量中按 [b, k] 索引提取出 [b, k, n] 结果。
本文详解如何使用 `torch.gather` 对 batched 3d 张量进行高效索引,通过维度扩展使索引张量与目标张量对齐,从而实现从 `[b, m, n]` 张量中按 `[b, k]` 索引提取出 `[b, k, n]` 结果。
在 PyTorch 中,直接用多维索引张量(如 [b, k])去索引高维张量(如 [b, m, n])并非原生支持的操作——torch.index_select 和 torch.take 仅接受一维索引,而 torch.gather 要求输入与索引张量在除指定维度外其余维度严格一致。但通过合理的维度变换,这一需求完全可以优雅实现。
核心思路是:将索引张量 B 扩展为与 A 相容的形状,再沿目标维度(这里是第 1 维,即 m 维)调用 gather。
假设:
- A 形状为 (b, m, n):表示 b 个样本,每个含 m 行、n 列特征;
- B 形状为 (b, k):每个样本需选取 k 个行索引(值 ∈ [0, m));
期望输出 out 形状为 (b, k, n),即每个样本取 k 行,保留全部 n 列。
✅ 正确做法如下:
-
扩展索引张量维度:将 B 从 (b, k) 变为 (b, k, n),使其在最后一维上与 A 对齐。使用 None(即 unsqueeze(-1))添加 singleton 维,再用 expand 广播至 n 列:
B_expanded = B.unsqueeze(-1).expand(-1, -1, A.size(-1)) # shape: (b, k, n)
-
沿 dim=1 聚合:A.gather(1, B_expanded) 表示:对每个 (b, n) 切片,在 m 维(即行方向)按 B_expanded[b, :, n] 的索引取值:
out = A.gather(1, B_expanded) # shape: (b, k, n)
完整可运行示例:
import torch
b, m, n, k = 2, 5, 4, 3
A = torch.randn(b, m, n)
B = torch.randint(0, m, (b, k)) # 随机行索引,范围 [0, m)
# 扩展并 gather
B_expanded = B.unsqueeze(-1).expand(-1, -1, A.size(-1))
out = A.gather(1, B_expanded)
print(f"A.shape: {A.shape}") # torch.Size([2, 5, 4])
print(f"B.shape: {B.shape}") # torch.Size([2, 3])
print(f"out.shape: {out.shape}") # torch.Size([2, 3, 4])
# 验证:out[0, 0] 应等于 A[0, B[0, 0], :]
assert torch.equal(out[0, 0], A[0, B[0, 0]])
⚠️ 注意事项:
- B 中的索引值必须在 [0, m) 范围内,否则会触发 IndexError;
- expand() 是零拷贝操作,高效安全;若需修改索引,应先 clone() 再操作;
- 不要误用 torch.index_select(A, dim=1, index=B.flatten())——它会破坏 batch 结构,输出 (b*k, n) 而非 (b, k, n);
- 若 B 含负索引,PyTorch 会自动转换(如 -1 → m-1),但建议显式校验以增强鲁棒性。
该方法兼具简洁性与高性能,是 PyTorch 中处理 batched 多维索引的标准范式。
文章来自机圈观察员网,发布者:,转载请注明出处:https://www.jqgcy.com/xinjizixun/123811.html