본문 바로가기
Python

[Python] Tensorflow 에서 모델 저장하고 불러오기

by teamnova 2023. 2. 16.

안녕하세요.

오늘은 Tesorflow 모델을 저장하고 불러오는 예제를 진행하겠습니다.

 

우선, 모델을 저장하는 방식에는 크게 두가지 방향이 있습니다.

모델을 통째로 저장하는 방법과 가중치만 저장하는 방법입니다.

 

먼저, 저장에 필요한 기본 라이브러리를 추가합니다.

import tensorflow as tf
from tensorflow import keras

 

아래 코드에서 model에 저장하고자 하는 모델이 들어있다는 가정하에 진행합니다.

경로는 저장하고자 하는 폴더의 경로를 입력하면 되고, 가중치 저장의 경우 몇번째 epoch의 가중치인지 파일명으로 구분하기 쉽게 할 수 있습니다. 상대적으로 save가 모델을 통째로 저장하기 때문에 용량이 크고 시간이 좀 더 소요됩니다.

# 1. 모델 전체 저장
model.save('./경로/저장하고자 하는 폴더명')

# 2. 가중치 저장
model.save_weights('./경로/폴더명/몇번째 epoch')

 

모델을 불러오는 것은 load_model 을 통해 불러오고자 하는 모델의 경로를 입력하면됩니다.

가중치 불러오기는 모델이 이미 구성되어 있는 가정하에 가중치만 불러와서 업데이트 해주면 됩니다.

# 1. 모델 불러오기
model = keras.models.load_model('./경로/불러올 모델의 폴더명')

# 2.가중치만 불러오기
model = Model()
model.load_weight('./경로/불러올 모델가중치의 파일명')