Skip to main content
 首页 » 编程设计

python之在 Pytorch 中重复张量的特定列

2024年06月03日22zdz8207

我有一个大小为 m x n 的 pytorch 张量 X 和一个长度为 n 的非负整数列表 num_repeats (假设总和(num_repeats)> 0)。在forward()方法中,我想创建一个大小为m x sum(num_repeats)的张量X_dup,其中X的列i 重复 num_repeats[i] 次。张量X_dup将在forward()方法的下游使用,因此梯度需要正确反向传播。 我能想到的所有解决方案都需要就地操作(创建一个新的张量并通过迭代num_repeats来填充它),但如果我理解正确的话,这不会保留梯度(如果我',请纠正我)我错了,我对整个 Pytorch 都是新手)。

请您参考如下方法:

如果您使用 PyTorch >= 1.1.0,您可以使用 torch.repeat_interleave .

repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64) 
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)