티스토리 뷰
[업데이트 2019.03.18 12:49]
Caffe 딥러닝 프레임워크를 통해 학습을 수행할 때, train/test loss에 대해 그래프로 확인하며, 학습 모델을 튜닝해야 할때가 있습니다. 이번 포스팅에서는 Jupyter Notebook을 통해 train/test loss 그래프를 확인할 수 있는 방법에 대해 공유하고자 합니다.
1) 먼저 Caffe를 통해 학습을 수행 할 때, 로그 파일로 저장하도록 합니다.
../caffe/build/tools/caffe train --solver ../model/solver.prototxt 2>&1 | tee TT100K_TRAIN_90000.log
2) 학습이 완료되면 아래의 Jupyter Notebook 파일을 실행한 후, 로그 파일 위치를 지정합니다.
import numpy as np import re from cycler import cycler from matplotlib import pylab as plt %matplotlib inline def main(): #files = ['../script/TT100K_TRAIN_01.log', '../script/TT100K_TRAIN_02.log'] files = ['../script/TT100K_TRAIN_90000.log'] for i, log_file in enumerate(files): loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses = parse_log(log_file) draw_results(loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses, color_ind=i) def parse_log(log_file): with open(log_file, 'r') as log_file: log = log_file.read() loss_pattern = r"Iteration (?P\d+), loss = (?P [+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)" losses = [] loss_iterations = [] for r in re.findall(loss_pattern, log): loss_iterations.append(int(r[0])) losses.append(float(r[1])) loss_iterations = np.array(loss_iterations) losses = np.array(losses) test_loss_iterations_pattern = r"Iteration (?P \d+), Testing net \(#0\)\n" test_loss_pattern = r"Test loss: (?P [+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)" test_losses = [] test_loss_iterations = [] for r in re.findall(test_loss_iterations_pattern, log): test_loss_iterations.append(int(r)) for r in re.findall(test_loss_pattern, log): test_losses.append(float(r[0])) test_loss_iterations = np.array(test_loss_iterations) test_losses = np.array(test_losses) min_test_loss_index = np.argmin(test_losses) max_test_loss_index = np.argmax(test_losses) top_n = 5 topn_loss_index = test_losses.argsort()[:top_n] print(">>",log_file.name) print("Start Iteration: iteration=%6d, loss=%.6f" % (loss_iterations[0], test_losses[0])) print("End Iteration: iteration=%6d, loss=%.6f" % (loss_iterations[-1], test_losses[-1])) print("Min test loss > iteration=%6d, loss=%.6f" % (test_loss_iterations[min_test_loss_index], test_losses[min_test_loss_index]) ) print("Max test loss > iteration=%6d, loss=%.6f" % (test_loss_iterations[max_test_loss_index], test_losses[max_test_loss_index]) ) print("Top"+str(top_n)+" min test loss >") for i in topn_loss_index: print("iteration=%6d, loss=%.6f" % (test_loss_iterations[i], test_losses[i]) ) min_test_losses = [] min_test_loss_iterations = [] min_test_loss_iterations.append(test_loss_iterations[min_test_loss_index]) min_test_losses.append(test_losses[min_test_loss_index]) return loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses def draw_results(loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses, color_ind=0): axes_cycle = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] modula = len(axes_cycle) plt.figure() plt.plot(loss_iterations, losses, color=axes_cycle[0], label='train loss') plt.plot(test_loss_iterations, test_losses, color=axes_cycle[2], label='test loss') plt.plot(min_test_loss_iterations, min_test_losses, 'o', color=axes_cycle[3], label='min test loss') plt.xlabel("iteration") plt.ylabel("loss") plt.title("train loss vs test loss") plt.legend() plt.grid(True) plt.show() if __name__ == '__main__': main()
3) 실행하면 아래와 같이 그래프 및 관련 정보를 확인할 수 있습니다. 90,000번 학습후 확인한 train/test loss 그래프입니다.
아래의 GitHub는 제가 tsinghua-tencent 100k 테스트셋을 통해 실행한 결과입니다.
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- #REST API
- Badge
- Jekyll and Hyde
- Library
- #TensorFlow
- 2D Game
- 도커
- Memorize
- ILoop Engine
- sentence test
- Sea Bottom
- some time ago
- GOD
- Game Engine
- #ApacheZeppelin
- Worry
- Meow
- English
- #ELK Stack
- SSM
- Mask R-CNN
- Ragdoll
- Physical Simulation
- OST
- belief
- aws #cloudfront
- ate
- project
- #ApacheSpark
- docker
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
글 보관함