DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

举报
一个处女座的程序猿 发表于 2021/03/27 01:38:59 2021/03/27
【摘要】 DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读     目录 tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读 函数功能解读 函数代码实现     tf.contrib.rnn.BasicLSTMCell(rnn_unit)...

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

 

 

目录

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

函数代码实现


 

 

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/abs/1409.2329.

  We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  that follows.

  """

  def __init__(self,
               num_units,
               forget_bias=1.0,
               state_is_tuple=True,
               activation=None,
               reuse=None,
               name=None,
               dtype=None):
    """Initialize the basic LSTM cell.

基本LSTM递归网络单元。

实现基于:http://arxiv.org/abs/1409.2329。

我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。

它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。

 

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
        Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
      state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.
      name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.
      dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.

      When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.
    """

参数:
num_units: int类型, LSTM单元中的单元数。
forget_bias: float类型,偏见添加到忘记门(见上面)。
从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。
state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。
activation: 内部状态的激活功能。默认值tanh激活函数
reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。
name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。
dtype:该层的默认dtype(默认为‘None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。

从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。
”“”

 

函数代码实现


      @tf_export("nn.rnn_cell.BasicLSTMCell")
      class BasicLSTMCell(LayerRNNCell):
       """Basic LSTM recurrent network cell.
       The implementation is based on: http://arxiv.org/abs/1409.2329.
       We add forget_bias (default: 1) to the biases of the forget gate in order to
       reduce the scale of forgetting in the beginning of the training.
       It does not allow cell clipping, a projection layer, and does not
       use peep-hole connections: it is the basic baseline.
       For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
       that follows.
       """
       def __init__(self,
       num_units,
       forget_bias=1.0,
       state_is_tuple=True,
       activation=None,
       reuse=None,
       name=None,
       dtype=None):
      """Initialize the basic LSTM cell.
       Args:
       num_units: int, The number of units in the LSTM cell.
       forget_bias: float, The bias added to forget gates (see above).
       Must set to `0.0` manually when restoring from CudnnLSTM-trained
       checkpoints.
       state_is_tuple: If True, accepted and returned states are 2-tuples of
       the `c_state` and `m_state`. If False, they are concatenated
       along the column axis. The latter behavior will soon be deprecated.
       activation: Activation function of the inner states. Default: `tanh`.
       reuse: (optional) Python boolean describing whether to reuse variables
       in an existing scope. If not `True`, and the existing scope already has
       the given variables, an error is raised.
       name: String, the name of the layer. Layers with the same name will
       share weights, but to avoid mistakes we require reuse=True in such
       cases.
       dtype: Default dtype of the layer (default of `None` means use the type
       of the first input). Required when `build` is called before `call`.
       When restoring from CudnnLSTM-trained checkpoints, must use
       `CudnnCompatibleLSTMCell` instead.
       """
       super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
      if not state_is_tuple:
       logging.warn("%s: Using a concatenated state is slower and will soon be "
      "deprecated. Use state_is_tuple=True.", self)
      # Inputs must be 2-dimensional.
       self.input_spec = base_layer.InputSpec(ndim=2)
       self._num_units = num_units
       self._forget_bias = forget_bias
       self._state_is_tuple = state_is_tuple
       self._activation = activation or math_ops.tanh
       @property
       def state_size(self):
      return (LSTMStateTuple(self._num_units, self._num_units)
      if self._state_is_tuple else 2 * self._num_units)
       @property
       def output_size(self):
      return self._num_units
       def build(self, inputs_shape):
      if inputs_shape[1].value is None:
      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
       % inputs_shape)
       input_depth = inputs_shape[1].value
       h_depth = self._num_units
       self._kernel = self.add_variable(
       _WEIGHTS_VARIABLE_NAME,
       shape=[input_depth + h_depth, 4 * self._num_units])
       self._bias = self.add_variable(
       _BIAS_VARIABLE_NAME,
       shape=[4 * self._num_units],
       initializer=init_ops.zeros_initializer(dtype=self.dtype))
       self.built = True
       def call(self, inputs, state):
      """Long short-term memory cell (LSTM).
       Args:
       inputs: `2-D` tensor with shape `[batch_size, input_size]`.
       state: An `LSTMStateTuple` of state tensors, each shaped
       `[batch_size, num_units]`, if `state_is_tuple` has been set to
       `True`. Otherwise, a `Tensor` shaped
       `[batch_size, 2 * num_units]`.
       Returns:
       A pair containing the new hidden state, and the new state (either a
       `LSTMStateTuple` or a concatenated state, depending on
       `state_is_tuple`).
       """
       sigmoid = math_ops.sigmoid
       one = constant_op.constant(1, dtype=dtypes.int32)
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
       c, h = state
      else:
       c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
       gate_inputs = math_ops.matmul(
       array_ops.concat([inputs, h], 1), self._kernel)
       gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
       i, j, f, o = array_ops.split(
       value=gate_inputs, num_or_size_splits=4, axis=one)
       forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
      # Note that using `add` and `multiply` instead of `+` and `*` gives a
      # performance improvement. So using those at the cost of readability.
       add = math_ops.add
       multiply = math_ops.multiply
       new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
       multiply(sigmoid(i), self._activation(j)))
       new_h = multiply(self._activation(new_c), sigmoid(o))
      if self._state_is_tuple:
       new_state = LSTMStateTuple(new_c, new_h)
      else:
       new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state
  
 

 

文章来源: yunyaniu.blog.csdn.net,作者:一个处女座的程序猿,版权归原作者所有,如需转载,请联系作者。

原文链接:yunyaniu.blog.csdn.net/article/details/105439831

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。