티스토리 뷰

[업데이트 2017.08.30 12:45]

 

Supervised Learning중 output이 discrete한 classification 문제에 대해 Tensorflow로 작성해본 코드입니다.

여러개의 레이블에 대하여 분류하는 multinomial classification입니다.

 

참고로 Linear Regression으로는 class를 구별하는 문제에 대해서 입력 데이터에 따라 잘못된 결과가 나올 수 있는데,

Logistic Regression을 사용하면 올바른 결과를 도출할 수 있습니다.

 

sigmoid 함수(값이 0-1사이로 수렴) 또는 여러개의 class 분류에 특화된 softmax 함수를 Hypothesis에 적용 및 Cost 함수로 Cross-Entropy를 사용하여 성능을 끌어올릴 수 있습니다. (많은 데이터에 대해 Mean Square Error Cost함수에 보다 성능이 비슷하거나 좋다고 합니다) 

 

아래는 홍콩 과기대 김성훈 교수님의 코드를 약간 변형하여 테스트 해보았습니다.

 

- bias텀을 Input X Data에 추가

- Softmax가 적용된 트레이닝 데이터별 Hypothesis 확률값/추론된 Label Number 출력

- (다음 포스팅에서 업데이트 예정) Evaluation 코드 추가 예정(Skewed class 데이터에 대한 모델 검증 F Score, Precison, Recall, Accuracy, Learning Curve 등)

- (다음 포스팅에서 업데이트 예정) 데이터 입력 부분 수정(대량의 데이터도 입력 받을수 있도록)

 

항상 머신러닝 코드를 구현할 때 Input/Output 에 대한 행렬 계산시 차원값을 고려하는 생각을 먼저 하게 됩니다.

머신러닝을 처음 공부할 때 계산되는 행렬의 차원이 맞지 않아 오류 발생이 많았던 경험이 있습니다.

 

# 데이터의 class 개수는 7개

num_of_class = 7

 

# N는 트레이닐 데이터 레코드수, 입력값인 feature X 개수 16개 + 1개(bias 1 포함)                   
X = N x 17                          

 

# 기본 Y 입력값은 N x 1 행렬인데, 모델 학습을 위해 one-hot encoding을 수행 합니다.

Y = N x 1 -> N x 7 (one-hot encoding)                 

 

# 우리가 구해야할 weight값에 대한 행렬 차원, 입력값인 feature X 개수 17개(bias 1 포함) x 데이터의 class 개수 7개

theta = 17 x 7     

 

# Hypothesis의 최종 output 행렬의 차원은 N x 7
h(x) = softmax(X * theta), N x 7 

 

import tensorflow as tf
import numpy as np


def addBiasTerm(X):
    m = np.size(X, 0)
    n = np.size(X, 1)
    X_bias = np.ones([m, n + 1])

    for i in range(m):
        for j in range(n):
            X_bias[i, j + 1] = X[i, j]

    return X_bias


'''
num_of_class = 7

X = N x 16
Y = N x 1  -> one hot   N x 7
theta = 16 x 7

h(x) = softmax(X * theta) N x 7

'''

training_set = np.loadtxt('data-04-zoo.csv', delimiter=',', dtype=np.float32)
X_data = addBiasTerm(training_set[:, 0:-1])
y_data = training_set[:, [-1]]

# number of classes
num_of_classes = 7
feature_size = np.size(X_data, 1)

X = tf.placeholder(tf.float32, [None, feature_size]) # add bias +1
y = tf.placeholder(tf.int32, [None, 1])

# one-hot encoding
y_one_hot = tf.one_hot(y, num_of_classes)
y_one_hot = tf.reshape(y_one_hot, [-1, num_of_classes])

theta = tf.Variable(tf.random_normal([feature_size, num_of_classes]), name='weight')

logits = tf.matmul(X, theta)
hypothesis = tf.nn.softmax(logits) # Use softmax

cost_logits = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_one_hot)
cost = tf.reduce_mean(cost_logits)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)


prediction = tf.argmax(hypothesis, 1)
correct_prediction = tf.equal(prediction, tf.argmax(y_one_hot, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(init_op)

    for step in range(2000):
        sess.run(optimizer, feed_dict={X: X_data, y: y_data})
        if step % 100 == 0:
            loss, acc = sess.run([cost, accuracy], feed_dict={X: X_data, y: y_data})

            print("Step: {:5}\tLoss: {:.3f}\tAcc: {:.2%}".format(step, loss, acc))


    pred = sess.run(prediction, feed_dict={X: X_data})

    for p, y in zip(pred, y_data.flatten()):
        print("[{}] Prediction: {}, True Y: {}".format(p == int(y), p, int(y)))

    hypo = sess.run(hypothesis, feed_dict={X: X_data})  # for printing probabilities in each training data

    np.set_printoptions(formatter={'float': '{: 0.3f}'.format})

    for h in hypo:
        print("Probability: {}, Inferred Label: {}".format(h, sess.run(tf.argmax(h, 0))))

<테스트 결과>

아래 테스트 결과와 같이 모델의 훈련이 진행될수록 Cost 함수의 오차율이 줄어들면서 정확도가 개선됨을 확인할 수 있습니다.

그리고 Softmax가 적용됨으로써 각 트레이닝 데이터별로 예측함수 Hypothesis가 확률값의 총합이 100%가 되도록 계산이 되고 있음을 확인 할 수 있습니다.

tep:     0	Loss: 5.573	Acc: 37.62%
Step:   100	Loss: 0.729	Acc: 83.17%
Step:   200	Loss: 0.468	Acc: 88.12%
Step:   300	Loss: 0.354	Acc: 90.10%
Step:   400	Loss: 0.284	Acc: 92.08%
Step:   500	Loss: 0.236	Acc: 93.07%
Step:   600	Loss: 0.201	Acc: 95.05%
Step:   700	Loss: 0.174	Acc: 97.03%
Step:   800	Loss: 0.153	Acc: 97.03%
Step:   900	Loss: 0.136	Acc: 97.03%
Step:  1000	Loss: 0.123	Acc: 98.02%
Step:  1100	Loss: 0.112	Acc: 98.02%
Step:  1200	Loss: 0.102	Acc: 98.02%
Step:  1300	Loss: 0.094	Acc: 98.02%
Step:  1400	Loss: 0.087	Acc: 98.02%
Step:  1500	Loss: 0.081	Acc: 99.01%
Step:  1600	Loss: 0.076	Acc: 100.00%
Step:  1700	Loss: 0.071	Acc: 100.00%
Step:  1800	Loss: 0.067	Acc: 100.00%
Step:  1900	Loss: 0.063	Acc: 100.00%
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 3, True Y: 3
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 3, True Y: 3
[True] Prediction: 6, True Y: 6
[True] Prediction: 6, True Y: 6
[True] Prediction: 6, True Y: 6
[True] Prediction: 1, True Y: 1
[True] Prediction: 0, True Y: 0
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 1, True Y: 1
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 5, True Y: 5
[True] Prediction: 4, True Y: 4
[True] Prediction: 4, True Y: 4
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 5, True Y: 5
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 3, True Y: 3
[True] Prediction: 5, True Y: 5
[True] Prediction: 5, True Y: 5
[True] Prediction: 1, True Y: 1
[True] Prediction: 5, True Y: 5
[True] Prediction: 1, True Y: 1
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 6, True Y: 6
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 5, True Y: 5
[True] Prediction: 4, True Y: 4
[True] Prediction: 6, True Y: 6
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 1, True Y: 1
[True] Prediction: 1, True Y: 1
[True] Prediction: 1, True Y: 1
[True] Prediction: 3, True Y: 3
[True] Prediction: 3, True Y: 3
[True] Prediction: 2, True Y: 2
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 6, True Y: 6
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 2, True Y: 2
[True] Prediction: 6, True Y: 6
[True] Prediction: 1, True Y: 1
[True] Prediction: 1, True Y: 1
[True] Prediction: 2, True Y: 2
[True] Prediction: 6, True Y: 6
[True] Prediction: 3, True Y: 3
[True] Prediction: 1, True Y: 1
[True] Prediction: 0, True Y: 0
[True] Prediction: 6, True Y: 6
[True] Prediction: 3, True Y: 3
[True] Prediction: 1, True Y: 1
[True] Prediction: 5, True Y: 5
[True] Prediction: 4, True Y: 4
[True] Prediction: 2, True Y: 2
[True] Prediction: 2, True Y: 2
[True] Prediction: 3, True Y: 3
[True] Prediction: 0, True Y: 0
[True] Prediction: 0, True Y: 0
[True] Prediction: 1, True Y: 1
[True] Prediction: 0, True Y: 0
[True] Prediction: 5, True Y: 5
[True] Prediction: 0, True Y: 0
[True] Prediction: 6, True Y: 6
[True] Prediction: 1, True Y: 1
Probability: [ 0.998  0.000  0.001  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.001  0.001  0.004  0.992  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.998  0.000  0.001  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.995  0.000  0.004  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.003  0.002  0.028  0.949  0.018  0.000  0.000], Inferred Label: 3
Probability: [ 0.001  0.001  0.004  0.992  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.975  0.000  0.003  0.000  0.012  0.010  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.001  0.978  0.012  0.000  0.001  0.008  0.000], Inferred Label: 1
Probability: [ 0.001  0.001  0.004  0.992  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.003  0.017  0.007  0.079  0.011  0.002  0.880], Inferred Label: 6
Probability: [ 0.000  0.000  0.000  0.000  0.028  0.005  0.966], Inferred Label: 6
Probability: [ 0.000  0.000  0.000  0.000  0.033  0.025  0.942], Inferred Label: 6
Probability: [ 0.000  0.993  0.004  0.000  0.000  0.001  0.001], Inferred Label: 1
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.005  0.003  0.018  0.972  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.939  0.000  0.018  0.042  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.001  0.978  0.012  0.000  0.001  0.008  0.000], Inferred Label: 1
Probability: [ 0.001  0.989  0.002  0.000  0.003  0.003  0.002], Inferred Label: 1
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.001  0.991  0.007  0.000  0.000  0.001  0.000], Inferred Label: 1
Probability: [ 0.001  0.000  0.002  0.000  0.010  0.810  0.177], Inferred Label: 5
Probability: [ 0.006  0.001  0.030  0.000  0.937  0.002  0.024], Inferred Label: 4
Probability: [ 0.008  0.001  0.159  0.000  0.693  0.001  0.137], Inferred Label: 4
Probability: [ 0.993  0.000  0.005  0.000  0.000  0.002  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.002  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.000  0.000  0.000  0.000  0.000  0.995  0.004], Inferred Label: 5
Probability: [ 0.995  0.000  0.004  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.998  0.000  0.001  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.000  0.988  0.002  0.001  0.003  0.001  0.005], Inferred Label: 1
Probability: [ 0.001  0.002  0.008  0.985  0.003  0.000  0.000], Inferred Label: 3
Probability: [ 0.987  0.000  0.007  0.000  0.003  0.003  0.000], Inferred Label: 0
Probability: [ 0.991  0.000  0.004  0.000  0.001  0.003  0.000], Inferred Label: 0
Probability: [ 0.000  0.993  0.004  0.000  0.000  0.001  0.001], Inferred Label: 1
Probability: [ 0.001  0.001  0.004  0.992  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.001  0.000  0.002  0.000  0.001  0.972  0.025], Inferred Label: 5
Probability: [ 0.000  0.000  0.000  0.000  0.000  0.995  0.004], Inferred Label: 5
Probability: [ 0.002  0.910  0.021  0.002  0.009  0.001  0.055], Inferred Label: 1
Probability: [ 0.001  0.001  0.002  0.000  0.002  0.945  0.048], Inferred Label: 5
Probability: [ 0.000  0.989  0.005  0.000  0.000  0.006  0.000], Inferred Label: 1
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.000  0.000  0.000  0.000  0.033  0.025  0.942], Inferred Label: 6
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.001  0.000  0.002  0.000  0.000], Inferred Label: 0
Probability: [ 0.993  0.000  0.005  0.000  0.001  0.001  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.000  0.000  0.000  0.000  0.000  0.995  0.004], Inferred Label: 5
Probability: [ 0.019  0.006  0.194  0.000  0.768  0.001  0.011], Inferred Label: 4
Probability: [ 0.001  0.000  0.000  0.000  0.054  0.072  0.873], Inferred Label: 6
Probability: [ 0.993  0.000  0.005  0.000  0.001  0.001  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.006  0.948  0.035  0.000  0.004  0.001  0.005], Inferred Label: 1
Probability: [ 0.001  0.978  0.012  0.000  0.001  0.008  0.000], Inferred Label: 1
Probability: [ 0.006  0.865  0.014  0.003  0.031  0.000  0.081], Inferred Label: 1
Probability: [ 0.000  0.989  0.005  0.000  0.000  0.006  0.000], Inferred Label: 1
Probability: [ 0.005  0.003  0.018  0.972  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.001  0.001  0.004  0.992  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.006  0.013  0.972  0.005  0.003  0.000  0.001], Inferred Label: 2
Probability: [ 0.956  0.000  0.001  0.000  0.011  0.001  0.031], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.995  0.000  0.004  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.939  0.000  0.018  0.042  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.995  0.000  0.004  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.995  0.000  0.004  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.005  0.942  0.033  0.001  0.004  0.000  0.015], Inferred Label: 1
Probability: [ 0.006  0.000  0.045  0.000  0.009  0.203  0.737], Inferred Label: 6
Probability: [ 0.001  0.002  0.008  0.985  0.003  0.000  0.000], Inferred Label: 3
Probability: [ 0.995  0.000  0.001  0.004  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.998  0.000  0.001  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.074  0.021  0.471  0.279  0.145  0.000  0.011], Inferred Label: 2
Probability: [ 0.000  0.001  0.001  0.005  0.003  0.000  0.991], Inferred Label: 6
Probability: [ 0.000  0.988  0.002  0.001  0.003  0.001  0.005], Inferred Label: 1
Probability: [ 0.000  0.988  0.002  0.001  0.003  0.001  0.005], Inferred Label: 1
Probability: [ 0.021  0.064  0.833  0.060  0.022  0.000  0.001], Inferred Label: 2
Probability: [ 0.014  0.117  0.089  0.014  0.023  0.030  0.713], Inferred Label: 6
Probability: [ 0.001  0.002  0.008  0.985  0.003  0.000  0.000], Inferred Label: 3
Probability: [ 0.000  0.989  0.005  0.000  0.000  0.006  0.000], Inferred Label: 1
Probability: [ 0.994  0.000  0.005  0.000  0.001  0.000  0.000], Inferred Label: 0
Probability: [ 0.000  0.000  0.000  0.000  0.030  0.012  0.958], Inferred Label: 6
Probability: [ 0.014  0.006  0.217  0.759  0.003  0.000  0.001], Inferred Label: 3
Probability: [ 0.001  0.993  0.003  0.000  0.001  0.000  0.001], Inferred Label: 1
Probability: [ 0.001  0.000  0.002  0.000  0.010  0.810  0.177], Inferred Label: 5
Probability: [ 0.007  0.001  0.032  0.000  0.943  0.009  0.009], Inferred Label: 4
Probability: [ 0.102  0.259  0.487  0.000  0.096  0.029  0.027], Inferred Label: 2
Probability: [ 0.029  0.011  0.800  0.000  0.151  0.005  0.004], Inferred Label: 2
Probability: [ 0.005  0.003  0.018  0.972  0.002  0.000  0.000], Inferred Label: 3
Probability: [ 0.993  0.000  0.005  0.000  0.000  0.002  0.000], Inferred Label: 0
Probability: [ 0.991  0.000  0.004  0.000  0.001  0.003  0.000], Inferred Label: 0
Probability: [ 0.001  0.992  0.006  0.000  0.000  0.000  0.000], Inferred Label: 1
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.001  0.000  0.001  0.000  0.000  0.964  0.034], Inferred Label: 5
Probability: [ 0.997  0.000  0.003  0.000  0.000  0.000  0.000], Inferred Label: 0
Probability: [ 0.014  0.117  0.089  0.014  0.023  0.030  0.713], Inferred Label: 6
Probability: [ 0.000  0.989  0.005  0.000  0.000  0.006  0.000], Inferred Label: 1

 

* 참고: https://hunkim.github.io/ml/

* 참고: https://www.tensorflow.org

 

 

댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/04   »
1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30
글 보관함