《神经网络与PyTorch实战》——3.1.2 PyTorch里的张量
3.1.2 PyTorch里的张量
在扩展库PyTorch中,张量是运算的基本数据类型,用类torch.Tensor实现。在正式介绍torch.Tensor类的用法之前,本节先来看一段使用torch.Tensor类的代码,使你对PyTorch里的张量有个直观的印象。
代码清单3-1是一段使用torch.Tensor类的代码。该代码首先导入了扩展库PyTorch。扩展库PyTorch在Python里记作“torch”,可以用语句“import torch”导入。接着,代码用torch.tensor() 函数将一个列表转化为torch.Tensor类实例t2。注意,函数名tensor的所有字母都是小写字母,而类名的首字母是大写的。得到torch.Tensor类实例t2后,代码显示t2的数据内容。接着,代码还将张量t2的元素进行重新组织,使得其大小从变为。最后,代码演示了如何利用张量进行数学计算,将张量逐元素加1得到新的张量。
代码清单3-1 一段使用torch.Tensor类的代码
import torch
t2 = torch.tensor([[0, 1, 2], [3, 4, 5]])
print(t2)
print('数据 = {}'.format(t2))
print(t2.reshape(3, 2)) # 重新组织元素
print(t2 + 1) # 逐元素运算
代码清单3-1涉及张量的构造、元素重组、科学计算等用法。在本章后续内容中,我们将逐一对以上话题进行介绍。
前一节提到了张量的大小、维度和元素个数。在PyTorch中,可以通过torch.Tensor类实例的成员获得这些性质。torch.Tensor类包括以下成员(但不限于此)。
* 成员size():返回张量的大小,它是一个元组的子类torch.Size类实例。
* 成员dim():返回张量的维度,它是一个int类型的数值。它是张量大小的数据条目个数。
* 成员numel():张量中元素的个数,它是一个int类型的数值。它是张量大小各条目之积。
例如,可以用以下代码查看代码清单3-1中张量t2的数据、大小、维度和元素个数:
print('数据 = {}'.format(t2))
print('大小 = {}'.format(t2.size()))
print('维度 = {}'.format(t2.dim()))
print('元素个数 = {}'.format(t2.numel()))
在代码清单3-1中,张量类实例t2的大小为,维度为2,元素个数为6。
另外,每个张量类实例还会有元素类型(dtype)。为了更好地适应不同的应用需求,PyTorch提供了下列元素类型。
* 浮点类型:torch.float16、torch.float32、torch.float64。它们的区别在于位数和精度不同。类型名称的后缀“16”“32”“64”表示位数,位数越大精度越高。这些浮点类型中最常用的是torch.float32。
* 整数类型:torch.uint8、torch.int8、torch.int16、torch.int32、torch.int64。其中torch.uint8是无符号(unsigned)的整数类型,它只能表示非负整数;其他整数类型都是有符号类型。这些整数类型中最常用torch.int64类型表示指标(即0、1、2这样的序号),用torch.uint8表示逻辑类型(即“是”或“否”)。
可以通过张量类实例的成员dtype查看元素类型:
print('元素类型 = {}'.format(t2.dtype))
代码清单3-1在构造torch.Tensor类实例时,使用了int值列表,所以构造出来的张量的数据类型默认为torch.int64型。如果用bool值列表构造,则构造出来的张量元素类型默认是torch.uint8型;如果用float值列表构造,则构造出来的张量的元素类型默认是torch.float32型。
- 点赞
- 收藏
- 关注作者
评论(0)