Decision Transformer를 배우기 위해 일단 Transformer부터 공부해 본다.

빛갓갓 동빈나 님이 체계적으로 잘 정리해주신 자료가 있어 덕분에 이해가 금방 되었다.

https://youtu.be/AA621UofTUA

동빈나님이 설명하신 동영상에 seq2seq 모델이 나오는데 이에 관한 내용은 미리 다음 자료로 공부했었기 때문에 바로 이해할 수 있었다.

https://wikidocs.net/24996

 

트랜스포머의 핵심 구조 익히기

트랜스포머는 일반적인 언어 모델처럼 Encoder와 Decoder로 구성되어 있다. Encoder는 여러 layer의 반복으로 이루어져 있고, Decoder 또한 여러 layer의 반복으로 이루어져 있다.

Encoder, Decoder를 이루고 있는 단위 layer에는 Self-attention layer라는 것이 들어가는데 이를 잠시 뒤에 자세히 살펴볼 것이다.

이 외에도 Positional Encoding이라는 것이 들어가는데 이것도 나중에 설명이 될 것이다.

언어 모델이기 때문에 Input을 Output으로 translate해준다.

이 때, Input은 문장이고, 문장을 구성하는 각각의 단어들을 Input Embedding layer를 통해 특정 길이의 Embedding으로 변환된다.

그리고 단위 레이어를 지나면서 인풋과 똑같은 크기의 아웃풋 데이터가 나오게 된다.

오른쪽 그림을 보면 알겠지만 Input이 그대로 Output으로 연결되는 Residual connection 또한 가지고 있다.

인코더의 마지막 레이어에서 나온 아웃풋을 모든 디코더 레이어에서 어텐션 메카니즘에 사용하기 위해 인풋으로 넣어준다.

 

 

Attention mechanism

트랜스포머를 설명할 때 attention mechanism을 빼놓을 수 없다.

attention mechanism이라는 것은 문장을 구성하고 있는 각 단어(혹은 토큰)들이 문장 내의 다른 단어와 얼마나 연관성이 있는지를 계산하여 가중치로 만들어서 각 단어들의 임베딩을 조합해서 최종 임베딩을 만드는 것이다.

 

예를 들면 I am a boy라는 문장이 있고, 이를 단어로 쪼개면 ["I", "am", "a", "boy"]이다.

"I"라는 단어의 임베딩 출력을 구하기 위해서는 I와 다른 단어와의 유사성을 찾게 되는데 이 유사성이 [0.7, 0.05, 0.05, 0.2]로 나온다고 하면 임베딩 출력은 각 단어들의 임베딩에 위의 유사성 숫자를 가중치로 곱해준 값들의 합이 되게 된다.

 

트랜스포머는 이를 Query (Q), Key (K), Value (V)라는 개념으로 설명한다.

하나의 Q에 해당하는 임베딩을 찾기 위해 각 K와의 유사성을 MatMul을 이용해서 체크한다.

그리고 이를 통해 가중치 (energy라고 표현하더라)를 구하고 가중치만큼 Value를 더해줘서 최종적인 Q의 임베딩을 만들어 낸다.

 

attention mechanism은 seq2seq 모델에도 나오는데 트랜스포머에서는 "self-attention"이라는 것이 중요하다.

즉, attention을 만들 때 자기 정보만 가지고 만든다는 것이다.

seq2seq 모델을 기억해보면 decoder의 attention 정보를 만들기 위해 encoder의 정보를 활용해서 가중치를 만든다.

트랜스포머의 인코더에서는 Q, K, V가 같은 문장에서 다 나오게 된다.

 

V, K, Q를 직접 쓰는 것이 아니라 MLP layer를 거쳐서 만들어낸 값을 활용하는데, h개 만큼의 MLP layer를 병렬로 사용해서 다양한 attention을 만들고 임베딩을 만들어서 이를 조합해서 사용한다.

 

 

좀더 과정을 자세히 살펴보면 다음과 같다.

 

 

일단 한 단어의 임베딩을 가지고 Query, Key, Value를 다 만들 것이다.

나중에 여러 head를 concat해서 사용할 것이기 때문에 embedding의 길이를 head의 개수로 나눈 vector 길이로 나오게 된다.

 

그렇게 만들어진 Q, K, V를 가지고 Attention을 구한다.

Q,K는 Matrix multiplication을 통해 (dot product처럼) 유사도를 구하고 거기에 Value를 곱한다.

dk로 Normalize하는 이유는 backprob을 위해 softmax의 gradient가 큰 지점에 위치시키기 위함이다.

위에서는 하나의 단어에 대해서만 Attention을 만들었지만 사실 여러개의 단어를 동시에 만들게 된다.

 

embedding matrix와 attention energy matrix를 곱해서 최종 attention을 구해주게 되는데 임의로 attention energy matrix에 마스크 행렬을 곱해서 특정 단어와의 연관성을 없앨 수도 있다.

(이는 Decoder에서 쓰이게 되는데 Decoder에서는 단어를 전부 참조하는 것이 아니라 이미 나온 단어들만 참조하게 만들고 싶기 때문에 미래의 단어들을 참조하는 attention은 마스킹하게 된다.)

 

위에서 말했듯이 Multi-head attention이기 때문에 각 head에서 나온 attention들을 concat하고 MLP를 통해서 최종 출력을 만든다.

 

Positional Encoding

트랜스포머는 RNN이나 LSTM을 활용하는 Seq2Seq 모델처럼 단어를 순차적으로 받는 것이 아니라 문장이 한꺼번에 들어간다.

따라서 단어의 위치 정보를 넣어줘야 하는데 이를 하기 위해 Positional Encoding이라는 것을 활용한다.

식은 다음과 같다.

 

임베딩의 size와 position을 늘리면 다음과 같이 나온다.

 

'배움 > Reinforcement Learning' 카테고리의 다른 글

10. Knowledge Graph  (0) 2022.11.29
random seed 설정하기  (0) 2022.09.24
Policy invariance under reward transformation  (0) 2022.08.18

이 챕터에서는 GNN이 그래프 안의 노드들을 얼마나 잘 구별하여 노드 임베딩을 만들어줄 수 있는지 이론적으로 알아보는 시간이다.

이 챕터를 다 배우게 되면 어떤 GNN 구조가 가장 expressive power가 좋은지 알게 될 것이다.

 

GNN은 어떻게 노드를 구별할까?

node들에 따로 id 같은 feature가 없다고 할 때 graph structure만 가지고 노드들을 얼마나 잘 구별할 수 있을까?

위와 같은 구조를 보면 1,2,3,4,5는 노드 피쳐가 없다. 즉, 구별할 수 없는 노드이기 때문에 노드가 그래프 내에 어디에 달려있는지 그 위치 만으로 구별해야 한다.

직관적으로 봤을 때, 1번 노드와 2번 노드는 같게 취급이 될 것이라는 걸 알 수 있다.

 

노드가 구별하기 위해서는 각 노드별로 만들어지는 computational graph의 구조가 달라야 한다.
1번, 2번 노드의 computational graph를 아래와 같이 그렸을 때 둘의 구조가 정확히 일치하는 것을 알 수 있다.

(노드 번호가 다른 것은 ID가 노드 피쳐로 들어가지 않았으므로 실제로는 구별할 수 없다.)

아래와 같이 모든 노드들에 대해 computational graph(rooted subgraph structure)를 그려봤을 때 서로 다른 computational graph를 가진 노드들만이 서로 다른 임베딩을 갖게 된다.

결국 GNN의 역할은 computational graph를 임베딩으로 매핑하게 되는데,

다른 computational graph는 다른 embedding이 나오기 위해서는 매핑이 "injective" 해야 한다.

그리고 매핑이 injective하기 위해서는 computational graph의 내부 과정인 neighbor aggregation 또한 injective 해야 한다.

 

GNN이 expressive하지 못할 때

위에서 GNN이 expressive하기 위해서는 neighbor aggregation이 injective해야 한다고 했다.

그렇다면 과연 어떤 neighbor aggregation이 좋을까?

(우리가 하는 neightbor aggregation의 대상은 multi-set이라고 할 수 있다. 여기서 multi-set은 같은 item이 반복될 수 있는 set을 말한다)

 

Mean pooling failure case (GCN)

Max pooling failure case (GraphSAGE)

 

The most expressive GNN - Graph Isomorphism Network (GIN)

GIN을 설명하기에 앞서 다음과 같은 두 개념을 알아두자.

Injective multi-set function

Universal Approximation Theorem

1개의 hidden layer를 가진 MLP가 충분히 크면 어떤 continuous function이던 임의의 정확도로 근사할 수 있다.

 

위 두개의 theorem에 따라 Injective한 aggregation function을 만들면 다음과 같다.

즉, injective한 동시에 MLP 파트들을 잘 학습시켜서 뒤이어 나오는 downstream task 또한 잘 수행될 수 있도록 학습시킬 수 있다.

 

GIN과 WL Graph Kernel의 유사성

우리는 챕터2에서 그래프를 구별할 수 있는 WL Graph Kernel에 대해 배웠다.

WL Graph Kernel도 주변 노드에서 정보를 합친 뒤 (sum) Hash를 통해 새로운 ID로 매핑한다.

이는 GIN이 하는 일과 굉장히 유사한데 K-step의 GIN operation을 반복하면 K-hop neighborhood 구조를 summarize하는 것과 같은 효과이다.

 

Pooling의 expressive power 순위

1. Sum (GIN) - multiset 구별 가능

2. Mean (GCN) - distribution 구별 가능

3. Max (GraphSAGE) - set 구별 가능

'배움 > Graph Representation Learning' 카테고리의 다른 글

8. GNN Training  (2) 2022.10.30
7. GNN  (0) 2022.10.13
[CS224W] Lecture01 Introduction: Machine Learning for Graphs  (0) 2022.06.17

Level 별 prediction task 방법

Question: GNN을 통해 Node Embedding을 얻고 나서 실제 Task는 어떻게 수행할 것인가?

--> 각 레벨 별로 어떤 식으로 task를 수행할 수 있는지 알아보자.

(아래에서는 prediction head를 가지고 어떻게 활용할 수 있는지 알아보는 것이라 할 수 있다.)

Level 방법
Node-level - 노드 임베딩에 웨이트를 곱해서 바로 prediction!
Edge-level - 두 노드 임베딩을 concat한 뒤 MLP 적용!


- 두 노드 임베딩을 dot-product 
1- way: 바로 product

k-way: 가운데 여러개의 Weight를 시도한 뒤 Concat해서 만든다.

 
Graph-level 노드 임베딩을 aggregation처럼 모아주는 operator를 활용해서 예측에 사용한다.


NOTE: 위의 global pooling method들 같은 경우 서로 다른 노드 임베딩이 같은 결과를 나타낼 때도 있기 때문에 다른 방법이 필요하다.
아래 예제에서도 sum pooling을 하게 되면 node feature들의 절대값들이 달라도 결과가 같아지게 되는 문제가 생긴다.

DiffPool (Hierarchical Pooling)

위의 그림처럼 원본 그래프를 점점 더 작은 그래프로 바꾼 다음 graph embedding을 구하게 된다.

각 단계별로

1. 노드 임베딩을 구하고

2. 주어진 노드 임베딩을 가지고 클러스터링을 한뒤

3. 클러스터링 별로 pooling을 적용하여 다음 스테이지를 위한 single node를 만든다.

 

Supervised Vs Unsupervised Tasks

그래프를 가지고 할 수 있는 task 중에 supervised와 unsupervised가 있다.

아래 표를 보면 알겠지만 unsupervised 같은 경우는 그래프의 구조 자체를 활용한다.

Supervised Unsupervised (Self-supervised)
label(node, edge, graph level)이 존재한다.
가장 쉬운 task이므로 왠만하면 문제를 이 task들로 변환하여 풀도록 하자!
label이 존재 하지 않기 때문에 그래프에서 얻을 수 있는 정보를 최대한 활용
- Node level: Node degree, clustering coefficient
- Edge level: Link prediction (실제로 존재하는 edge를 숨겼다가 다시 찾아내기)
- Graph level: graph statistics를 이용

Loss function for supervised learning

- Classification task: Cross entropy loss를 활용한다. (아래 cross entropy는 K개의 data가 있을 때를 의미한다.)

- Regression task: L2 error (Mean Square error)를 활용한다.)

Evaluation Metric

- Classification task

Multi-class classification의 경우

 --> accuracy를 측정 전체 중 정확하게 예측한 것은 몇개인가?

 

Binary classification의 경우

--> Confusion matrix를 활용

Accuracy: 전체 중 맞춘 것은 무엇인가? (TP + TN) / (TP + TF + FP + FN)

Precision: True라고 예측한 것 중에 진짜 True의 비율(얼마나 오진을 했냐) (TP) / (TP + FP)

Recall: 모든 True들 중 True라고 예측한 것의 비율(얼마나 진짜 True를 빼먹지 않았냐) (TP) / (TP + FN)

 

어떤걸 예측한다고 했을 때, False의 비율이 굉장히 낫다고 하자.

그러면 무지성으로 True로 예측하면 왠만하면 Accuracy, Precision, Recall이 모두 높게 나올 것이다.

그래서 다음과 같은 개념이 등장한다.

 

ROC (Receiver operating characteristic) curve and AUC (Area under Curve)

Binary classification에서 threshold를 바꾸면 TPR과 FPR이 바뀌게 된다.

우리의 목표는 FPR이 0이 되고, TPR이 1이 되게 만드는 것이다. (최대한 왼쪽 위)

threshold가 1이면 아무것도 positive가 안될 것이므로 TP=0, FP=0이어서 (FPR,TPR)=(0,0)에 있을 것이고

threshold가 0이면 모든 것이 positive가 될 것이므로 (FPR,TPR)~(1,1)에 있을 것이다.

 

우리는 최대한 curve 자체도 왼쪽 위로 넓어지기를 원하고 (AUC)

그 중 가장 왼쪽 위에 있는 point를 통해 threshold를 고를 것이다.

 

- Regression:

Root Mean Square Error(RMSE)

 Mean Absolute Error(MAE)

Split data (train, validation, test set)

일단 random하게 split하고 여러번의 학습 결과를 평균을 내야 한다.

그렇다면 어떻게 random하게 split해야 할까?

 

Transductive Setting Inductive Setting
(1) original graph 전체를 사용한다.
(2) 구한 embedding을 train, validation, test에 모두 수행한다.

Node, Edge task만 수행 가능하다.
(1) Graph를 training, validation, test용으로 쪼갠다.
(2) 각 graph셋만 가지고 embedding을 구하고 prediction 한다.

Node, Edge, Graph task 수행 가능

Link Prediction example

Link prediction에서는 edge를 감춰서 학습에 쓰일 edge는 message passing edge, 실제 존재하는지 맞춰야하는 edge를 supervision edge로 나눈다.

Inductive setting에서는 단순히 그래프를 나눠서 하면 되지만 아래와 같은 Transductive setting은 약간 tricky하다.

 

(1) Transductive setting에서는 그래프를 세 개로 세트로 나눈 뒤 그 안에서 message passing edge, supervision edge로 나눈다.

(2) Training 단계에서는 Training message passing edge와 Training supervision edge로 학습한다.

(3) Validation 단계에서는 Training message passing edge + training supervision edge로 validation supervision edge를 예측한다.

(4) Test 단계에서는 Training message passing edge + training supervision edge + Validation supervision edge로 test supervision edge를 예측한다.

 

참고: 원래 Transductive setting이라면 모든 단계에서 전체 graph structure를 이용해야 하기 때문에 모든 edge를 이용해야 하지만 edge 자체가 각 단계에서 prediction의 대상이 되기 때문에 점진적으로 더 넣어주게 되는 것이다.

 

* feature augmentation

(사실 이 부분은 정작 중요한 부분이 아닌데 8단원 강의의 첫부분에 나온게 이상하긴 하다)

 

때로는 feature가 부족하거나 아무것도 없을 때는 임의로 feature를 넣기도 한다.

Feature augmentation 방법 Constant 1을 집어넣기 Node ID를 One-hot으로 집어넣기
특징 - expressive power는 낮다. (graph의 구조에 대해서만 배운다.)
- computation cost가 적다.
- 그래프가 더 커져도 적용이 가능하다. (단순히 1만 추가하면 되므로)
- expressive power는 높다 (각각의 node에 대해서 배울 수 있다.)
- Computation cost가 크다. (one-hot vector 크기가 크므로)
- 새로운 노드에는 적용이 불가능하다. (one-hot vector의 크기가 정해져 있으므로 노드가 늘어나면 새로 vector 크기를 정의해야 한다.
설명

때로는 computation graph로는 배울 수 없는 feature들도 있다.

--> 예: 특정 길이의 cycle의 수

이러한 구조적인 정보는 넣어주면 더 좋다. (degree, clustering coefficient, pagerank, centraility)

 

* Structure augmentation

Too Sparse할 경우 Too Dense할 경우
- biparate graph에서는 edge를 augment해서 한 논문을 같이 쓴 저자들을 묶어줄 수 있다.
- 또는 모든 노드를 잇는 임의의 노드를 만들어서 아주 먼 노드를 이어줄 수 있다.
- Neightbor Node를 샘플링 해야 한다. --> computation cost와 missing node로부터의 important message 간의 trade-off
 


 

 

 

'배움 > Graph Representation Learning' 카테고리의 다른 글

9. How expressive are GNN?  (0) 2022.11.02
7. GNN  (0) 2022.10.13
[CS224W] Lecture01 Introduction: Machine Learning for Graphs  (0) 2022.06.17

 

Classification Tree

Decision Tree(결정 트리)란?

Decision Tree는 supervised learning 중의 하나로 연속적인 질문을 통해서 예측하는 모델이다.

이 때, 최종적으로 분류를 예측하면 Classification Tree, 값을 예측하면 Regression Tree가 된다.

Tree라는 이름에 걸맞게 root node부터 시작해서 가지(branch)를 치며 여러가지 node로 나뉘어 지게 된다.

마지막에 예측을 하는 node는 leaf라고 부른다.

이 때, 각 node에는 각 변수에 관한 질문들이 있고, 이에 대한 대답을 통해 branch로 나뉘어지게 된다.

목표는 각 node의 질문을 잘 설정해서 마지막 분류 혹은 값의 예측이 실제 분류나 값과 일치하거나 에러가 적도록 만들어야 한다.

위의 표에서 보이는 바와 같이 세가지 변수를 가지고 Cool As Ice라는 제품을 좋아하는지 분류하는 모델을 만드는 과정을 살펴보자.

 

1. 어떤 변수를 가지고 root node를 만들어야 할까?

가장 먼저 하는 질문이므로 가장 결과를 잘 분류할 수 있도록 변수를 설정해야 할 것이다.

먼저 "Loves Popcorn"이라는 항목과 "Loves Soda"을 가지고 분류를 한다고 생각해보자.

과연 둘 중에 누가 분류를 더 잘한 것일까?

직관적으로 "Love Soda?"라는 질문을 가지고 했을 때 좀 더 분류를 뚜렷하게 하는 것 같다.

하지만 이걸 어떻게 수치화할까?

 

이 때 우리는 Gini Impurity라는 방법을 사용하는데 Gini Impurity가 낮은 쪽으로 (즉, pure하게 나뉠 수 있도록) 가지를 나눠야 한다.

 

Gini Impurity 계산하기

Leaf node에서 Gini Impurity는 1에서 각 분류의 비율을 제곱해서 빼면 된다.

맨 왼쪽의 경우 yes는 1/4, no는 3/4이므로 

$$ 1 - (\frac{1}{4})^2 - (\frac{3}{4})^2 = 0.375 $$

 

Leaf node가 아닌 node에서는 child node들의 Gini Impurity를 데이터의 비율대로 더해주면 된다.

Love Popcorn?의 경우를 살펴보면

$$ \frac{4}{7}(1 - (\frac{1}{4})^2 - (\frac{3}{4})^2) + \frac{3}{7}(1 - (\frac{2}{3})^2 - (\frac{1}{3})^2) = 0.405 $$

 

Love Soda?의 경우를 살펴보면

$$ \frac{4}{7}(1 - (\frac{3}{4})^2 - (\frac{1}{4})^2) + \frac{3}{7}(1 - (\frac{3}{3})^2 - (\frac{0}{3})^2) = 0.214 $$

 

위의 수식과 같이 "Love Popcorn?"을 사용하면 Gini Impurity는 0.405, "Love Soda?"는 0.214라는 값이 나온다.

따라서 우리의 직관과 마찬가지로 "Love Soda?"를 가지고 분류하는 것이 더 좋다.

 

Q2. 그렇다면 카테고리가 아닌 값으로 되어 있는 변수는 어떻게 활용할까?

Age 같은 경우에는 카테고리가 아닌 값으로 되어 있기 때문에 분류에 활용하기 위해서는 경계값(threshold)를 정해주어야 한다.

즉, Age를 오름차순으로 정렬했을 때, 그 중간값을 가지고 노드를 만들게 된다.

예를 들면 7과 12의 중간값인 (Age > 9.5?)라는 노드를 만들어서 가지를 만들게 된다.

그러면 카테고리 값과 마찬가지로 Gini Impurity를 구할 수 있을 것이다.

 

데이터가 N개 있으면 N-1개의 중간값을 만들 수 있는데 이를 전부 체크해서 최소 Gini Impurity를 만드는 중간값을 경계값으로 정한다.

각 중간값을 경계로 했을 때의 Gini Impurity

 

최종적으로 세 변수로 만든 Node 중 Gini Impurity가 가장 낮은 것을 골라 root node를 만든다.

 

위의 예제에서는 Love Popcorn? = 0.405,  Love Soda? = 0.213, Age > 15? = 0.343 이므로 Love Soda?를 root node로 사용하면 된다.

 

root node 이후에는 어떻게 해야할까?

root node 이후에도 이와 같은 방식으로 계속 가지를 만들어 내려가면 된다.

노드 하나당 가지를 몇개로 할 것인지(보통 2개이다), root로부터 얼마만큼 깊게 가지를 내릴 것인지 등은 하이퍼파라미터로 정할 수 있다.

 

 

왜 RNN이 필요한가?

순서가 의미를 부여하는 데이터를 Sequence 라고 부른다. 다시 말하면 하나의 input이 하나의 output을 만들어 내기보다는 input 하나하나가 쌓여 output을 만들게 된다.

특히, 이렇게 순서가 중요한 데이터들 중 시간에 따른 의미가 존재하는 데이터를 시계열 (time series)이라 부른다. (예: 주식 가격 데이터)

예: 다음 스펠링 예측

예를 들어 사용자가 triangle이라는 단어를 입력하고 싶다고 해보자. 그러면 사용자는 차례대로 키보드를 통해 t, r i, a, n, g, l, e라는 글자를 타이핑할 것이다. 그리고 결국에 우리가 하고 싶은 것은 사용자가 전부 스펠링을 치기 전에 triangle이라는 단어를 예측하고 차례대로 출력해야 한다.

그러기 위해서는 사용자가 n을 치고 있을 때, 그전에 차례대로 입력한 t,r,i,a를 알고 있어야 한다.

이렇듯 언어 처리는 전반적으로 이전의 데이터들이 context가 되기 때문에 비슷한 방식으로 처리 해야한다. (단어 자동 완성, 다음 단어 예측)

즉, 순서가 중요한 데이터들을 처리하기 위해서는 기존의 Neural Network와는 다르게 메모리를 가지고 있어 이전 입력값이 현재의 출력에 영향을 주어야 한다.

RNN의 구조

정말 간단한 구조의 MLP(Multi layer Perception)과 RNN을 비교해 보자.

MLP
RNN

가장 큰 차이점은 새로 그려진 hidden layer 자기 자신으로 다시 들어가는 화살표이다. 이 화살표들은 이전에 hidden layer에서 계산된 결과값들이 다음 버너 hidden layer의 결과값 계산에 쓰인다는 이야기이다. 즉, MLP와는 다르게 hidden layer의 결과값(t=n)을 계산하기 위해서는 새로운 input(t=n)과 더불어 hidden layer의 이전 결과값(t=n-1)이 필요하다.

예: 'name'이라는 단어의 스펠링 예측하기

 

우리는 RNN을 통해 name이라는 단어의 스펠링을 예측하도록 만들것이다.name이라는 스펠링을 예측한 다는 것은 n을 타이핑 했을 때, a를 예측해야 하고, a를 타이핑 했을 때는 m, m을 타이핑 했을 때는 e를 예측해야 한다.

위의 그림은 hidden layer가 2개인 RNN의 구조이다. 보이다시피 파란색 화살표를 통해 이전 시점의 값을 현재 시점의 값을 계산하는데 사용하게 된다.

위의 RNN을 식으로 풀어 써보면 다음과 같다. (w는 weight, b는 bias이고 activation functino으로 tanh를 사용하였다.)

 

위의 RNN 구조를 시점 별로 풀어서 나타내면 다음과 같다. (여기서 또한 이전의 값을 사용하게 되는 걸 파란색 화살표로 표시하였다.)

output을 처음 계산할 때에는(t=0) hidden layer들의 이전 값들이 없기 때문에 적당한 값을 initial value로 지정해준다. (h1(-1), h2(-1))

RNN의 트레이닝 과정

기본적으로 트레이닝 과정은 MLP의 트레이닝 과정과 비슷하다.

특정 input에 대해 기대하는 정답 output(Label)이 정해져 있다. (n→a, a→m, m→e) 즉, 3개의 input에 대해 3개의 label이 있고, label과 output의 차이를 가지고 loss function을 만들어 gradient descent를 이용하여 weight와 bias들을 업데이트하게 된다.

다만 큰 차이점이라고 한다면 특정 시점의 output을 계산할 때, 그 이전 시점예 계산됐던 값이 필요하다는 점이다. 예를 들어 w22를 업데이트 하고 싶다고 하면 우리는 다음을 구해야 한다.

 

예를 들어 w22를 구하기 위해서는 h2의 이전 값(h2(t-1))을 알아야 하고, 또 h2(t-1)을 구하기 위해서는 h2(t-2)를 구해야 할 것이다. 이처럼 back propagation이 layer의 앞쪽 방향으로만 가는게 아니라 시점도 앞으로 가게 된다. 이를 Back Propagation Through Time (BPTT)라고 한다. 화살표가 좌에서 우로 가는 것뿐만 아니라 위에서 아래로도 내려온다는 것을 생각하면 된다.

RNN의 한계

time sequence가 늘어나면 back prop시에 미분값이 0과 1 사이의 값이 여러번 곱해져서 기울기가 소멸한다. 따라서 트레이닝이 되질 않는다. (gradient vanishing)

⇒ LSTM이나 GRU를 사용한다.

+ Recent posts