DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

举报
一个处女座的程序猿 发表于 2021/03/27 01:38:59 2021/03/27
【摘要】 DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读     目录 tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读 函数功能解读 函数代码实现     tf.contrib.rnn.BasicLSTMCell(rnn_unit)...

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

 

 

目录

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

函数代码实现


 

 

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/abs/1409.2329.

  We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  that follows.

  """

  def __init__(self,
               num_units,
               forget_bias=1.0,
               state_is_tuple=True,
               activation=None,
               reuse=None,
               name=None,
               dtype=None):
    """Initialize the basic LSTM cell.

基本LSTM递归网络单元。

实现基于:http://arxiv.org/abs/1409.2329。

我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。

它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。

 

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
        Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
      state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.
      name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.
      dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.

      When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.
    """

参数:
num_units: int类型, LSTM单元中的单元数。
forget_bias: float类型,偏见添加到忘记门(见上面)。
从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。
state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。
activation: 内部状态的激活功能。默认值tanh激活函数
reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。
name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。
dtype:该层的默认dtype(默认为‘None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。

从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。
”“”

 

函数代码实现


  
  1. @tf_export("nn.rnn_cell.BasicLSTMCell")
  2. class BasicLSTMCell(LayerRNNCell):
  3. """Basic LSTM recurrent network cell.
  4. The implementation is based on: http://arxiv.org/abs/1409.2329.
  5. We add forget_bias (default: 1) to the biases of the forget gate in order to
  6. reduce the scale of forgetting in the beginning of the training.
  7. It does not allow cell clipping, a projection layer, and does not
  8. use peep-hole connections: it is the basic baseline.
  9. For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  10. that follows.
  11. """
  12. def __init__(self,
  13. num_units,
  14. forget_bias=1.0,
  15. state_is_tuple=True,
  16. activation=None,
  17. reuse=None,
  18. name=None,
  19. dtype=None):
  20. """Initialize the basic LSTM cell.
  21. Args:
  22. num_units: int, The number of units in the LSTM cell.
  23. forget_bias: float, The bias added to forget gates (see above).
  24. Must set to `0.0` manually when restoring from CudnnLSTM-trained
  25. checkpoints.
  26. state_is_tuple: If True, accepted and returned states are 2-tuples of
  27. the `c_state` and `m_state`. If False, they are concatenated
  28. along the column axis. The latter behavior will soon be deprecated.
  29. activation: Activation function of the inner states. Default: `tanh`.
  30. reuse: (optional) Python boolean describing whether to reuse variables
  31. in an existing scope. If not `True`, and the existing scope already has
  32. the given variables, an error is raised.
  33. name: String, the name of the layer. Layers with the same name will
  34. share weights, but to avoid mistakes we require reuse=True in such
  35. cases.
  36. dtype: Default dtype of the layer (default of `None` means use the type
  37. of the first input). Required when `build` is called before `call`.
  38. When restoring from CudnnLSTM-trained checkpoints, must use
  39. `CudnnCompatibleLSTMCell` instead.
  40. """
  41. super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
  42. if not state_is_tuple:
  43. logging.warn("%s: Using a concatenated state is slower and will soon be "
  44. "deprecated. Use state_is_tuple=True.", self)
  45. # Inputs must be 2-dimensional.
  46. self.input_spec = base_layer.InputSpec(ndim=2)
  47. self._num_units = num_units
  48. self._forget_bias = forget_bias
  49. self._state_is_tuple = state_is_tuple
  50. self._activation = activation or math_ops.tanh
  51. @property
  52. def state_size(self):
  53. return (LSTMStateTuple(self._num_units, self._num_units)
  54. if self._state_is_tuple else 2 * self._num_units)
  55. @property
  56. def output_size(self):
  57. return self._num_units
  58. def build(self, inputs_shape):
  59. if inputs_shape[1].value is None:
  60. raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
  61. % inputs_shape)
  62. input_depth = inputs_shape[1].value
  63. h_depth = self._num_units
  64. self._kernel = self.add_variable(
  65. _WEIGHTS_VARIABLE_NAME,
  66. shape=[input_depth + h_depth, 4 * self._num_units])
  67. self._bias = self.add_variable(
  68. _BIAS_VARIABLE_NAME,
  69. shape=[4 * self._num_units],
  70. initializer=init_ops.zeros_initializer(dtype=self.dtype))
  71. self.built = True
  72. def call(self, inputs, state):
  73. """Long short-term memory cell (LSTM).
  74. Args:
  75. inputs: `2-D` tensor with shape `[batch_size, input_size]`.
  76. state: An `LSTMStateTuple` of state tensors, each shaped
  77. `[batch_size, num_units]`, if `state_is_tuple` has been set to
  78. `True`. Otherwise, a `Tensor` shaped
  79. `[batch_size, 2 * num_units]`.
  80. Returns:
  81. A pair containing the new hidden state, and the new state (either a
  82. `LSTMStateTuple` or a concatenated state, depending on
  83. `state_is_tuple`).
  84. """
  85. sigmoid = math_ops.sigmoid
  86. one = constant_op.constant(1, dtype=dtypes.int32)
  87. # Parameters of gates are concatenated into one multiply for efficiency.
  88. if self._state_is_tuple:
  89. c, h = state
  90. else:
  91. c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
  92. gate_inputs = math_ops.matmul(
  93. array_ops.concat([inputs, h], 1), self._kernel)
  94. gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
  95. # i = input_gate, j = new_input, f = forget_gate, o = output_gate
  96. i, j, f, o = array_ops.split(
  97. value=gate_inputs, num_or_size_splits=4, axis=one)
  98. forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
  99. # Note that using `add` and `multiply` instead of `+` and `*` gives a
  100. # performance improvement. So using those at the cost of readability.
  101. add = math_ops.add
  102. multiply = math_ops.multiply
  103. new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
  104. multiply(sigmoid(i), self._activation(j)))
  105. new_h = multiply(self._activation(new_c), sigmoid(o))
  106. if self._state_is_tuple:
  107. new_state = LSTMStateTuple(new_c, new_h)
  108. else:
  109. new_state = array_ops.concat([new_c, new_h], 1)
  110. return new_h, new_state

 

文章来源: yunyaniu.blog.csdn.net,作者:一个处女座的程序猿,版权归原作者所有,如需转载,请联系作者。

原文链接:yunyaniu.blog.csdn.net/article/details/105439831

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。