Pytorch使用篇之常用函数总结(一)
在笔者日常使用pytorch的过程中,总结了如下的常用函数:
1. 把pandas.DataFrame mat 转换为torch的张量:
mat = torch.tensor(mat.to_numpy(), dtype=torch.float32)
2. 把numpy.array转换为torch张量:
x_torch = torch.from_numpy(x_np)
torch张量转为numpy.array():
x_np = x_torch.numpy()
3. 在定义神经网络模型时,由于网络参数需要随机初始化,因为在定义神经网络模型前,需要设置随机种子,一般把numpy和torch都设置一下随机种子:
seed = 1992
np.random.seed(seed)
torch.manual_seed(seed)
4. 在定义神经网络模型时,还可以设置一下线程数,这样可以并行,加速训练:
n_cores = 32
torch.set_num_threads(n_cores)
5. adj_matrix: 是一个64*64的torch.tensor:
torch.ones_like(adj) # 构造一个64*64的torch.tensor,元素全为1
6. 比如我们定义一个list: x_list = []
x_list中存放了4个shape为:torch.Size([1, 96, 64])的torch张量,那么我可以使用:
x_return = torch.cat(x_list, dim=2) # 按照第2维度进行拼接,其中第2维度(是从0开始编码的)表示:64
则x_return的shape为:torch.Size([1, 96, 256])
7. torch.unsqueeze:将数据进行扩充。unsqueeze(arg)是增添第arg个维度为1,以插入的形式填充。比如:
a = torch.rand(1, 3)
a.shape # [1,3]
b = a.unsqueeze(1)
b.shape # [1,1,3]
d = a.unsqueeze(2)
d.shape # [1,3,1]
torch.squeeze: 将数据进行压缩,使得数据更紧凑
squeeze(arg):删除第arg个维度(如果当前维度不为1,则不会进行删除),比如:
d是1个size为[1,3,1]的torch.tensor,则:
e = d.squeeze(0)
e.shape # [3,1]
f = d.squeeze(1)
f.shape # [1,3,1]
8. detach():返回一个新的从当前图中分离的变量,该变量的require_grad=False, 得到的这个Variable不需要计算其梯度,不具有grad。
假设有变量:
a = tensor([[[-1.5962e-01, -1.6366e-01, -6.5601e-02, ..., -8.3999e-02,
-3.9803e-02, -1.1769e-01],
[-2.3615e-01, -1.5466e-01, -1.0734e-01, ..., -1.4286e-02,
-6.4129e-02, -1.6298e-01],
[-2.6935e-01, -2.7972e-01, -1.5651e-01, ..., -1.1821e-01,
-4.8074e-02, -2.7009e-01],
...,
[ 8.5757e-02, -4.4065e-01, 1.8019e-01, ..., 2.1553e-01,
4.3380e-01, -4.4875e-01],
[-8.9650e-02, -1.3082e-01, 2.3181e-02, ..., 3.6276e-01,
4.8729e-01, -2.7162e-01],
[ 9.0946e-02, -8.8203e-02, 9.7522e-02, ..., 7.0850e-02,
5.0222e-01, -3.0326e-01]]], grad_fn=<CopySlices>)
那么运行b = a.detach(),则得到b: 就是没有grad_fn了
tensor([[[-1.5962e-01, -1.6366e-01, -6.5601e-02, ..., -8.3999e-02,
-3.9803e-02, -1.1769e-01],
[-2.3615e-01, -1.5466e-01, -1.0734e-01, ..., -1.4286e-02,
-6.4129e-02, -1.6298e-01],
[-2.6935e-01, -2.7972e-01, -1.5651e-01, ..., -1.1821e-01,
-4.8074e-02, -2.7009e-01],
...,
[ 8.5757e-02, -4.4065e-01, 1.8019e-01, ..., 2.1553e-01,
4.3380e-01, -4.4875e-01],
[-8.9650e-02, -1.3082e-01, 2.3181e-02, ..., 3.6276e-01,
4.8729e-01, -2.7162e-01],
[ 9.0946e-02, -8.8203e-02, 9.7522e-02, ..., 7.0850e-02,
5.0222e-01, -3.0326e-01]]])
然后可转为numpy,即为:a.detach().numpy()
还有.detach_()、.data()
- 点赞
- 收藏
- 关注作者
评论(0)