DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
目录
tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
函数功能解读
"""Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} """ def __init__(self, |
基本LSTM递归网络单元。 实现基于:http://arxiv.org/abs/1409.2329。 我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。 它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。
|
Args: When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. |
参数: 从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。 |
函数代码实现
-
-
@tf_export("nn.rnn_cell.BasicLSTMCell")
-
class BasicLSTMCell(LayerRNNCell):
-
"""Basic LSTM recurrent network cell.
-
-
The implementation is based on: http://arxiv.org/abs/1409.2329.
-
-
We add forget_bias (default: 1) to the biases of the forget gate in order to
-
reduce the scale of forgetting in the beginning of the training.
-
-
It does not allow cell clipping, a projection layer, and does not
-
use peep-hole connections: it is the basic baseline.
-
-
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
-
that follows.
-
"""
-
-
def __init__(self,
-
num_units,
-
forget_bias=1.0,
-
state_is_tuple=True,
-
activation=None,
-
reuse=None,
-
name=None,
-
dtype=None):
-
"""Initialize the basic LSTM cell.
-
-
Args:
-
num_units: int, The number of units in the LSTM cell.
-
forget_bias: float, The bias added to forget gates (see above).
-
Must set to `0.0` manually when restoring from CudnnLSTM-trained
-
checkpoints.
-
state_is_tuple: If True, accepted and returned states are 2-tuples of
-
the `c_state` and `m_state`. If False, they are concatenated
-
along the column axis. The latter behavior will soon be deprecated.
-
activation: Activation function of the inner states. Default: `tanh`.
-
reuse: (optional) Python boolean describing whether to reuse variables
-
in an existing scope. If not `True`, and the existing scope already has
-
the given variables, an error is raised.
-
name: String, the name of the layer. Layers with the same name will
-
share weights, but to avoid mistakes we require reuse=True in such
-
cases.
-
dtype: Default dtype of the layer (default of `None` means use the type
-
of the first input). Required when `build` is called before `call`.
-
-
When restoring from CudnnLSTM-trained checkpoints, must use
-
`CudnnCompatibleLSTMCell` instead.
-
"""
-
super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
-
if not state_is_tuple:
-
logging.warn("%s: Using a concatenated state is slower and will soon be "
-
"deprecated. Use state_is_tuple=True.", self)
-
-
# Inputs must be 2-dimensional.
-
self.input_spec = base_layer.InputSpec(ndim=2)
-
-
self._num_units = num_units
-
self._forget_bias = forget_bias
-
self._state_is_tuple = state_is_tuple
-
self._activation = activation or math_ops.tanh
-
-
@property
-
def state_size(self):
-
return (LSTMStateTuple(self._num_units, self._num_units)
-
if self._state_is_tuple else 2 * self._num_units)
-
-
@property
-
def output_size(self):
-
return self._num_units
-
-
def build(self, inputs_shape):
-
if inputs_shape[1].value is None:
-
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-
% inputs_shape)
-
-
input_depth = inputs_shape[1].value
-
h_depth = self._num_units
-
self._kernel = self.add_variable(
-
_WEIGHTS_VARIABLE_NAME,
-
shape=[input_depth + h_depth, 4 * self._num_units])
-
self._bias = self.add_variable(
-
_BIAS_VARIABLE_NAME,
-
shape=[4 * self._num_units],
-
initializer=init_ops.zeros_initializer(dtype=self.dtype))
-
-
self.built = True
-
-
def call(self, inputs, state):
-
"""Long short-term memory cell (LSTM).
-
-
Args:
-
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
-
state: An `LSTMStateTuple` of state tensors, each shaped
-
`[batch_size, num_units]`, if `state_is_tuple` has been set to
-
`True`. Otherwise, a `Tensor` shaped
-
`[batch_size, 2 * num_units]`.
-
-
Returns:
-
A pair containing the new hidden state, and the new state (either a
-
`LSTMStateTuple` or a concatenated state, depending on
-
`state_is_tuple`).
-
"""
-
sigmoid = math_ops.sigmoid
-
one = constant_op.constant(1, dtype=dtypes.int32)
-
# Parameters of gates are concatenated into one multiply for efficiency.
-
if self._state_is_tuple:
-
c, h = state
-
else:
-
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
-
-
gate_inputs = math_ops.matmul(
-
array_ops.concat([inputs, h], 1), self._kernel)
-
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
-
-
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
-
i, j, f, o = array_ops.split(
-
value=gate_inputs, num_or_size_splits=4, axis=one)
-
-
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
-
# Note that using `add` and `multiply` instead of `+` and `*` gives a
-
# performance improvement. So using those at the cost of readability.
-
add = math_ops.add
-
multiply = math_ops.multiply
-
new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
-
multiply(sigmoid(i), self._activation(j)))
-
new_h = multiply(self._activation(new_c), sigmoid(o))
-
-
if self._state_is_tuple:
-
new_state = LSTMStateTuple(new_c, new_h)
-
else:
-
new_state = array_ops.concat([new_c, new_h], 1)
-
return new_h, new_state
文章来源: yunyaniu.blog.csdn.net,作者:一个处女座的程序猿,版权归原作者所有,如需转载,请联系作者。
原文链接:yunyaniu.blog.csdn.net/article/details/105439831
- 点赞
- 收藏
- 关注作者
评论(0)