본문 바로가기
Python

[Python] Tensorflow 분류모델에서 훈련결과 추론해보기

by teamnova 2023. 1. 6.

안녕하세요.

오늘은 텐서플로우에서 학습이 끝난 모델을 가져와서 결과를 추론하는 예제를 진행하겠습니다.

 

사용하는 라이브러리를 선언해줍니다.

import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow

 

구글 드라이브 마운트를 하신 후, 다음과 같이 tf.keras.models.load_model() 를 사용해서 저장된 모델을 불러오시면 됩니다.

dir = "/content/drive/Mydrive/스틱코드/저장된가중치.h5"
model = tf.keras.models.load_model(dir)

 

test_dir 에 추론할 이미지 경로를 넣어서 보여줍니다.

test_dir = '/content/테스트이미지폴더경로/테스트이미지.확장자'
img = cv2.imread(test_dir)
dst = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_AREA)
cv2_imshow(dst)
sample_square = np.array(dst)
pred = model.predict(tf.expand_dims(sample_square, 0))
pred = pred[0]
print(pred)

 

pred를 출력하시면 다중분류일 경우 클래스 순서대로 테스트 이미지의 예측값이 표현됩니다.

클래스명에 해당하는 pred의 위치값을 지정해서 보기 편하게 출력하면 다음과 같습니다.

바로 pred[] 값을 출력하면 보기 어려울 수 있어서 round를 통해 반올림을 하겠습니다.

daisy = round(pred[0], 3)
dandelion =  round(pred[1], 3)
roses =  round(pred[2], 3)
sunflowers =  round(pred[3], 3)
tulips =  round(pred[4], 3)

print("daisy : ", daisy)
print("dandelion : ", dandelion)
print("roses : ", roses)
print("sunflowers : ", sunflowers)
print("tulips : ", tulips)