본문 바로가기
개발

Keras Conv2d CNN 간단한 예제 ( Mnist data)

by 화악 2022. 5. 24.
반응형

인공지능하면 제일처음에 해보는 Mnist 손글씨 맞추기 kears Conv2d로 간단하게 작성된 코드를 소개하려고 합니다. 이미지 분류에 자주 쓰이는 CNN은 Keras로 간단하게 모델 구성을 할 수 있습니다. 먼저 코드부터 보시죠

 

 

코드 메인 영역


 

import numpy as np
import pandas as pd
import keras
import tensorflow
from tensorflow.keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten
from keras.datasets import mnist


train_samples, train_labels, test_samples, test_labels = loadData() #mnist data 다운

train_samples = convertDtype(train_samples) #data type 변경 float32
test_samples = convertDtype(test_samples)

train_samples = normalize(train_samples) #정규화
test_samples = normalize(test_samples)

np.isclose(np.amax(train_samples), 1)  # 허용오차를 계산하는 numpy함수

train_samples = reshape(train_samples)  # train data와 test data 학습 폼에 맞게 shape 변경
test_samples = reshape(test_samples)

train_labels = oneHot(train_labels, 10) # 원핫-인코딩
test_labels = oneHot(test_labels, 10)

#모델 정의부분 CNN CONV2D 사용
model = Sequential()
model.add(Conv2D(64, kernel_size=3, activation='relu', input_shape=(28,28,1))) #커널 사이즈는 3 활성화 함수는 relu를 사용 , mnist data 포맷 28x28
model.add(Conv2D(32, kernel_size=3, activation='relu')) #cnn layer 2
model.add(Flatten())   #Flatten Layer
model.add(Dense(10, activation='softmax')) #softmax Layer

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # Adam 방식으로 최적화, loss는 categorical_crossentropy 선택
model.summary() #모델 구성 요약

# 학습모델 하이퍼파라미터 지정
results = model.fit(train_samples, train_labels, validation_split = 0.1, epochs=10, batch_size=250)  #validation split 0.1 -> 테스트 data 비율 , 에폭은 10 , 배치사이즈는 250으로 구성

predictedLabelsTrain = predict(train_samples)   #예측

acc = accuracy(train_samples, train_labels, model)   #train 정확도 출력
print('Train accuracy is, ', acc * 100, '%')


acc = accuracyT(test_samples, test_labels, model)  # test 정확도 출력
print('Test accuracy is, ', acc * 100, '%')

 

 

각 함수 영역 코드


#mnist data를 불러오는 함수
def loadData():
  (train_samples,train_labels), (test_samples,test_labels) = mnist.load_data() # mnist 데이터를 keras 라이브러리로 다운로드후 train test에 대입
  return train_samples, train_labels, test_samples, test_labels

#data type을 float32 로 변환
def convertDtype(x):
    x_float = x.astype('float32')
    return x_float
#정규화 함수
def normalize(x):
  y = (x - np.min(x))/np.ptp(x)
  return y

# train data와 test data 학습 폼에 맞게 shape 변경
def reshape(x):
    x_r = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1)
    return x_r
# 원핫 인코딩을 하는 함수 0,1 로 데이터 구별
def oneHot(y, Ny):

    Ny = len(np.unique(y))
    y_oh = to_categorical(y, num_classes=Ny) #원-핫 인코딩을 수행하는 keras 함수
    return y_oh

# 모델 예측함수
def predict(x):
    y = model.predict(x)
    return y

# train data의 학습 정확성을 출력하는 함수
def accuracy(x_train, y_train, model):
    loss, acc = model.evaluate(train_samples, train_labels, verbose=0)
    return acc
# train 된 모델로 test data의 학습 정확성을 출력하는 함수
def accuracyT(x_test, y_test, model):
    loss, acc = model.evaluate(test_samples, test_labels, verbose=0)
    return acc

 

 

keras CNN 모델 간단설명


모델구성

 

model.add(Conv2D(64, kernel_size=3, activation='relu', input_shape=(28,28,1))) #커널 사이즈는 3 활성화 함수는 relu를 사용 , mnist data 포맷 28x28
model.add(Conv2D(32, kernel_size=3, activation='relu')) #cnn layer 2
model.add(Flatten())   #Flatten Layer
model.add(Dense(10, activation='softmax')) #softmax Layer

 

 

model.add로 Conv2D를 사용했습니다.

 

1. 첫 인자로 64개의 필터를 사용한다는 의미이고

2. kenel_size 는 필터의 크기를 뜻하며 3으로 지정

3. input shape 는 mnist 데이터 shpae인 28,28로 설정

4. 활성화 함수로는 softmax , relu를 사용 하여 간단하게 구현

 

 

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # Adam 방식으로 최적화, loss는 categorical_crossentropy 선택

 

손실함수는 교차엔트로피 사용  옵티마이저는 adam을 사용하였습니다.

 

 

하이퍼파라미터

 

results = model.fit(train_samples, train_labels, validation_split = 0.1, epochs=10, batch_size=250)  #validation split 0.1 -> 테스트 data 비율 , 에폭은 10 , 배치사이즈는 250으로 구성

 

epoch과 batch_szie는 하이퍼파라미터로 직접 여러가지를 시험해보시는걸 추천드립니다. epochs이 늘어날 수록 학습 정확도가 올라 갈 수 있으나 어느 순간부터 수렴하는 것을 보실 수 있을겁니다. 과적합 방지를 위해 dropout을 쓰는 경우도 있는데 이 코드에서는 사용하지 않았습니다. 대부분 코드에 주석을 달아 놨으니 천천히 디버깅 해보면서 확인하시면 처음보는 비전공자도 이해하기 쉬우실겁니다.

반응형

댓글