본문 바로가기
ML/CNN

[CNN] GoogLeNet Implementation (Keras)

by 나른한 사람 2021. 8. 6.

GoogLeNet

GoogLeNet 구조

https://arxiv.org/pdf/1409.4842.pdf

위 논문에 따라 Keras로 구현해보았다. 논문의 unit과 num_class는 각각 1024, 1000 이다.

Auxiliary Classifier

Auxiliary Classifier 구조

def AuxiliaryClassifier(inputs, name=None):
    x = AveragePooling2D((5,5), strides=3, padding='valid')(inputs)
    x = Conv2D(128, (1,1), strides=1, padding='same', activation='relu',
              kernel_initializer=kernel_initializer, bias_initializer=bias_initializer)(x)
    x = Flatten()(x)
    x = Dense(unit, activation='relu',
              kernel_initializer=kernel_initializer, bias_initializer=bias_initializer)(x)
    x = Dropout(0.7)(x)
    x = Dense(num_class, activation='softmax', name=name)(x)
    return x

사이사이에 존재하는 gradient vanishing 문제를 해결하기 위한 Auxiliary Classifier.

Inception Module

Inception Module 구조

급격히 불어나는 연산량을 줄이기 위한 Bottleneck 구조를 사용.

def InceptionModule(inputs, f1, f3r, f3, f5r, f5, fp, name=None):
    conv_1x1 = Conv2D(f1, (1,1), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(inputs)

    x1 = Conv2D(f3r, (1,1), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(inputs)
    conv_3x3 = Conv2D(f3, (3,3), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(x1)

    x2 = Conv2D(f5r, (1,1), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(inputs)
    conv_5x5 = Conv2D(f5, (5,5), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(x2)

    x3 = MaxPool2D((3,3), strides=1, padding='same')(inputs)
    conv_pool = Conv2D(fp, (1,1), padding='same', activation='relu',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(x3)

    concat = Concatenate(axis=-1)([conv_1x1, conv_3x3, conv_5x5, conv_pool])
    return concat

GoogLeNet

GoogLeNet Architecture

위의 GoogLeNet 구조를 보며 구현했다.

def GoogLeNet():
    inputs = Input(shape=(224,224,3), name='inputs')
    x = Conv2D(64, (7,7), strides=2, padding='same', activation='relu', name='conv_1',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(inputs)
    x = MaxPool2D((3,3), strides=2, padding='same', name='maxpool_1')(x)
    x = BatchNormalization()(x)
    x = tf.nn.local_response_normalization(x, alpha=.0001, beta=.75, name='lrn_1')
    x = Conv2D(64, (1,1), strides=1, padding='valid', activation='relu', name='conv_2',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(x)
    x = Conv2D(192, (3,3), strides=1, padding='same', activation='relu', name='conv_3',
              kernel_initializer=kernel_initializer,
              bias_initializer=bias_initializer)(x)
    x = BatchNormalization()(x)
    x = tf.nn.local_response_normalization(x, alpha=.0001, beta=.75, name='lrn_2')
    x = MaxPool2D((3,3), strides=2, padding='same', name='maxpool_2')(x)

    x = InceptionModule(x, 64, 96, 128, 16, 32, 32, name='inception (3a)')
    x = InceptionModule(x, 128, 128, 192, 32, 96, 64, name='inception (3b)')
    x = MaxPool2D((3,3), strides=2, padding='same', name='maxpool_3')(x)

    x = InceptionModule(x, 192, 96, 208, 16, 48, 64, name='inception (4a)')
    aux1 = AuxiliaryClassifier(x, name='aux1')
    x = InceptionModule(x, 160, 112, 224, 24, 64, 64, name='inception (4b)')
    x = InceptionModule(x, 128, 128, 256, 24, 64, 64, name='inception (4c)')
    x = InceptionModule(x, 112, 144, 288, 32, 64, 64, name='inception (4d)')
    aux2 = AuxiliaryClassifier(x, name='aux2')
    x = InceptionModule(x, 256, 160, 320, 32, 128, 128, name='inception (4e)')
    x = MaxPool2D((3,3), strides=2, padding='same', name='maxpool_4')(x)

    x = InceptionModule(x, 256, 160, 320, 32, 128, 128, name='inception (5a)')
    x = InceptionModule(x, 384, 192, 384, 48, 128, 128, name='inception (5b)')
    x = GlobalAveragePooling2D(name='gap')(x)
    x = Dropout(0.4, name='dropout')(x)
    main = Dense(num_class, activation='softmax', name='main')(x)

    model = Model(inputs=inputs, outputs=[main, aux1, aux2])
    return model

Model Summary

model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              loss_weights={'main':1.0, 'aux1':0.3, 'aux2':0.3},
              metrics=['acc'])
model.summary()

'''
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
inputs (InputLayer)             [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv_1 (Conv2D)                 (None, 112, 112, 64) 9472        inputs[0][0]                     
__________________________________________________________________________________________________
maxpool_1 (MaxPooling2D)        (None, 56, 56, 64)   0           conv_1[0][0]                     
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 56, 56, 64)   256         maxpool_1[0][0]                  
__________________________________________________________________________________________________
tf.nn.local_response_normalizat (None, 56, 56, 64)   0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv_2 (Conv2D)                 (None, 56, 56, 64)   4160        tf.nn.local_response_normalizatio
__________________________________________________________________________________________________
conv_3 (Conv2D)                 (None, 56, 56, 192)  110784      conv_2[0][0]                     
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 56, 56, 192)  768         conv_3[0][0]                     
__________________________________________________________________________________________________
tf.nn.local_response_normalizat (None, 56, 56, 192)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
maxpool_2 (MaxPooling2D)        (None, 28, 28, 192)  0           tf.nn.local_response_normalizatio
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 28, 28, 96)   18528       maxpool_2[0][0]                  
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 28, 28, 16)   3088        maxpool_2[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D)  (None, 28, 28, 192)  0           maxpool_2[0][0]                  
__________________________________________________________________________________________________
conv2d_56 (Conv2D)              (None, 28, 28, 64)   12352       maxpool_2[0][0]                  
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 28, 28, 128)  110720      conv2d_57[0][0]                  
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 28, 28, 32)   12832       conv2d_59[0][0]                  
__________________________________________________________________________________________________
conv2d_61 (Conv2D)              (None, 28, 28, 32)   6176        max_pooling2d_9[0][0]            
__________________________________________________________________________________________________
concatenate_9 (Concatenate)     (None, 28, 28, 256)  0           conv2d_56[0][0]                  
                                                                 conv2d_58[0][0]                  
                                                                 conv2d_60[0][0]                  
                                                                 conv2d_61[0][0]                  
__________________________________________________________________________________________________
conv2d_63 (Conv2D)              (None, 28, 28, 128)  32896       concatenate_9[0][0]              
__________________________________________________________________________________________________
conv2d_65 (Conv2D)              (None, 28, 28, 32)   8224        concatenate_9[0][0]              
__________________________________________________________________________________________________
max_pooling2d_10 (MaxPooling2D) (None, 28, 28, 256)  0           concatenate_9[0][0]              
__________________________________________________________________________________________________
conv2d_62 (Conv2D)              (None, 28, 28, 128)  32896       concatenate_9[0][0]              
__________________________________________________________________________________________________
conv2d_64 (Conv2D)              (None, 28, 28, 192)  221376      conv2d_63[0][0]                  
__________________________________________________________________________________________________
conv2d_66 (Conv2D)              (None, 28, 28, 96)   76896       conv2d_65[0][0]                  
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 28, 28, 64)   16448       max_pooling2d_10[0][0]           
__________________________________________________________________________________________________
concatenate_10 (Concatenate)    (None, 28, 28, 480)  0           conv2d_62[0][0]                  
                                                                 conv2d_64[0][0]                  
                                                                 conv2d_66[0][0]                  
                                                                 conv2d_67[0][0]                  
__________________________________________________________________________________________________
maxpool_3 (MaxPooling2D)        (None, 14, 14, 480)  0           concatenate_10[0][0]             
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 14, 14, 96)   46176       maxpool_3[0][0]                  
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 14, 14, 16)   7696        maxpool_3[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_11 (MaxPooling2D) (None, 14, 14, 480)  0           maxpool_3[0][0]                  
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 14, 14, 192)  92352       maxpool_3[0][0]                  
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 14, 14, 208)  179920      conv2d_69[0][0]                  
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 14, 14, 48)   19248       conv2d_71[0][0]                  
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 14, 14, 64)   30784       max_pooling2d_11[0][0]           
__________________________________________________________________________________________________
concatenate_11 (Concatenate)    (None, 14, 14, 512)  0           conv2d_68[0][0]                  
                                                                 conv2d_70[0][0]                  
                                                                 conv2d_72[0][0]                  
                                                                 conv2d_73[0][0]                  
__________________________________________________________________________________________________
conv2d_76 (Conv2D)              (None, 14, 14, 112)  57456       concatenate_11[0][0]             
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 14, 14, 24)   12312       concatenate_11[0][0]             
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 14, 14, 512)  0           concatenate_11[0][0]             
__________________________________________________________________________________________________
conv2d_75 (Conv2D)              (None, 14, 14, 160)  82080       concatenate_11[0][0]             
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 14, 14, 224)  226016      conv2d_76[0][0]                  
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 14, 14, 64)   38464       conv2d_78[0][0]                  
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 14, 14, 64)   32832       max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
concatenate_12 (Concatenate)    (None, 14, 14, 512)  0           conv2d_75[0][0]                  
                                                                 conv2d_77[0][0]                  
                                                                 conv2d_79[0][0]                  
                                                                 conv2d_80[0][0]                  
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 14, 14, 128)  65664       concatenate_12[0][0]             
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 14, 14, 24)   12312       concatenate_12[0][0]             
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 14, 14, 512)  0           concatenate_12[0][0]             
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 14, 14, 128)  65664       concatenate_12[0][0]             
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 14, 14, 256)  295168      conv2d_82[0][0]                  
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 14, 14, 64)   38464       conv2d_84[0][0]                  
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 14, 14, 64)   32832       max_pooling2d_13[0][0]           
__________________________________________________________________________________________________
concatenate_13 (Concatenate)    (None, 14, 14, 512)  0           conv2d_81[0][0]                  
                                                                 conv2d_83[0][0]                  
                                                                 conv2d_85[0][0]                  
                                                                 conv2d_86[0][0]                  
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 14, 14, 144)  73872       concatenate_13[0][0]             
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 14, 14, 32)   16416       concatenate_13[0][0]             
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 14, 14, 512)  0           concatenate_13[0][0]             
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 14, 14, 112)  57456       concatenate_13[0][0]             
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 14, 14, 288)  373536      conv2d_88[0][0]                  
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 14, 14, 64)   51264       conv2d_90[0][0]                  
__________________________________________________________________________________________________
conv2d_92 (Conv2D)              (None, 14, 14, 64)   32832       max_pooling2d_14[0][0]           
__________________________________________________________________________________________________
concatenate_14 (Concatenate)    (None, 14, 14, 528)  0           conv2d_87[0][0]                  
                                                                 conv2d_89[0][0]                  
                                                                 conv2d_91[0][0]                  
                                                                 conv2d_92[0][0]                  
__________________________________________________________________________________________________
conv2d_95 (Conv2D)              (None, 14, 14, 160)  84640       concatenate_14[0][0]             
__________________________________________________________________________________________________
conv2d_97 (Conv2D)              (None, 14, 14, 32)   16928       concatenate_14[0][0]             
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 14, 14, 528)  0           concatenate_14[0][0]             
__________________________________________________________________________________________________
conv2d_94 (Conv2D)              (None, 14, 14, 256)  135424      concatenate_14[0][0]             
__________________________________________________________________________________________________
conv2d_96 (Conv2D)              (None, 14, 14, 320)  461120      conv2d_95[0][0]                  
__________________________________________________________________________________________________
conv2d_98 (Conv2D)              (None, 14, 14, 128)  102528      conv2d_97[0][0]                  
__________________________________________________________________________________________________
conv2d_99 (Conv2D)              (None, 14, 14, 128)  67712       max_pooling2d_15[0][0]           
__________________________________________________________________________________________________
concatenate_15 (Concatenate)    (None, 14, 14, 832)  0           conv2d_94[0][0]                  
                                                                 conv2d_96[0][0]                  
                                                                 conv2d_98[0][0]                  
                                                                 conv2d_99[0][0]                  
__________________________________________________________________________________________________
maxpool_4 (MaxPooling2D)        (None, 7, 7, 832)    0           concatenate_15[0][0]             
__________________________________________________________________________________________________
conv2d_101 (Conv2D)             (None, 7, 7, 160)    133280      maxpool_4[0][0]                  
__________________________________________________________________________________________________
conv2d_103 (Conv2D)             (None, 7, 7, 32)     26656       maxpool_4[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_16 (MaxPooling2D) (None, 7, 7, 832)    0           maxpool_4[0][0]                  
__________________________________________________________________________________________________
conv2d_100 (Conv2D)             (None, 7, 7, 256)    213248      maxpool_4[0][0]                  
__________________________________________________________________________________________________
conv2d_102 (Conv2D)             (None, 7, 7, 320)    461120      conv2d_101[0][0]                 
__________________________________________________________________________________________________
conv2d_104 (Conv2D)             (None, 7, 7, 128)    102528      conv2d_103[0][0]                 
__________________________________________________________________________________________________
conv2d_105 (Conv2D)             (None, 7, 7, 128)    106624      max_pooling2d_16[0][0]           
__________________________________________________________________________________________________
concatenate_16 (Concatenate)    (None, 7, 7, 832)    0           conv2d_100[0][0]                 
                                                                 conv2d_102[0][0]                 
                                                                 conv2d_104[0][0]                 
                                                                 conv2d_105[0][0]                 
__________________________________________________________________________________________________
conv2d_107 (Conv2D)             (None, 7, 7, 192)    159936      concatenate_16[0][0]             
__________________________________________________________________________________________________
conv2d_109 (Conv2D)             (None, 7, 7, 48)     39984       concatenate_16[0][0]             
__________________________________________________________________________________________________
max_pooling2d_17 (MaxPooling2D) (None, 7, 7, 832)    0           concatenate_16[0][0]             
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 4, 4, 512)    0           concatenate_11[0][0]             
__________________________________________________________________________________________________
average_pooling2d_3 (AveragePoo (None, 4, 4, 528)    0           concatenate_14[0][0]             
__________________________________________________________________________________________________
conv2d_106 (Conv2D)             (None, 7, 7, 384)    319872      concatenate_16[0][0]             
__________________________________________________________________________________________________
conv2d_108 (Conv2D)             (None, 7, 7, 384)    663936      conv2d_107[0][0]                 
__________________________________________________________________________________________________
conv2d_110 (Conv2D)             (None, 7, 7, 128)    153728      conv2d_109[0][0]                 
__________________________________________________________________________________________________
conv2d_111 (Conv2D)             (None, 7, 7, 128)    106624      max_pooling2d_17[0][0]           
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 4, 4, 128)    65664       average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
conv2d_93 (Conv2D)              (None, 4, 4, 128)    67712       average_pooling2d_3[0][0]        
__________________________________________________________________________________________________
concatenate_17 (Concatenate)    (None, 7, 7, 1024)   0           conv2d_106[0][0]                 
                                                                 conv2d_108[0][0]                 
                                                                 conv2d_110[0][0]                 
                                                                 conv2d_111[0][0]                 
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 2048)         0           conv2d_74[0][0]                  
__________________________________________________________________________________________________
flatten_3 (Flatten)             (None, 2048)         0           conv2d_93[0][0]                  
__________________________________________________________________________________________________
gap (GlobalAveragePooling2D)    (None, 1024)         0           concatenate_17[0][0]             
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1024)         2098176     flatten_2[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1024)         2098176     flatten_3[0][0]                  
__________________________________________________________________________________________________
dropout (Dropout)               (None, 1024)         0           gap[0][0]                        
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 1024)         0           dense_2[0][0]                    
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 1024)         0           dense_3[0][0]                    
__________________________________________________________________________________________________
main (Dense)                    (None, 10)           10250       dropout[0][0]                    
__________________________________________________________________________________________________
aux1 (Dense)                    (None, 10)           10250       dropout_2[0][0]                  
__________________________________________________________________________________________________
aux2 (Dense)                    (None, 10)           10250       dropout_3[0][0]                  
==================================================================================================
Total params: 10,335,054
Trainable params: 10,334,542
Non-trainable params: 512
__________________________________________________________________________________________________
'''

'ML > CNN' 카테고리의 다른 글

[CNN] DenseNet Implementation (Keras)  (0) 2021.08.10
[CNN] ResNet50 Implementation(Keras)  (0) 2021.08.09
[CNN] DenseNet 요약  (0) 2021.08.04
[CNN] ResNet 요약  (0) 2021.08.03
[CNN] GoogLeNet 요약  (0) 2021.08.02

댓글