Classification Tree

Decision Tree(결정 트리)란?

Decision Tree는 supervised learning 중의 하나로 연속적인 질문을 통해서 예측하는 모델이다.

이 때, 최종적으로 분류를 예측하면 Classification Tree, 값을 예측하면 Regression Tree가 된다.

Tree라는 이름에 걸맞게 root node부터 시작해서 가지(branch)를 치며 여러가지 node로 나뉘어 지게 된다.

마지막에 예측을 하는 node는 leaf라고 부른다.

이 때, 각 node에는 각 변수에 관한 질문들이 있고, 이에 대한 대답을 통해 branch로 나뉘어지게 된다.

목표는 각 node의 질문을 잘 설정해서 마지막 분류 혹은 값의 예측이 실제 분류나 값과 일치하거나 에러가 적도록 만들어야 한다.

위의 표에서 보이는 바와 같이 세가지 변수를 가지고 Cool As Ice라는 제품을 좋아하는지 분류하는 모델을 만드는 과정을 살펴보자.

 

1. 어떤 변수를 가지고 root node를 만들어야 할까?

가장 먼저 하는 질문이므로 가장 결과를 잘 분류할 수 있도록 변수를 설정해야 할 것이다.

먼저 "Loves Popcorn"이라는 항목과 "Loves Soda"을 가지고 분류를 한다고 생각해보자.

과연 둘 중에 누가 분류를 더 잘한 것일까?

직관적으로 "Love Soda?"라는 질문을 가지고 했을 때 좀 더 분류를 뚜렷하게 하는 것 같다.

하지만 이걸 어떻게 수치화할까?

 

이 때 우리는 Gini Impurity라는 방법을 사용하는데 Gini Impurity가 낮은 쪽으로 (즉, pure하게 나뉠 수 있도록) 가지를 나눠야 한다.

 

Gini Impurity 계산하기

Leaf node에서 Gini Impurity는 1에서 각 분류의 비율을 제곱해서 빼면 된다.

맨 왼쪽의 경우 yes는 1/4, no는 3/4이므로 

$$ 1 - (\frac{1}{4})^2 - (\frac{3}{4})^2 = 0.375 $$

 

Leaf node가 아닌 node에서는 child node들의 Gini Impurity를 데이터의 비율대로 더해주면 된다.

Love Popcorn?의 경우를 살펴보면

$$ \frac{4}{7}(1 - (\frac{1}{4})^2 - (\frac{3}{4})^2) + \frac{3}{7}(1 - (\frac{2}{3})^2 - (\frac{1}{3})^2) = 0.405 $$

 

Love Soda?의 경우를 살펴보면

$$ \frac{4}{7}(1 - (\frac{3}{4})^2 - (\frac{1}{4})^2) + \frac{3}{7}(1 - (\frac{3}{3})^2 - (\frac{0}{3})^2) = 0.214 $$

 

위의 수식과 같이 "Love Popcorn?"을 사용하면 Gini Impurity는 0.405, "Love Soda?"는 0.214라는 값이 나온다.

따라서 우리의 직관과 마찬가지로 "Love Soda?"를 가지고 분류하는 것이 더 좋다.

 

Q2. 그렇다면 카테고리가 아닌 값으로 되어 있는 변수는 어떻게 활용할까?

Age 같은 경우에는 카테고리가 아닌 값으로 되어 있기 때문에 분류에 활용하기 위해서는 경계값(threshold)를 정해주어야 한다.

즉, Age를 오름차순으로 정렬했을 때, 그 중간값을 가지고 노드를 만들게 된다.

예를 들면 7과 12의 중간값인 (Age > 9.5?)라는 노드를 만들어서 가지를 만들게 된다.

그러면 카테고리 값과 마찬가지로 Gini Impurity를 구할 수 있을 것이다.

 

데이터가 N개 있으면 N-1개의 중간값을 만들 수 있는데 이를 전부 체크해서 최소 Gini Impurity를 만드는 중간값을 경계값으로 정한다.

각 중간값을 경계로 했을 때의 Gini Impurity

 

최종적으로 세 변수로 만든 Node 중 Gini Impurity가 가장 낮은 것을 골라 root node를 만든다.

 

위의 예제에서는 Love Popcorn? = 0.405,  Love Soda? = 0.213, Age > 15? = 0.343 이므로 Love Soda?를 root node로 사용하면 된다.

 

root node 이후에는 어떻게 해야할까?

root node 이후에도 이와 같은 방식으로 계속 가지를 만들어 내려가면 된다.

노드 하나당 가지를 몇개로 할 것인지(보통 2개이다), root로부터 얼마만큼 깊게 가지를 내릴 것인지 등은 하이퍼파라미터로 정할 수 있다.

 

 

들어가면서

다양한 자료 구조 (특히 회로)들은 그래프로 표현하는 것이 직관적으로 더 좋고 성능도 더 잘 나온 것으로 알려져서 스탠포드에서 graph representation leraning 강의가 올라와 있어서 이를 공부하며서 정리하고자 한다.

https://web.stanford.edu/class/cs224w/index.html

 

CS224W | Home

Content What is this course about? Complex data can be represented as a graph of relationships between objects. Such networks are a fundamental tool for modeling social, technological, and biological systems. This course focuses on the computational, algor

web.stanford.edu

 

왜 그래프를 사용해야 하는가?

기존의 Machine learning 및 Deep learning의 경우 grid(image), sequence(text나 audio) 형태의 input만 주로 활용했지만 세상에는 grid, sequence로 표현되지 않는 것도 많다.

그래프는 relation, interaction을 표현하기 좋은 수단으로 더 나은 prediction 결과를 낼 수 있다.

 

그래프 표현의 예

그래프를 활용한 Neural Network의 모습

그래프를 직접 convolution할 수 있다.
그래프 정보 (node, edge, graph)를 임베딩으로 변환하여 활용할 수 있다.
Graph structure를 활용하면 간접적으로 어떤 feature가 중요한지 알아낼 수 있다.

 

그래프를 가지고 할 수 있는 task들

다양한 level에서 task를 수행할 수 있다.

  • Node classification: 특정 노드가 어떤 특징(feature)를 가지고 있는지 예측 (online user/item categorization)
  • Link prediction: 두 노드간 link가 있을 것인지 없을 건인지 관계를 에측 (Knowledge graph completion, recommendation system, drug side effect prediction)
  • graph classification: 그래프가 어떤 property를 가질 것인지 구분 (molecule/drug property prediction)
  • clustering: 특정 sub-graph를 찾아낸다. (social circle prediction, traffic prediction)
  • graph generation: 특정 조건을 만족하도록 새로운 그래프를 생성 (Drug discovery)
  • graph evolution: 그래프가 어떻게 변화해 갈 것인지 예측 (physical simulation)

 

그래프의 구성 요소

Node, Edge, Graph

 

그래프를 표현하는 방법들

(1) Adjacency matrix

각 노드에 index를 부여하고 index의 개수만큼 row와 column을 가지는 매트릭스를 만든다.

- undirected graph일 경우 adjacency matrix는 symmetric하다.

- 만약 edge에 가중치를 준다면 1이 아닌 다른 숫자를 넣으면 된다.

 

성질

1. row (혹은  column) 값을 다 더하면 해당 노드의 degree가 된다.

$$ k_i = \sum_{j}^{N}A_{ij} $$

보통 node의 개수에 비해 edge의 개수가 훨씬 적으므로 "sparse matrix"가 된다 => memory, computation에서 손해

 

(2) Edge list

(따로 떨어져 있는 노드 없이 노드들이 모두 연결되어 있다는 가정 하에) edge들의 list만으로 그래프를 표현할 수 있다.

ex: [(1,2), (2,3), (1,3), (3,4)]

 

(3) Adjacency list

각 노드들의 neighbor들을 표기해준다.

1: 2,3

2: 1,3

3:, 1,2,4

4: 3

 

다양한 그래프들

  • Directed Vs. Undirected: edge에 방향성이 있나 없나
  • Heterogenous graph G = (V, E, R, T)
  • Bipartite graph: 노드가 두 그룹 (U, V)로 나뉘어져 edge가 같은 그룹내에서는 없고, 서로 다른 두 그룹 사이에만 존재 (user와 iterm, actors & movies
  • weighted vs unweighted
  • self-edge를 허용할 것인가? multiple edge를 허용할 것인가?

'배움 > Graph Representation Learning' 카테고리의 다른 글

9. How expressive are GNN?  (0) 2022.11.02
8. GNN Training  (2) 2022.10.30
7. GNN  (0) 2022.10.13

+ Recent posts