我有一个大小为 m x n
的 pytorch 张量 X
和一个长度为 n
的非负整数列表 num_repeats
(假设总和(num_repeats)> 0)。在forward()方法中,我想创建一个大小为m x sum(num_repeats)
的张量X_dup
,其中X的列
重复 i
num_repeats[i]
次。张量X_dup
将在forward()方法的下游使用,因此梯度需要正确反向传播。 我能想到的所有解决方案都需要就地操作(创建一个新的张量并通过迭代num_repeats
来填充它),但如果我理解正确的话,这不会保留梯度(如果我',请纠正我)我错了,我对整个 Pytorch 都是新手)。
请您参考如下方法:
如果您使用 PyTorch >= 1.1.0,您可以使用 torch.repeat_interleave
.
repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)