티스토리 뷰

[업데이트 2019.03.18 12:49]


Caffe 딥러닝 프레임워크를 통해 학습을 수행할 때, train/test loss에 대해 그래프로 확인하며, 학습 모델을 튜닝해야 할때가 있습니다. 이번 포스팅에서는 Jupyter Notebook을 통해 train/test loss 그래프를 확인할 수 있는 방법에 대해 공유하고자 합니다.


1) 먼저 Caffe를 통해 학습을 수행 할 때, 로그 파일로 저장하도록 합니다.

../caffe/build/tools/caffe train --solver ../model/solver.prototxt 2>&1 | tee TT100K_TRAIN_90000.log


2) 학습이 완료되면 아래의 Jupyter Notebook 파일을 실행한 후, 로그 파일 위치를 지정합니다.

import numpy as np
import re
from cycler import cycler
from matplotlib import pylab as plt
%matplotlib inline

def main():
    #files = ['../script/TT100K_TRAIN_01.log', '../script/TT100K_TRAIN_02.log']
    files = ['../script/TT100K_TRAIN_90000.log']
  
    for i, log_file in enumerate(files):
        loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses = parse_log(log_file)
        draw_results(loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses, color_ind=i)


def parse_log(log_file):
    with open(log_file, 'r') as log_file:
        log = log_file.read()

    loss_pattern = r"Iteration (?P\d+), loss = (?P[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)"
    losses = []
    loss_iterations = []

    for r in re.findall(loss_pattern, log):
        loss_iterations.append(int(r[0]))
        losses.append(float(r[1]))

    loss_iterations = np.array(loss_iterations)
    losses = np.array(losses)

    test_loss_iterations_pattern = r"Iteration (?P\d+), Testing net \(#0\)\n"
    test_loss_pattern = r"Test loss: (?P[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)"
    test_losses = []
    test_loss_iterations = []

    for r in re.findall(test_loss_iterations_pattern, log):
        test_loss_iterations.append(int(r))
    
    for r in re.findall(test_loss_pattern, log):
        test_losses.append(float(r[0]))
    
    test_loss_iterations = np.array(test_loss_iterations)
    test_losses = np.array(test_losses)
    
    min_test_loss_index = np.argmin(test_losses)
    max_test_loss_index = np.argmax(test_losses)
    
    top_n = 5
    topn_loss_index = test_losses.argsort()[:top_n]

    print(">>",log_file.name)
    print("Start Iteration: iteration=%6d, loss=%.6f" % (loss_iterations[0], test_losses[0]))
    print("End   Iteration: iteration=%6d, loss=%.6f" % (loss_iterations[-1], test_losses[-1]))
    print("Min test loss > iteration=%6d, loss=%.6f" % (test_loss_iterations[min_test_loss_index], test_losses[min_test_loss_index]) )
    print("Max test loss > iteration=%6d, loss=%.6f" % (test_loss_iterations[max_test_loss_index], test_losses[max_test_loss_index]) )
    print("Top"+str(top_n)+" min test loss >")
    
    for i in topn_loss_index:
        print("iteration=%6d, loss=%.6f" % (test_loss_iterations[i], test_losses[i]) )
    
    min_test_losses = []
    min_test_loss_iterations = []
    
    min_test_loss_iterations.append(test_loss_iterations[min_test_loss_index])
    min_test_losses.append(test_losses[min_test_loss_index])

    return loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses


def draw_results(loss_iterations, losses, test_loss_iterations, test_losses, min_test_loss_iterations, min_test_losses, color_ind=0):
    axes_cycle = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    modula = len(axes_cycle)
        
    plt.figure()
    plt.plot(loss_iterations, losses, color=axes_cycle[0], label='train loss')
    plt.plot(test_loss_iterations, test_losses, color=axes_cycle[2], label='test loss')
    plt.plot(min_test_loss_iterations, min_test_losses, 'o', color=axes_cycle[3], label='min test loss')
    plt.xlabel("iteration")
    plt.ylabel("loss")
    plt.title("train loss vs test loss")
    plt.legend()
    plt.grid(True)
    
    plt.show()

if __name__ == '__main__':
    main()

3) 실행하면 아래와 같이 그래프 및 관련 정보를 확인할 수 있습니다. 90,000번 학습후 확인한 train/test loss 그래프입니다.



아래의 GitHub는 제가 tsinghua-tencent 100k 테스트셋을 통해 실행한 결과입니다.


* GitHub: https://github.com/asyncbridge/tsinghua-tencent-100k/blob/master/code/python/my-eval-loss-graph.ipynb

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함