本文共 822 字,大约阅读时间需要 2 分钟。
在PyTorch中,squeeze函数是一个强大的工具,用于去除维度。它能够根据指定的维度或默认维度,将大小为1的维度从张量中去除。这种操作在数据预处理和模型训练中非常有用。
以下是squeeze函数的实际应用示例:
>>>> x = torch.zeros(2, 1, 2, 1, 2)>>> x.size() # torch.Size([2, 1, 2, 1, 2]) >>>> y = torch.squeeze(x)>>> y.size() # torch.Size([2, 2, 2]) 在这个例子中,squeeze函数默认去除了所有大小为1的维度,结果是一个大小为(2,2,2)的张量。
>>>> y = torch.squeeze(x, 0)>>> y.size() # torch.Size([2, 1, 2, 1, 2])
如果指定了dim参数,只会去除指定的那个维度。例如,dim=0表示去除第一个维度,结果保持其他维度不变。
>>>> y = torch.squeeze(x, 1)>>> y.size() # torch.Size([2, 2, 1, 2])
同样,可以单独去除指定维度的大小为1的维度。
squeeze函数的功能可以通过以下方式调用:
torch.squeeze(input, dim=None, out=None) → Tensor | 参数 | 说明 |
| input (Tensor) | 输入张量。 |
| dim (int, 选项) | 指定要去除的维度。如果不指定,默认去除所有大小为1的维度。 |
| out (Tensor, 选项) | 返回去除维度后的张量。 |
参考文档:[PyTorch官方文档](https://pytorch.org/docs/stable/generated/torch.squeeze.html#torch.squeeze)
转载地址:http://igefk.baihongyu.com/