본문 바로가기
ML

Basic RNN/LSTM cell implementation

by 나른한 사람 2021. 9. 10.

RNN

  • $ h_t = tanh(W_hh_{t-1}+W_xx_t+b) $
  • $ y_t = W_yh_t $

RNN 구조

class RNNCell(layers.Layer):
    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(RNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w_xh = self.add_weight(shape=(input_shape[-1], self.units),
                                     initializer='uniform',
                                     name='W_xh')
        self.w_hh = self.add_weight(shape=(self.units, self.units),
                                  initializer='uniform',
                                  name='W_hh')
        self.w_hy = self.add_weight(shape=(self.units, self.units),
                                   initializer='uniform',
                                   name='W_hy')
        self.b_h = self.add_weight(shape=(self.units,),
                                   initializer='zeros',
                                   name='b_h')
        self.b_y = self.add_weight(shape=(self.units,),
                                  initializer='zeros',
                                  name='b_y')
        self.built = True

    def call(self, inputs, states):
        prev_hidden = states[0]

        h = K.dot(inputs,self.w_xh) + K.dot(prev_hidden, self.w_hh)
        h = K.tanh(h + self.b_h)

        output = K.dot(h, self.w_hy) + self.b_y
        return output, [output]

LSTM

LSTM 구조

  • forget gate

  • input gate

  • update

  • output gate

class LSTMCell(layers.Layer):
    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(LSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        weight_size = input_shape[-1]*2
        # forget gate
        self.w_f = self.add_weight(shape=(weight_size, self.units),
                                  initializer='uniform',
                                  name='w_f')
        self.b_f = self.add_weight(shape=(self.units,),
                                  initializer='zeros',
                                  name='b_f')

        # input gate
        self.w_i = self.add_weight(shape=(weight_size, self.units),
                                  initializer='uniform',
                                  name='w_i')
        self.b_i = self.add_weight(shape=(self.units,),
                                  initializer='zeros',
                                  name='b_i')
        self.w_c = self.add_weight(shape=(weight_size, self.units),
                                   initializer='uniform',
                                   name='w_c')
        self.b_c = self.add_weight(shape=(self.units,),
                                  initializer='zeros',
                                  name='b_c')

        # output gate
        self.w_o = self.add_weight(shape=(weight_size, self.units),
                                  initializer='uniform',
                                  name='w_o')
        self.b_o = self.add_weight(shape=(self.units,),
                                  initializer='zeros',
                                  name='b_o')
        self.built = True
    def call(self, inputs, states):
        h_prev, c_prev = states
        hx_concat = K.concatenate([h_prev, inputs], axis=-1)
        # forget gate
        f_t = K.sigmoid(K.dot(self.w_f, hx_concat) + self.b_f)
        # input gate
        i_t = K.sigmoid(K.dot(self.w_i, hx_concat) + self.b_i)
        c_t_hat = K.tanh(K.dot(self.w_c, hx_concat) + self.b_c)
        # update
        c_t = f_t*c_prev + i_t*c_t_hat
        # output gate
        o_t = K.sigmoid(K.dot(self.w_o, hx_concat) + b_o)
        h_t = o_t*K.tanh(c_t)

        return h_t, [h_t, c_t]

 

'ML' 카테고리의 다른 글

seq2seq + attention 이란?  (0) 2021.08.17
RNN, LSTM ?  (0) 2021.08.16
[Tensorflow 2.x] 기초  (0) 2021.08.13
AutoEncoder  (0) 2021.07.28
[정리] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction  (0) 2021.04.17

댓글