用于 Python 深度学习项目的 PyTorch 与 TensorFlow

举报
Yuchuan 发表于 2021/08/30 18:05:37 2021/08/30
【摘要】 在本教程中,您已经介绍了 PyTorch 和 TensorFlow,了解了谁在使用它们以及它们支持哪些 API,并了解了如何为您的项目选择 PyTorch 与 TensorFlow。您已经了解了每种语言、工具、数据集和模型所支持的不同编程语言,并了解了如何选择最适合您的独特风格和项目的一种。

目录

PyTorch 与 TensorFlow:有何不同?两者都是使用图对数据执行数值计算的开源 Python 库。两者都广泛用于学术研究和商业代码。两者都通过各种 API、云计算平台和模型存储库进行扩展。

如果它们如此相似,那么哪一个最适合您的项目?

在本教程中,您将学习:

  • PyTorchTensorFlow有什么区别
  • 每个人都有哪些可用的工具资源
  • 如何为您的特定用例选择最佳选项

您将首先仔细研究这两个平台,从稍旧的 TensorFlow 开始,然后探索一些可以帮助您确定哪种选择最适合您的项目的注意事项。让我们开始吧!

什么是 TensorFlow?

TensorFlow 由 Google 开发并于 2015 年作为开源发布。它源于 Google 的自主机器学习软件,该软件经过重构和优化以用于生产。

“TensorFlow”这个名字描述了你如何组织和执行数据操作。TensorFlow 和 PyTorch 的基本数据结构都是张量。当您使用 TensorFlow 时,您可以通过构建一个有状态的数据流图来对这些张量中的数据执行操作,有点像记住过去事件的流程图。

谁在使用 TensorFlow?

TensorFlow 以生产级深度学习库而闻名。它拥有庞大而活跃的用户群,以及大量用于训练、部署和服务模型的官方和第三方工具和平台。

2016 年 PyTorch 发布后,TensorFlow 的受欢迎程度有所下降。但在 2019 年底,谷歌发布了TensorFlow 2.0,这是一项重大更新,简化了库并使其更加用户友好,重新引起了机器学习社区的兴趣。

代码风格和功能

在 TensorFlow 2.0 之前,TensorFlow 要求您通过调用 API手动将抽象语法树(图)拼接在一起tf.*。然后它要求您通过将一组输出张量和输入张量传递给session.run()调用来手动编译模型。

一个Session对象是运行TensorFlow操作类。它包含Tensor评估Operation对象和执行对象的环境,并且它可以像tf.Variable对象一样拥有资源。使用 a 的最常见方法Session是作为上下文管理器

在 TensorFlow 2.0 中,您仍然可以通过这种方式构建模型,但使用Eager Execution更容易,这是 Python 通常的工作方式。Eager execution 会立即评估操作,因此您可以使用 Python 控制流而不是图形控制流来编写代码。

要查看差异,让我们看看如何使用每种方法将两个张量相乘。这是使用旧的 TensorFlow 1.0 方法的示例:

>>>
>>> import tensorflow as tf

>>> tf.compat.v1.disable_eager_execution()

>>> x = tf.compat.v1.placeholder(tf.float32, name = "x")
>>> y = tf.compat.v1.placeholder(tf.float32, name = "y")

>>> multiply = tf.multiply(x, y)

>>> with tf.compat.v1.Session() as session:
...     m = session.run(
...         multiply, feed_dict={x: [[2., 4., 6.]], y: [[1.], [3.], [5.]]}
...     )
...     print(m)
[[ 2.  4.  6.]
 [ 6. 12. 18.]
 [10. 20. 30.]]

此代码使用 TensorFlow 2.x 的tf.compatAPI 来访问 TensorFlow 1.x 方法并禁用 Eager Execution。

您首先声明输入张量xy使用tf.compat.v1.placeholder张量对象。然后定义要对它们执行的操作。接下来,使用该tf.Session对象作为上下文管理器,创建一个容器来封装运行时环境,并通过将实际值提供给带有feed_dict. 最后,还是在 session 里面,你print()的结果。

在 TensorFlow 2.0 中使用 Eager Execution,您只需要tf.multiply()达到相同的结果:

>>>
>>> import tensorflow as tf

>>> x = [[2., 4., 6.]]
>>> y = [[1.], [3.], [5.]]
>>> m = tf.multiply(x, y)

>>> m
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[ 2.,  4.,  6.],
       [ 6., 12., 18.],
       [10., 20., 30.]], dtype=float32)>

在此代码中,您使用 Python 列表符号声明张量,并在调用时立即tf.multiply()执行元素乘法

如果您不想或不需要构建低级组件,那么推荐使用 TensorFlow 的方法是Keras。它具有更简单的 API,为您将常见用例滚动到预制组件中,并提供比基础 TensorFlow 更好的错误消息。

特殊功能

TensorFlow 拥有庞大而完善的用户群和大量工具来帮助生产机器学习。对于移动开发,它具有适用于JavaScript和 Swift 的API ,并且TensorFlow Lite可让您压缩和优化物联网设备的模型。

由于Google 和第三方提供的大量数据、预训练模型和Google Colab 笔记本,您可以快速开始使用 TensorFlow 。

许多流行的机器学习算法和数据集都内置在 TensorFlow 中并且可以立即使用。除了内置数据集,您还可以访问Google Research 数据集或使用 Google 的数据集搜索来查找更多数据。

Keras 可以更轻松地启动和运行模型,因此您可以在更短的时间内尝试新技术。事实上,Keras 是Kaggle前五名获胜团队中使用最广泛的深度学习框架。

一个缺点是从 TensorFlow 1.x 到 TensorFlow 2.0 的更新改变了很多功能,您可能会发现自己感到困惑。升级代码既乏味又容易出错。许多资源,如教程,可能包含过时的建议。

PyTorch 没有同样大的向后兼容性问题,这可能是选择它而不是 TensorFlow 的一个原因。

Tensorflow 生态系统

TensorFlow 扩展生态系统的 API、扩展和有用工具的一些亮点包括:

什么是 PyTorch?

PyTorch 由 Facebook 开发,并于 2016年首次公开发布。它旨在提供类似于 TensorFlow 的生产优化,同时使模型更易于编写。

由于 Python 程序员发现它使用起来非常自然,PyTorch 迅速获得了用户,这促使 TensorFlow 团队在 TensorFlow 2.0 中采用了许多 PyTorch 最受欢迎的功能。

谁在使用 PyTorch?

PyTorch 以在研究中比在生产中使用更广泛而闻名。然而,自从在 TensorFlow 之后的一年发布以来,PyTorch 的专业开发人员使用量急剧增加。

2020堆栈溢出开发者调查最流行的“其他框架,库和工具”报道,专业开发人员10.4%选择TensorFlow和4.1%选择PyTorch的名单。在2018,百分比分别为TensorFlow 7.6%和PyTorch只是1.6%。

至于研究,PyTorch 是一个流行的选择,像斯坦福这样的计算机科学项目现在用它来教授深​​度学习。

代码风格和功能

PyTorch 基于Torch,这是一个用 C 编写的用于进行快速计算的框架。 Torch 有一个用于构建模型的Lua包装器。

PyTorch将相同的 C 后端包装在 Python 接口中。但它不仅仅是一个包装器。开发人员从头开始构建它,使 Python 程序员可以轻松编写模型。底层的低级 C 和 C++ 代码针对运行 Python 代码进行了优化。由于这种紧密集成,您可以获得:

  • 更好的内存和优化
  • 更合理的错误信息
  • 模型结构的细粒度控制
  • 更透明的模型行为
  • 与 NumPy 更好的兼容性

这意味着您可以直接在 Python 中编写高度定制的神经网络组件,而无需使用大量低级函数。

PyTorch 的 Eager execution立即动态地评估张量操作,启发了 TensorFlow 2.0,因此两者的 API 看起来很相似。

将 NumPy 对象转换为张量被烘焙到 PyTorch 的核心数据结构中。这意味着您可以轻松地在torch.Tensor对象和numpy.array对象之间来回切换。

例如,您可以使用 PyTorch 的原生支持将 NumPy 数组转换为张量以创建两个numpy.array对象,使用 将每个torch.Tensor对象转换为一个对象torch.from_numpy(),然后获取它们的元素乘积

>>>
>>> import torch
>>> import numpy as np

>>> x = np.array([[2., 4., 6.]])
>>> y = np.array([[1.], [3.], [5.]])

>>> m = torch.mul(torch.from_numpy(x), torch.from_numpy(y))

>>> m.numpy()
array([[ 2.,  4.,  6.],
       [ 6., 12., 18.],
       [10., 20., 30.]])

Usingtorch.Tensor.numpy()允许您将矩阵乘法的结果(它是一个torch.Tensor对象)作为numpy.array对象打印出来。

torch.Tensor对象和numpy.array对象之间最重要的区别在于torch.Tensor 具有不同的方法和属性,例如backward()计算梯度的 和CUDA兼容性。

特殊功能

PyTorch向 Torch 后端添加了一个用于自动分化的 C++ 模块。自微分自动计算torch.nn反向传播过程中定义的函数的梯度。

默认情况下,PyTorch 使用 Eager模式计算。您可以在构建时逐行运行神经网络,这样可以更轻松地进行调试。它还使得构建具有条件执行的神经网络成为可能。这种动态执行对于大多数 Python 程序员来说更为直观。

PyTorch 生态系统

PyTorch 扩展生态系统的 API、扩展和有用工具的一些亮点包括:

  • fast.ai API,这使得它非常容易建立模型迅速
  • TorchServe,AWS 和 Facebook 合作开发的开源模型服务器
  • TorchElastic使用Kubernetes大规模训练深度神经网络
  • PyTorch Hub,一个分享和扩展前沿模型的活跃社区

PyTorch 与 TensorFlow 决策指南

使用哪个库取决于您自己的风格和偏好、您的数据和模型以及您的项目目标。当您通过对哪个库最能支持这三个因素进行一些研究来开始您的项目时,您将为成功做好准备!

风格

如果您是 Python 程序员,那么 PyTorch 会很容易上手。它以您期望的方式工作,开箱即用。

另一方面,TensorFlow 支持的编码语言比 PyTorch 多,后者具有 C++ API。您可以在 JavaScript 和 Swift 中使用 TensorFlow。如果你不想编写太多底层代码,那么 Keras 会抽象出很多常见用例的细节,这样你就可以构建 TensorFlow 模型而不必担心细节。

数据和模型

你用的是什么型号?如果您想使用特定的预训练模型,例如BERTDeepDream,那么您应该研究它与什么兼容。一些预训练模型仅在一个库或另一个库中可用,而有些则在这两个库中都可用。模型花园、PyTorch 和 TensorFlow 中心也是值得检查的好资源。

你需要什么数据?如果您想使用预处理数据,那么它可能已经内置到一个或另一个库中。检查文档以查看 - 这将使您的开发速度更快!

项目目标

你的模特会住在哪里?如果您想在移动设备上部署模型,那么 TensorFlow 是一个不错的选择,因为 TensorFlow Lite 及其 Swift API。对于服务模型,TensorFlow 与 Google Cloud 紧密集成,但 PyTorch 集成到 AWS 上的 TorchServe 中。如果你想参加 Kaggle 比赛,那么 Keras 会让你快速迭代实验。

在项目开始时考虑这些问题和示例。确定两个或三个最重要的组件,TensorFlow 或 PyTorch 将成为正确的选择。

结论

在本教程中,您已经介绍了 PyTorch 和 TensorFlow,了解了谁在使用它们以及它们支持哪些 API,并了解了如何为您的项目选择 PyTorch 与 TensorFlow。您已经了解了每种语言、工具、数据集和模型所支持的不同编程语言,并了解了如何选择最适合您的独特风格和项目的一种。

在本教程中,您学习了:

  • PyTorchTensorFlow有什么区别
  • 如何使用张量进行计算
  • 哪个平台适合不同类型的项目
  • 每个支持哪些工具数据

既然您已经决定要使用哪个库,您就可以开始使用它们构建神经网络了。查看进一步阅读中的链接以获取想法。

进一步阅读

以下教程是使用 PyTorch 和 TensorFlow 进行动手练习的好方法:

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。