home
👩‍💻

TDD를 통해 ModernBERT 밑바닥부터 이해하기

Author
Sigrid Jin (Jin Hyung Park) | ML DevRel Engineer
Category
Hands-on
Tags
Model Architectures
Encoding
Published
2025/02/05
5 more properties

들어가며

간단하게 설명드리면 BERT는 RAG에서 사용되는 임베딩과 리랭커 모델의 근간이 되는 모델 아키텍처인데, ModernBERT는 Databricks 사가 개발한 MosaicBERT를 바탕으로 최신 논문들과 기술을 집약하여 내놓은 끝판왕 이라고 할 수 있겠습니다.
허깅페이스 아티클에서는 ModernBERT를 내 높으면서 RAG 파이프라인(Retrieval Augmented Generation)과 추천 시스템 등 인코더 전용 모델이 현재 배포되고 있는 수많은 응용 프로그램에서 새로운 표준이 될 것으로 기대한다고 표현했고 성능 테스트 결과도 매우 우수한 것을 알 수 있습니다.
하지만 아직 ModernBERT라는 이름이 낯설고, 어떤 원리로 동작하며 왜 이렇게 구조가 짜였을지? 를 이해하기 어려울 수 있습니다. 모델링 코드에서 유닛 테스트를 작성하는 문화가 활발하지 않다는 점을 생각해본다면, 테스트 코드를 작성해보면서 모델링 코드를 학습하는 것이 좋은 방법이 될 수 있을 것입니다.
이번 아티클에서는 ModernBert PyTorch 구현체를 TensorFlow 구현체로 변환하는 프로젝트를 통해 모델 구조를 학습하실 수 있도록 안내해드리고자 합니다. ModernBERT의 구현체를 TF로 변환하는 이유는 분명한데요, 바로 Enterprise 레벨에서 서비스를 고객에게 다가갈 수 있도록 제공하기 위해서는 Java/Kotlin 생태계를 위시한 JVM 계열의 백엔드 서빙이 유리하기 때문입니다. TensorFlow를 통해 모델을 서빙한다면 JVM와 Spring의 우수한 생태계를 레버리지 하여 딥러닝 모델 서비스를 할 수 있습니다.
ModernBERT에서 가장 핵심이 되는 개념은 글로벌 어텐션 vs. 로컬 어텐션 혼합NTK 스케일링 로터리 임베딩 두 가지입니다. 그리고 이 둘을 연결하는 파이프라인은 결국 TransformerBlock이 담당하지요. 각각을 작은 테스트로 분리해 두었습니다.
flowchart TD
    Input["Hidden States Input"]
    
    subgraph MultiHead["Multi-Head Attention Module"]
        LN1["Layer Norm<br/>(if layer_id != 0)"]
        QKV["QKV Dense Projection<br/>(d_model -> 3 x d_model)"]
        Split["Split QKV"]
        Q["Q: [B,S,H,D]"]
        K["K: [B,S,H,D]"]
        V["V: [B,S,H,D]"]
        
        subgraph RPE["Rotary Position Embedding"]
            Cos["cos(θ)"]
            Sin["sin(θ)"]
            Rot["Apply Rotation"]
        end
        
        Attention["Scaled Dot-Product Attention<br/>Q·K^T/√d + mask"]
        OutProj["Output Projection<br/>(d_model -> d_model)"]
        Drop1["Dropout<br/>(if training)"]
    end
    
    subgraph FFN["Feed-Forward Network"]
        LN2["Layer Norm"]
        W1["W1 Dense<br/>(d_model -> 2*intermediate_size)"]
        Split2["Split for GLU"]
        GELU["GELU"]
        Gate["Gate"]
        Mult["×"]
        W2["W2 Dense<br/>(intermediate_size -> d_model)"]
        Drop2["Dropout<br/>(if training)"]
    end
    
    Output["Output"]
    
    Input --> LN1
    LN1 --> QKV
    QKV --> Split
    Split --> Q
    Split --> K
    Split --> V
    Q --> RPE
    K --> RPE
    Cos --> Rot
    Sin --> Rot
    RPE --> Attention
    V --> Attention
    Attention --> OutProj
    OutProj --> Drop1
    Drop1 --> LN2
    LN2 --> W1
    W1 --> Split2
    Split2 --> GELU
    Split2 --> Gate
    GELU --> Mult
    Gate --> Mult
    Mult --> W2
    W2 --> Drop2
    Drop2 --> Output
Mermaid
복사
먼저, ModernBERT-GTE 모델은 내부적으로 다양한 로컬 어텐션(local attention) 기법NTK 스케일링을 도입한 Rotary Positional Embedding 을 접목한 특수한 변형 BERT 구조입니다. 그래서 기존 BERT나 XLM-RoBERTa 와 흡사하면서도 미세한 차이가 있습니다. 실험적으로는 8K 이상의 긴 문단까지 효율적으로 처리하기 위해 슬라이딩 윈도우 방식의 로컬 어텐션과 글로벌 어텐션을 섞는 구성이 포함되어 있지요.
테스트 코드와 구현체는 아래 GitHub 에서 확인하실 수 있습니다.
modernbert-gte-model-converter
sionic-ai
# Repository 를 Clone하고, uv 를 이용하여 파이썬 의존성을 맞추어보세요. git clone https://github.com/sionic-ai/ModernBert-GTE-Model-Converter cd ./ModernBert-GTE-Model-Converter uv venv source ./venv/bin/activate uv sync #uv add -r requirements.txt # 테스트 실행하기 pytest -v
JavaScript
복사
아래는 PyCharm 을 통해 uv 의존성을 설정하고 작성한 유닛 테스트를 실행하여 통과한 화면입니다.

Test Suite 1. 로컬 슬라이딩 윈도우 마스크

ModernBERT의 특별한 점 중 하나는 로컬 어텐션을 지원한다는 것입니다. 로컬 어텐션이란, 특정 토큰에서 일정 윈도우 범위 내에서만 어텐션을 주고받도록 제한하는 기법입니다. 예를 들면 앞뒤로 64~128개의 토큰 범위 내에서만 어텐션을 적용하는 것이지요. 이로써 긴 시퀀스에서도 연산 복잡도를 낮추고, 긴 문장에서도 계산량 폭발 없이 처리가 가능해집니다.
아래 테스트 코드에서는 몇 가지 배치 크기와 시퀀스 길이, 그리고 윈도우 크기를 다양하게 파라미터로 주어 테스트를 반복하는 구조입니다. 1배치, 시퀀스 길이 10, 윈도우 4부터 시작하여, 3배치에 시퀀스 길이 20, 윈도우 8 등의 다양한 조합을 시도하도록 파라미터를 구성하였습니다.
상황(Given) “글로벌 마스크가 모두 유효하여 패딩이 없는 상태”입니다. 구체적으로는 배치 크기나 시퀀스 길이가 여러 형태로 주어지지만, 결국 마스크 값이 0.0이 가득 채워져 있다는 점이 변하지 않습니다. 그때(When) “로컬 슬라이딩 윈도우 마스크 생성 함수를 호출”해서 실제 윈도우 크기만큼의 범위를 제외하고 바깥을 차단(-∞로 설정)하도록 요청합니다. 그러면(Then) 결과적으로 반환되는 마스크는 윈도우 범위 내에서는 모두 0.0을 유지하고, 그 범위 바깥쪽은 극도로 작은 값(마치 -∞처럼)에 해당하게 됩니다. 이 시나리오를 통해 테스트 코드는 “로컬 윈도우가 정확히 잘려 나가는지”와 “마스킹 값이 정말로 -∞ 범위에 가깝게 할당되는지”를 확인함으로써, 로컬 어텐션에 필요한 마스킹이 올바른지 판단합니다.
import tensorflow as tf import numpy as np import pytest from ModernGTETFModel import create_local_sliding_window_mask @pytest.mark.parametrize( "batch_size, seq_len, window_size", [ (1, 10, 4), # batch_size 1, sequence_length 10, window_size 4 (2, 15, 6), # batch_size 2, sequence_length 15, window_size 6 (3, 20, 8), # batch_size 3, sequence_length 20, window_size 8 ], ) def test_create_local_sliding_window_mask( batch_size: int, seq_len: int, window_size: int ): ATTENTION_MASK_DIMENSION: int = 1 HALF_WINDOW_SIZE: int = window_size // 2 # Given: 모든 위치가 유효하다고 전제 하므로 모든 원소를 0.0 으로 초기화 한다. global_mask: tf.zeros = tf.zeros( (batch_size, ATTENTION_MASK_DIMENSION, seq_len, seq_len), dtype=tf.float32 ) # When: 함수를 호출하여 슬라이딩 윈도우 마스크를 생성합니다. result: tf.Tensor = create_local_sliding_window_mask(global_mask, window_size) # Then: tensor 의 shape (batch_size, 1, seq_len, seq_len) 인지를 검증해봅니다. assert result.shape == (batch_size, ATTENTION_MASK_DIMENSION, seq_len, seq_len) # Then: (i, j) 위치에 대하여 윈도우의 내부이면 0.0 이요, 그 외부이면 매우 작은 값이 나온다는 것을 검증 합니다. result_np: np.ndarray = result.numpy() # 인덱스 배열을 생성한 후에, 각 위치에 저장된 값의 절댓값 을 계산 합니다. # 절댓값 행렬 (seq_len, seq_len) 은 두 1D 배열 간의 외적 차이를 계산하여 각 토큰 위치 간의 거리를 나타냅니다. indices: np.ndarray = np.arange(seq_len) diff_matrix: np.ndarray = np.abs(np.subtract.outer(indices, indices)) # 절대값의 차이가 half_w 이하인 윈도우 내부 위치는 True이고, 외부는 False인 마스크를 만들어봅니다. valid_mask: np.ndarray = diff_matrix <= HALF_WINDOW_SIZE # 결과 행렬의 첫 번째 배치의 값을 추출해봅니다. shape (seq_len, seq_len) 이 됩니다. # 추출한 이후에는 윈도우 내부의 값과 윈도우 외부의 값을 가져와봅니다. result_matrix: np.ndarray = result_np[0, 0] window_inner_matrix: np.ndarray = result_matrix[valid_mask] window_outer_matrix: np.ndarray = result_matrix[~valid_mask] # 내부 값들은 0.0 이어야 합니다. np.testing.assert_allclose( actual=window_inner_matrix, desired=0.0, atol=1e-6, err_msg="윈도우 내부 matrix에서 0.0이 아닌 값이 존재하네요.", ) # 외부 값들은 매우 작은 값 이어야 합니다. 여기서 매우 작은 값은 -1e8 미만을 의미 합니다. LOW_THRESHOLD: float = -1e8 assert np.all( outside_values < LOW_THRESHOLD for outside_values in window_outer_matrix ) if __name__ == "__main__": pytest.main([__file__])
JavaScript
복사
테스트 함수 내부에서, 먼저 global_mask라는 모든 값이 0.0인 마스크 를 생성합니다. 이는 패딩이 없다는 가정 하에, 전부 유효 토큰이라고 보는 4D 마스크입니다. 그 후 create_local_sliding_window_mask 함수를 호출합니다. 이 함수는 (batch_size, 1, seq_len, seq_len) 형상의 마스크를 입력으로 받고, 윈도우 범위를 벗어나는 부분은 -∞로 세팅하여 로컬 어텐션용 마스크로 변환해 주는 역할을 합니다.
def create_local_sliding_window_mask(global_mask_4d, window_size): """ ModernBERT의 local attention을 위한 "양방향 슬라이딩 윈도우" 마스크를 만든 뒤, 원본 global_mask(패딩 토큰 마스킹)와 결합하여 최종 4D float 마스크를 반환합니다. - PyTorch 코드를 예시로: distance = |i - j| window_mask = (distance <= window_size//2) sliding_window_mask = global_attention_mask.masked_fill(~window_mask, -) 에 대응. Args: global_mask_4d: shape (batch_size, 1, seq_len, seq_len) - 이미 패딩 토큰 부분은 -(float('-inf')) 또는 0.0 으로 구성된 float 마스크. - 보통 BERT식 마스크는 "유효 위치=0.0, 무효 위치=-∞" 형태임 window_size (int): 로컬 윈도우 크기 (: 128) Returns: final_local_4d (tf.Tensor): shape (batch_size, 1, seq_len, seq_len) - 윈도우 내부이면서 유효 토큰이면 0.0 - 윈도우 밖이거나 패딩이면 -""" # global_mask_4d: [B, 1, S, S] batch_size = tf.shape(global_mask_4d)[0] seq_len = tf.shape(global_mask_4d)[-1] # (S, S)에서의 distance 계산 rows = tf.range(seq_len)[:, None] # shape (S,1) cols = tf.range(seq_len)[None, :] # shape (1,S) distance = tf.abs(rows - cols) # shape (S,S) # distance가 window_size//2 이내면 True half_w = window_size // 2 window_bool_2d = tf.less_equal(distance, half_w) # (S,S), True/False # True면 0.0, False면 -∞인 2D float mask inside_0 = tf.zeros([seq_len, seq_len], dtype=TFDTYPE) outside_inf = tf.fill([seq_len, seq_len], TFDTYPE.min) # -1e9) local_mask_2d = tf.where(window_bool_2d, inside_0, outside_inf) # (S,S), float # shape 확장: [1,1,S,S] local_mask_4d = local_mask_2d[None, None, :, :] # batch 차원만큼 복제: [B,1,S,S] local_mask_4d = tf.tile(local_mask_4d, [batch_size, 1, 1, 1]) # PyTorch의 sliding_window_mask = global_mask.masked_fill(~window_bool_2d, -inf)에 해당 # => local_mask_4d가 window 밖을 -∞로 만들어놓았으므로 # 밖일 때 -, 안일 때 0.0 # => global_mask_4d도 이미 "패딩 부분 -∞, 정상 부분 0.0" 형태이므로 # 합산하면 "둘 중 하나라도 -∞이면 -∞"라는 효과가 남 final_local_4d = global_mask_4d + local_mask_4d # (B,1,S,S) return final_local_4d
JavaScript
복사
테스트 코드의 Then 검증 로직부를 보면, 실제로 반환된 마스크를 Numpy로 변환(.numpy())해 윈도우 내부의 값이 0.0인지, 그리고 윈도우 외부가 충분히 작은 값(예: -1e8 미만)인지 점검하는 로직이 있습니다. TDD 관점에서 보면, “로컬 슬라이딩 윈도우 마스크를 만들면 이런 식으로 테스트를 통과해야 한다”가 명확히 제시되어 있으므로, 자연스럽게 해당 마스킹 함수를 어떻게 작성해야 하는지를 추론할 수 있습니다. 그리고 이 마스크가 완성되어야 TransformerBlock에서도 올바른 로컬 어텐션 연산이 가능하다는 것을 알 수 있겠습니다.

Test Suite 2. MultiHeadAttention의 RoPE 적용

Transformer가 텍스트를 이해하고 처리할 때 쓰는 핵심 기술로, 입력에 대해 쿼리(Q)키(K), 값(V) 이라는 세 종류의 벡터를 만듭니다. 그리고 Q와 K를 곱해 ‘관계(어텐션 스코어)’ 를 구하고, 이를 소프트맥스로 확률 분포(어텐션 가중치)로 변환합니다. 마지막으로 이 가중치를 V에 곱해 중요한 정보를 강조해 내보냅니다. 이 과정을 여러 헤드(병렬 계산)로 나누어 처리하기 때문에 멀티헤드 어텐션이라고 부릅니다.
텍스트에서는 단어(토큰)의 순서가 매우 중요합니다. 예전에는 각 위치(예: 1번째 단어, 2번째 단어 등)에 해당하는 임베딩을 단어 벡터에 단순히 더하는 절대 위치 임베딩 방식을 사용하였습니다. 하지만 최근에는 문맥이 길어지고, 단어들의 “상대적 위치”가 더 중요한 경우가 많아졌습니다. RoPE는 ‘로터리(Rotary)’, 즉 벡터를 회전시켜 위치 정보를 반영합니다. 쿼리(Q)와 키(K) 벡터를 각 단어의 위치에 따라 다른 각도로 회전시킴으로써, 단어들의 상대적 거리가 자연스럽게 반영되도록 돕습니다. 이렇게 하면 문장이 길어져도 단어 간의 거리(순서) 정보가 잘 보존될 수 있고, 모델이 더 유연하게 문맥을 처리할 수 있습니다.
ModernBERT는 바로 이 RoPE 방식을 멀티헤드 어텐션에 적용합니다. 즉, “어텐션은 그대로” 유지하면서 Q, K 벡터에 회전 연산을 추가해 상대적 위치를 효과적으로 반영합니다. 그 결과, 모델이 긴 문맥에서 더 나은 성능을 낼 수 있고, 다양한 문장 구조에서도 유연하게 대응할 수 있습니다.
test_ROPE가_존재하지_않는_멀티헤드_어텐션을_테스트한다 라는 테스트 함수에서는 RoPE를 아예 주지 않고, 멀티헤드 어텐션을 호출한 결과가 특정 형상과 특정 범위를 만족하는지 검사합니다. 특히 layer_id가 0이면 LayerNorm을 건너뛴다는 구현상 특징도 있기 때문에, layer_id=0 vs. layer_id!=0 케이스를 동시에 테스트합니다. 그리고 Dropout이 0.1로 들어갔을 때도, 여전히 올바른 shape가 나오고 수치적으로도 유효해야 한다는 검증을 함께 진행합니다.
상황(Given) “멀티헤드 어텐션 레이어를 초기화하되, 로타리 임베딩(cos, sin) 없이 호출한다”는 것입니다. 이때(When) “임의의 텐서(쿼리, 키, 값)를 넣어 어텐션을 수행”해 보면, RoPE가 없으므로 위치 정보가 들어가지 않은 채 계산이 진행됩니다. 결과(Then) “출력의 형태가 (batch_size, seq_len, hidden_dim)인지” 와, “원본 입력과 분명히 다른 값이어야 하지만(Layer 단위 정규화와 QKV 변환이 있으므로), 로타리 임베딩을 적용했을 때의 결과와는 비교해볼 때 달라야 한다”는 점입니다. RoPE가 빠진 멀티헤드 어텐션이 정상적으로 동작하는지를 보장합니다.
@pytest.mark.parametrize( "batch_size, seq_len, d_model, num_heads, dropout_rate, training, layer_id", [ ( 1, 128, 64, 8, 0.0, False, 0, ), # Case 1: 배치를 먼저 작게하고, 기본 설정으로 만들어봅니다. layer_id가 0인 경우 layer norm을 적용하지 않고 원시 임베딩. (2, 256, 128, 16, 0.0, False, 0), # Case 2: 중간 배치로 올려보겠습니다. ( 3, 512, 64, 4, 0.1, True, 1, ), # Case 3 : 3 개의 배치에 대해 dropout 0.1, training = True 이고 layer_id가 1 인 상태입니다. # layer_id가 1 인 경우 입력을 먼저 layer norm 하고 q, k, v를 계산합니다. 안정적인 학습을 위한 전략입니다. ], ) def test_ROPE가_존재하지_않는_멀티헤드_어텐션을_테스트한다( batch_size, seq_len, d_model, num_heads, dropout_rate, training, layer_id ): """ test_ROPE가_존재하지_않는_멀티헤드_어텐션을_테스트한다: rope_embed 없이 MHA 의 기본 출력을 검증해봅시다. Given: - 임의의 값으로 채워진 dummy input tensor (shape: [batch_size, seq_len, d_model]). - d_model은 num_heads로 나누어 떨어지는 값이어야 하며, - MultiHeadAttention 레이어가 주어진 파라미터로 초기화됩니다. When: - 해당 레이어가 mask와 rope_embeds 없이 dummy input에 대해 forward pass를 수행합니다. Then: - 출력 tensor의 shape는 [batch_size, seq_len, d_model]이어야 합니다. - 최종 구현에서는 입력(dummy input)과 비교하여, 내부의 QKV 프로젝션 및 attention, 출력 프로젝션을 통해 값이 변환되어야 하므로, output은 dummy input과 거의 동일하지 않아야 합니다. """ # Given : d_model 이 num_heads로 나누어 떨어져야만 합니다. if d_model % num_heads != 0: pytest.fail("d_model must be divisible by num_heads") # Given : 입력 텐서를 무작위의 값으로 생성합니다. shape는 [batch_size, seq_len, d_model] 으로 구성되어야 합니다. dummy_input = tf.random.uniform( shape=(batch_size, seq_len, d_model), dtype=tf.float32 ) # Given : MHA 인스턴스를 생성합니다. multi_head_attention: MultiHeadAttention = MultiHeadAttention( d_model=d_model, num_heads=num_heads, dropout_rate=dropout_rate, layer_id=layer_id, ) # When : 입력 텐서를 전달하여 레이어를 호출해봅니다. output: tf.Tensor = multi_head_attention( inputs=dummy_input, mask=None, rope_embeds=None, training=training ) # Then : 출력 텐서의 shape (batch_size, seq_len, d_model) 과 동일해야만 합니다. expected_shape: tuple = (batch_size, seq_len, d_model) assert ( output.shape == expected_shape ), f"Expected output shape {expected_shape}, but got {output.shape}" # Then : 출력값이 수치적으로 안정적인지를 shape에 대하여 확인해본다. np.testing.assert_allclose( output.numpy(), output.numpy(), err_msg="Output values are not consistent." ) # Then : 만약 layer_id가 0이 아니라면, 실제 어텐션의 연산과 출력 프로젝션이 적용되어 dummy_input 과 output이 달라져야 합니다. if np.allclose(dummy_input.numpy(), output.numpy(), atol=1e-6): pytest.fail( "Output is almost identical to input; transformation did not occur as expected." )
JavaScript
복사
아래에는 test_ROPE가_존재하는_멀티헤드_어텐션을_테스트한다라는 테스트가 있습니다. RoPE가 있는 멀티헤드 어텐션 레이어가 정상적으로 동작하는 지를 검증하는 테스트 로직으로 보시면 됩니다.
Given: “cos과 sin 형태의 로타리 임베딩 텐서를 준비”하는 것이며, 그 차원과 시퀀스 길이가 실제 head_dim과 seq_len에 일치하도록 의도적으로 생성됩니다. When: “다시 멀티헤드 어텐션 레이어를 호출”해 보면, 이번에는 쿼리(Q)와 키(K)에 Rotary 연산이 적용됩니다. Then: “출력 텐서가 기존 ‘RoPE가 없던 경우’와는 반드시 달라야 한다”는 사실이 비교 검증으로 드러납니다. 즉, np.allclose 등을 사용해 두 결과가 같지 않아야 테스트를 통과하게 되는데, 이는 곧 “로타리 임베딩이 실제로 어텐션에 영향을 미친다”는 것을 증명해 냅니다.
@pytest.mark.parametrize( "batch_size, seq_len, d_model, num_heads", [ (1, 10, 64, 8), # Case 1: 간단한 설정 (2, 15, 128, 8), # Case 2: 조금 더 큰 설정 ], ) def test_ROPE가_존재하는_멀티헤드_어텐션을_테스트한다( batch_size, seq_len, d_model, num_heads ): """ Given a dummy input tensor and a MultiHeadAttention layer, and given dummy rope embeddings (cos, sin) with the correct shape, When the layer is invoked with these rope embeddings, Then the output tensor should have the same shape as (batch_size, seq_len, d_model). """ # Given: d_model이 num_heads로 나누어 떨어져야 합니다. if d_model % num_heads != 0: pytest.skip("d_model must be divisible by num_heads") # Given: 입력 텐서를 무작위로 생성 (shape: [batch_size, seq_len, d_model]) dummy_input = tf.random.uniform((batch_size, seq_len, d_model), dtype=tf.float32) # Given: dropout 없이, layer_id=1 (, 로타리 임베딩 적용 케이스)로 MultiHeadAttention 인스턴스를 생성합니다. mha = MultiHeadAttention(d_model, num_heads, dropout_rate=0.0, layer_id=1) # Given: 각 헤드의 차원 계산 (d_model // num_heads) head_dim = d_model // num_heads # Given: rope_embeds를 위한 더미 cos, sin 텐서를 생성합니다. # - 기대하는 rope_embeds의 shape는 [seq_len, 2 * head_dim] # Q: dummy_cos 가 모두 1 이라는 cosine tensor 의 의미는 ? # 실제로는 각 토큰의 위치가 cos 값에 따라 달라지겠지만, 더미 값으로 모든 위치를 1을 반환하게 해서 회전 효과가 없다는 것을 테스트한다. # Q: dummy_sin 이 모두 0 인 sin tensor 의 의미는 ? # 실제로는 임베딩이 sine 값 이 위치에 따라 달라지지만 모든 값을 0 으로 반환하게 해서 회전 효과가 없다는 것을 테스트한다. dummy_cos = tf.ones((seq_len, 2 * head_dim), dtype=tf.float32) # 모든 값 1 dummy_sin = tf.zeros((seq_len, 2 * head_dim), dtype=tf.float32) # 모든 값 0 rope_embeds = (dummy_cos, dummy_sin) # When: rope_embeds를 포함하여 MultiHeadAttention 레이어를 호출합니다. output = mha(dummy_input, mask=None, rope_embeds=rope_embeds, training=False) # Then: 출력 텐서의 shape가 (batch_size, seq_len, d_model)와 동일해야 합니다. expected_shape = (batch_size, seq_len, d_model) assert ( output.shape == expected_shape ), f"Output shape mismatch with rope_embeds: expected {expected_shape}, got {output.shape}" # 그리고, 함수형 검증 도구를 활용하여 출력값의 일관성을 확인합니다. np.testing.assert_allclose( output.numpy(), output.numpy(), err_msg="Output values are inconsistent when using rope_embeds.", )
JavaScript
복사
ModernBERT에서 MHA에 적용된 이 로직은 ModernGTETFModel.py 파일의 MultiHeadAttention 클래스 구현부에서 확인할 수 있습니다. 그 안에 apply_rotary_pos_emb 메서드가 있고, 쿼리, 키 텐서에 cos, sin 값을 섞어주는 Rotation 로직이 들어 있습니다. 어느 축으로 half split을 하고, 어떻게 concat해 다시 합치는지 를 테스트 코드를 통해서 파악하는 로직입니다.
class MultiHeadAttention(keras.layers.Layer): def __init__( self, d_model: int, num_heads: int, dropout_rate: float = 0.1, layer_id: int = 0, **kwargs, ) -> None: ... def apply_rotary_pos_emb( self, q: tf.Tensor, k: tf.Tensor, cos: tf.Tensor, sin: tf.Tensor ) -> tuple[tf.Tensor, tf.Tensor]: """ Applies the rotary positional embedding to query and key tensors. :param q: Query tensor :param k: Key tensor :param cos: Cosine tensor for rotary embedding :param sin: Sine tensor for rotary embedding :return: Tuple containing modified query and key tensors (q_embed, k_embed) """ # 이 함수는 주어진 텐서를 마지막 차원에서 두 부분으로 분할한 후, # 두 번째 절반의 부호를 반전시키고, 순서를 뒤바꿔서 첫 번째 절반과 결합합니다. # 예를 들어, 입력 벡터가 [x1, x2]라면, 출력은 [-x2, x1]가 됩니다. def _rotate_half(x: tf.Tensor) -> tf.Tensor: # 마지막 차원(axis=-1)을 두 개의 동일한 크기의 텐서로 분할합니다. x1, x2 = tf.split(value=x, num_or_size_splits=2, axis=-1) # x2의 부호를 반전한 후, x1과 함께 이어 붙입니다. # Q: 왜 마지막 차원을 분할하냐? # Transformer에서는 각 토큰마다 하나의 임베딩 벡터가 있으며 # MultiHeadAttention에서는 각 헤드마다 그 벡터의 일부(, head_dim)를 사용합니다. # 이 마지막 차원이 실제 피처(특징)들이 담긴 부분이기 때문에, 여기서 정보를 반으로 나누어 회전 변환을 적용하는 것이 자연스럽습니다. return tf.concat([-x2, x1], axis=-1) # 여기서 cos, sin은 원래 [seq_len, 2 * head_dim] 형태일 수 있습니다. # reshape을 통해 [1, 1, seq_len, head_dim] 형태로 바꿉니다. # - 첫 번째 1: 배치 차원에 대해 확장 (모든 배치에 동일한 값을 사용) # - 두 번째 1: 헤드 차원에 대해 확장 (모든 헤드에 동일하게 적용) # - -1: seq_len을 그대로 유지 # - self.head_dim: 마지막 차원은 각 헤드의 차원 크기 num_heads, seq_len, head_dim = 1, 1, -1 cos = tf.reshape(cos, [num_heads, seq_len, head_dim, self.head_dim]) sin = tf.reshape(sin, [num_heads, seq_len, head_dim, self.head_dim]) # 위의 reshape 후, cos와 sin은 [1, 1, 2 * seq_len, head_dim]의 shape를 가지게 됩니다. # 하지만 실제로 적용할 때는 query q와 key k의 시퀀스 길이에 맞춰야 하므로, # 필요한 부분만 슬라이싱하여 [1, 1, seq_len, head_dim]의 shape로 맞춥니다. cos = tf.reshape(cos, [1, 1, -1, self.head_dim]) sin = tf.reshape(sin, [1, 1, -1, self.head_dim]) # tf.shape(q)[2]는 q 텐서의 세 번째 차원, 즉 시퀀스 길이 S를 나타냅니다. # 첫 번째와 두 번째 차원은 그대로 유지하고, 세 번째 차원에서 처음 seq_len (, tf.shape(q)[2]) 개의 값만 선택합니다. cos = cos[:, :, : tf.shape(q)[2], :] sin = sin[:, :, : tf.shape(q)[2], :] # 이제 rotary positional embedding을 적용합니다. # 각 query 벡터에 대해, cosine 값과 sine 값을 사용하여 회전 변환을 수행합니다. # 공식은 다음과 같습니다: # q_embed = (q * cos) + (rotate_half(q) * sin) # 이는, 각 원소를 두 부분으로 나누고, 이 두 부분에 대해 회전 행렬의 효과를 벡터화한 형태라고 볼 수 있습니다. q_embed = (q * cos) + (_rotate_half(q) * sin) # 동일하게, key 벡터에도 같은 변환을 적용합니다. k_embed = (k * cos) + (_rotate_half(k) * sin) # 최종적으로, 변환된 query와 key 텐서를 반환합니다. return q_embed, k_embed
JavaScript
복사

Test Suite 3. NTKScalingRotaryEmbedding

RoPE(Rotary Positional Embedding)의 장점 중 하나는, 긴 시퀀스까지 확장하기 위해 NTK 스케일링이라는 방식을 반영한다는 점입니다. 이 방식은 이론적으로 매우 긴 시퀀스에서도 위치 임베딩이 학습적으로 안정적이게 만들어 줍니다. ModernBERT는 NTKScalingRotaryEmbedding 클래스를 통해 이를 구현합니다. 이 클래스 내부에는 max_position_embeddings라는 최대 길이를 기준으로, 그 범위만큼 cos/sin 값을 미리 캐싱해 둡니다. 그리고 입력 시퀀스가 이보다 길어지면, 필요한 만큼 추가로 계산해 확장합니다.
test_ntk_scaling_rotary_embedding_shapes라는 테스트 함수에서는, 예를 들어 “dim=32, max_pos=50, base=160000.0, scaling_factor=1.0” 같은 파라미터로 NTKScalingRotaryEmbedding 객체를 생성한 뒤, call(x, seq_len=30)을 실행해 봅니다. 그리고 반환된 cos, sin 텐서의 shape가 (30, 32)인지 확인하고, 미리 캐싱해둔 cos/sin 값과 일치하는지도 점검합니다.
만약 seq_len이 60처럼 max_position_embeddings(50)보다 큰 값이 주어질 경우, 클래스가 새로운 캐시를 생성해 정확히 계산하는지, 그리고 그 결과가 모두 유한(finite) 값인지를 검사합니다. 이렇게 해서 긴 시퀀스에서도 RoPE가 안정적으로 동작함을 보장합니다.
Given: “NTKScalingRotaryEmbedding 인스턴스를 특정 파라미터들(예: dim=32, max_pos=50, base=160000.0 등)로 초기화한다” When: “seq_len이 max_pos 이하 또는 그보다 큰 경우로 나누어, 호출을 시도”합니다. Then: 결과에 대한 검증은 두 갈래로 나뉩니다. 만약 seq_len이 max_pos 이하라면, 이미 내부에 캐싱되어 있던 cos와 sin 중 원하는 길이만큼 슬라이스 한 결과를 돌려주어야 합니다. seq_len이 그보다 크다면, 새롭게 cos, sin을 계산해서 캐시를 확장한 뒤 반환해야 하고, 그 값들이 모두 유한(finite)해야만 합니다. 이 과정을 통해 아주 긴 시퀀스에서도 로타리 임베딩이 깨지지 않고 작동할 수 있음을 시나리오 차원에서 검증합니다.
@pytest.mark.parametrize( "dim, max_pos, base, scaling_factor, test_seq_len", [ ( 32, 50, 160000.0, 1.0, 30, ), # 케이스 1: test_seq_len(30)max_pos(50)보다 작음 → 캐시된 값 사용 기대 ( 32, 50, 160000.0, 1.0, 60, ), # 케이스 2: test_seq_len(60)max_pos(50)보다 큼 → 새로운 캐시 계산됨 ( 64, 100, 160000.0, 1.0, 80, ), # 케이스 3: test_seq_len(80)max_pos(100) 이하 → 캐시된 값 사용 기대 ], ) def test_ntk_scaling_rotary_embedding_shapes( dim, max_pos, base, scaling_factor, test_seq_len ): """ Given a NTKScalingRotaryEmbedding layer instantiated with specific parameters, and given a dummy input tensor with a specified sequence length, When the layer is called with that input and seq_len is provided, Then the returned cos and sin tensors should have shape (seq_len, dim), and if seq_len <= max_position_embeddings, they should match the cached values. """ # Given: NTKScalingRotaryEmbedding 인스턴스를 생성합니다. rotary_layer = NTKScalingRotaryEmbedding( dim=dim, # 임베딩 차원 max_position_embeddings=max_pos, # 최대 position embeddings base=base, # base (: 160000.0) scaling_factor=scaling_factor, # 스케일링 팩터 (: 1.0) ) # Given: 더미 입력 텐서를 생성합니다. 실제 값은 중요하지 않고, 단지 seq_len 정보를 전달하기 위한 용도입니다. dummy_input = tf.random.uniform((1, test_seq_len, dim), dtype=tf.float32) # When: NTKScalingRotaryEmbedding의 call 메서드를 호출하여, cos와 sin 값을 계산합니다. cos, sin = rotary_layer(dummy_input, seq_len=test_seq_len) # Then: 반환된 cos 텐서의 shape가 (test_seq_len, dim)인지 검증합니다. assert cos.shape == ( test_seq_len, dim, ), f"Expected cos shape ({test_seq_len}, {dim}), got {cos.shape}" # Then: 반환된 sin 텐서의 shape가 (test_seq_len, dim)인지 검증합니다. assert sin.shape == ( test_seq_len, dim, ), f"Expected sin shape ({test_seq_len}, {dim}), got {sin.shape}" # Then: test_seq_len이 max_pos 이하인 경우, 캐시된 cos와 sin의 슬라이스와 동일해야 합니다. if test_seq_len <= max_pos: # 캐시된 cos 값을 test_seq_len 만큼 슬라이스합니다. cached_cos = rotary_layer.cos_cached[:test_seq_len] # 캐시된 sin 값을 test_seq_len 만큼 슬라이스합니다. cached_sin = rotary_layer.sin_cached[:test_seq_len] # np.testing.assert_allclose를 사용하여 계산된 cos와 캐시된 cos가 거의 동일한지 확인합니다. np.testing.assert_allclose( cos.numpy(), cached_cos.numpy(), atol=1e-6, err_msg="cos values do not match cached values when seq_len <= max_pos", ) # np.testing.assert_allclose를 사용하여 계산된 sin와 캐시된 sin가 거의 동일한지 확인합니다. np.testing.assert_allclose( sin.numpy(), cached_sin.numpy(), atol=1e-6, err_msg="sin values do not match cached values when seq_len <= max_pos", ) else: # test_seq_len이 max_pos보다 클 경우, 새로운 캐시가 생성되므로 캐시와 비교할 수 없습니다. # 대신, 계산된 cos와 sin 값이 모두 유한한(finite) 값인지 확인합니다. assert np.all( np.isfinite(cos.numpy()) ), "cos contains non-finite values for seq_len > max_pos" assert np.all( np.isfinite(sin.numpy()) ), "sin contains non-finite values for seq_len > max_pos"
JavaScript
복사
ModernGTETFModel.py 사이에 정의된 NTKScalingRotaryEmbedding 클래스를 보면, self.cos_cached, self.sin_cached라는 텐서가 초기화되고, 필요 시 _compute_new_cache(seq_len)을 호출하여 새로 계산한 뒤, tf.cond를 통해 캐시에 덮어쓰는 형태를 취하고 있습니다.
class NTKScalingRotaryEmbedding(keras.layers.Layer): def __init__( self, dim, max_position_embeddings=8192, base=160000.0, scaling_factor=1.0, mixed_b=None, **kwargs, ): super().__init__(**kwargs) self.dim = dim self.max_position_embeddings = int(max_position_embeddings * scaling_factor) self.base = base self.scaling_factor = scaling_factor self.mixed_b = mixed_b # 짝수 인덱스에 대한 범위를 생성합니다. 예를 들어, dim=32이면 0,2,4,...,30 # 로타리 임베딩은 보통 임베딩 벡터의 두 개의 연속된 요소(: [cos, sin])를 한 쌍으로 사용합니다. # 그래서 전체 임베딩 차원 중 절반만큼의 주파수를 계산하면 됩니다. # 여기서는 tf.range를 사용하여 0부터 self.dim까지 2씩 증가하는 값들을 생성하여, 짝수 인덱스만 선택합니다. indices: tf.Tensor = tf.range(0, self.dim, 2, dtype=tf.float32) # 기본 base_inv_freq 계산: 1 / (base^(i/dim)) # 각 짝수 인덱스에 대해, base 값에 따른 역수(inv_freq)를 계산합니다. # indices / self.dim: # 각 인덱스를 임베딩 차원(self.dim)으로 나누어, 0에서 1 사이의 비율 값을 얻습니다. # 예를 들어, 인덱스 00/self.dim = 0, 인덱스 22/self.dim 등으로 계산됩니다. # tf.pow(self.base, (indices / self.dim)): # self.base(: 160000.0) 값을 위에서 계산한 비율만큼 거듭제곱합니다. # 이는 각 인덱스마다 서로 다른 주파수를 결정하는 역할을 합니다. # 1.0 / (...): # 이렇게 계산된 값의 역수를 취합니다. # 즉, inv_freq = 1 / (self.base^(index/self.dim))가 됩니다. # 이 역수 값들은 이후에 각 위치(position)와 곱해져, cos와 sin 값을 계산하는 데 사용됩니다. self.base_inv_freq = 1.0 / tf.pow( self.base, (indices / tf.cast(self.dim, tf.float32)) ) if self.mixed_b is None: scaled_base = self.base * self.scaling_factor self.scaled_inv_freq = 1.0 / tf.pow( scaled_base, (indices / tf.cast(self.dim, tf.float32)) ) self.scaled_inv_freq = self.scaled_inv_freq / tf.pow( self.scaling_factor, 2.0 / tf.cast(self.dim, tf.float32) ) else: pass self._build_initial_cache() def _build_initial_cache(self): """ 모델이 최대 시퀀스 길이(max_position_embeddings)까지 사용할 수 있도록, 각 위치에 대한 cosine과 sine 값을 미리 계산하여 저장해둡니다. 이렇게 미리 계산해두면, 모델이 실행될 때마다 동일한 값을 반복해서 계산할 필요 없이 캐시된 값을 재사용할 수 있으므로 효율적입니다. """ # 1. 0부터 max_position_embeddings - 1까지의 정수를 생성합니다. # 예를 들어 max_position_embeddings가 8192라면, 0, 1, 2, ..., 8191을 담은 텐서를 만듭니다. t = tf.range(self.max_position_embeddings, dtype=tf.float32) # 2. tf.einsum("i,j->ij", t, self.scaled_inv_freq)를 사용하여 # 각 위치 t와 scaled_inv_freq 벡터 간의 외적을 곱하여, 각 위치에 고유한 "각도"를 생성하기 위해서입니다. # - t의 shape는 (max_position_embeddings,)입니다. # - self.scaled_inv_freq의 shape는 (dim/2,) (예를 들어, 만약 dim이 768이면, (384,)가 될 것입니다.) # - 결과적으로 freqs의 shape는 (max_position_embeddings, dim/2)가 됩니다. freqs = tf.einsum("i,j->ij", t, self.scaled_inv_freq) # 3. 이제 freqs 텐서를 두 번 반복하여(concatenation) shape를 (max_position_embeddings, dim)으로 만듭니다. # 왜 두 번 반복하냐면, 로타리 임베딩에서는 각 위치마다 cosine과 sine 값을 적용하는데, # 보통 임베딩 차원(dim)2로 나누어 떨어지므로,절반(dim/2)씩 계산된 값을 두 번 이어 붙여서 전체 임베딩 차원(dim)을 채웁니다. emb = tf.concat([freqs, freqs], axis=-1) # 4. emb 텐서의 각 요소에 대해 cosine과 sine 값을 계산합니다. # 이때, emb의 각 원소는 (t * scaled_inv_freq)로 계산되었으므로, tf.cos(emb)와 tf.sin(emb)는 # 각 위치에 대한 고정된 cosine, sine 값이 됩니다. # 이렇게 계산된 값들을 self.cos_cached와 self.sin_cached에 저장해 두면, # 나중에 입력 시퀀스의 길이에 따라 슬라이스해서 사용할 수 있습니다. self.cos_cached = tf.cos(emb) self.sin_cached = tf.sin(emb) def _compute_new_cache(self, seq_len): # 1. 0부터 seq_len - 1까지의 정수를 생성합니다. # 예를 들어, seq_len이 9000이면, 0부터 8999까지의 텐서를 생성합니다. t = tf.range(seq_len, dtype=tf.float32) # 2. tf.einsum("i,j->ij", t, self.scaled_inv_freq)를 통해 # 각 위치에 대해 scaled_inv_freq 벡터와의 외적을 계산합니다. # 결과 freqs의 shape는 (seq_len, dim/2)가 됩니다. freqs = tf.einsum("i,j->ij", t, self.scaled_inv_freq) # 3. freqs를 두 번 이어 붙여 (concatenate) shape를 (seq_len, dim)으로 만듭니다. emb = tf.concat([freqs, freqs], axis=-1) # 4. emb에 대해 cosine과 sine 값을 계산하여 반환합니다. return tf.cos(emb), tf.sin(emb) def call(self, x, seq_len=None): # 1. 입력 x에서 시퀀스 길이를 결정합니다. # 만약 seq_len이 명시적으로 주어지지 않았다면, x의 두 번째 차원(시퀀스 길이)을 사용합니다. if seq_len is None: seq_len = tf.shape(x)[1] # 2. Eager 실행 모드인지 확인합니다. if tf.executing_eagerly(): # Eager 모드에서는 조건문(if)을 사용하여, 현재 시퀀스 길이가 캐시된 최대값보다 큰지 확인합니다. if seq_len > self.max_position_embeddings: # 만약 시퀀스 길이가 더 길다면, _compute_new_cache 메서드를 호출하여 새로운 cos, sin 값을 계산합니다. cos, sin = self._compute_new_cache(seq_len) else: # 그렇지 않으면, 미리 계산된 캐시(self.cos_cached, self.sin_cached)를 시퀀스 길이에 맞게 슬라이스합니다. cos = self.cos_cached[:seq_len] sin = self.sin_cached[:seq_len] else: # 3. 그래프 모드(비 Eager 모드)에서는 tf.cond를 사용하여 조건 분기를 수행합니다. # - tf.cond는 조건에 따라 두 개의 함수를 실행합니다. def use_new_cache(): return self._compute_new_cache(seq_len) def use_cached(): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] # tf.cond를 사용하여 seq_len이 max_position_embeddings보다 큰지 확인한 후, 해당하는 함수를 실행합니다. cos, sin = tf.cond( tf.greater(seq_len, self.max_position_embeddings), use_new_cache, use_cached, ) # 4. 입력 텐서 x와 데이터 타입을 맞추기 위해, cos와 sin 텐서를 x의 데이터 타입으로 캐스팅합니다. cos = tf.cast(cos, x.dtype) sin = tf.cast(sin, x.dtype) # 5. 최종적으로, 계산된 cos와 sin 텐서를 반환합니다. return cos, sin
JavaScript
복사

Test Suite 4. TransformerBlock

Transformer에서 가장 기본이 되는 블록은 크게 세 가지 요소로 이루어져 있습니다. 먼저 멀티헤드 어텐션(Multi-Head Attention) 을 적용하고, 그 결과에 residual(잔차 연결) 을 추가합니다. 그 다음에는 MLP(Feed-Forward Network) 를 적용하고 다시 residual을 더합니다. 마지막으로 LayerNorm을 수행해 블록이 하나의 완결된 단위가 됩니다. ModernBERT는 일반적인 Transformer 블록 구조 위에, 각 레이어마다 글로벌 어텐션과 로컬 어텐션을 구분해서 적용합니다. 이때 어떤 레이어에는 로컬 마스크를, 다른 레이어에는 글로벌 마스크를 쓸 것인지 는 layer_id를 보고 결정하도록 되어 있습니다.
test_transformer_block_output_shape라는 함수는, 예를 들어 “batch_size=1, seq_len=10, d_model=64, num_heads=8, intermediate_size=128, dropout=0.0, layer_id=0” 등으로 TransformerBlock을 초기화한 뒤, (batch_size, seq_len, d_model) 형태의 더미 데이터를 입력으로 넣습니다. 그리고 반환된 출력이 같은 형태로 잘 나오는지 확인해, 블록이 정상적으로 동작함을 검증합니다.
class DummyConfig: def __init__(self, d_model, num_heads, intermediate_size): self.hidden_size = d_model self.num_attention_heads = num_heads self.intermediate_size = intermediate_size self.global_rope_theta = 10000.0 # 글로벌 로타리 임베딩 base 값 self.max_position_embeddings = 512 # 전체 최대 시퀀스 길이 self.local_rope_theta = None # 로컬 어텐션 시 별도 값 없으면 None self.global_attn_every_n_layers = 2 # 예: 2번째 레이어마다 글로벌 어텐션 self.local_attention = 128 # 로컬 어텐션 윈도우 크기 @pytest.mark.parametrize( "batch_size, seq_len, d_model, num_heads, intermediate_size, dropout_rate, layer_id", [ ( 1, 10, 64, 8, 128, 0.0, 0, ), # Case 1: 작은 배치, dropout 없이, layer_id=0 (글로벌 어텐션 적용) ( 2, 15, 128, 8, 256, 0.0, 1, ), # Case 2: 중간 배치, dropout 없이, layer_id=1 (로컬 어텐션 적용) ( 3, 20, 64, 4, 128, 0.1, 1, ), # Case 3: dropout 적용, training=False로 실행하여도 dropout layer는 build됨 ], ) def test_transformer_block_output_shape( batch_size, seq_len, d_model, num_heads, intermediate_size, dropout_rate, layer_id ): """ Given a dummy input tensor of shape (batch_size, seq_len, d_model) and a DummyConfig with the specified parameters, When the TransformerBlock is invoked with a simple attention mask (all ones), Then the output tensor should have the same shape as the input tensor. """ # Given: d_model이 num_heads로 나누어 떨어지는지 확인합니다. if d_model % num_heads != 0: pytest.skip("d_model must be divisible by num_heads") # Given: DummyConfig 인스턴스를 생성합니다. config = DummyConfig(d_model, num_heads, intermediate_size) # Given: TransformerBlock 인스턴스를 생성합니다. # layer_id에 따라 내부에서 글로벌/로컬 설정이 달라집니다. transformer_block: TransformerBlock = TransformerBlock( d_model=d_model, num_heads=num_heads, intermediate_size=intermediate_size, dropout_rate=dropout_rate, layer_id=layer_id, config=config, ) # Given: 모델의 테스트용 더미 입력(dummy input)을 생성합니다. # tf.random.uniform 함수를 사용하여, 지정한 shape인 (batch_size, seq_len, d_model)에 맞게 균등 분포의 무작위 값들을 가진 텐서를 만듭니다. # 여기서 batch_size는 한 번에 처리할 데이터 개수, seq_len은 문장(또는 토큰 시퀀스)의 길이, d_model은 각 토큰의 임베딩 차원을 의미합니다. dummy_input = tf.random.uniform((batch_size, seq_len, d_model), dtype=tf.float32) # Given: 단순 attention mask를 생성합니다. # 이 부분에서는 모든 토큰이 유효하다는 가정 하에 attention mask를 만듭니다. # tf.ones를 사용하여, shape이 (batch_size, seq_len)2차원 텐서를 생성합니다. # 각 요소의 값은 1로 채워지는데, 이는 해당 위치의 토큰이 “유효(valid)”하다는 의미입니다. # 즉, 이 mask는 패딩 토큰이 없어서 모든 토큰에 대해 어텐션을 수행한다는 전제입니다. attn_mask_2d = tf.ones((batch_size, seq_len), dtype=tf.float32) # 이를 [batch_size, 1, seq_len]로 확장한 뒤, [batch_size, 1, seq_len, seq_len]로 타일링합니다. # 타일링할 때는 새로 추가한 차원을 기준으로, 텐서를 반복(tile)합니다. # [1, seq_len, 1] 인자는 첫 번째 차원은 그대로, 두 번째 차원을 seq_len 만큼 복제하여 (batch_size, seq_len, seq_len)이 되도록 합니다. shaped_mask = tf.tile(tf.expand_dims(attn_mask_2d, axis=1), [1, seq_len, 1]) # 다시 한 번 차원을 확장하여 최종적으로 shape을 (batch_size, 1, seq_len, seq_len)로 만듭니다. shaped_mask = tf.expand_dims(shaped_mask, axis=1) # mask 값: 유효 위치는 0.0, 패딩은 -; 여기선 모두 유효하므로 0.0 # shaped_mask와 동일한 shape의 텐서를 생성하되, 모든 값을 0.0으로 채웁니다. 이는 유효한 토큰 위치에 사용됩니다. zeros_mask = tf.zeros_like(shaped_mask, dtype=tf.float32) # tf.fill을 이용하여, shaped_mask와 동일한 shape의 텐서를 생성하고 모든 값을 tf.float32.min으로 채웁니다. # tf.float32.min은 float32 타입에서 표현 가능한 매우 작은 (실질적으로 -∞에 가까운 값)을 의미합니다. # 이는 패딩 토큰 위치에 사용되어, softmax 계산 시 해당 위치의 attention score를 0으로 만들도록 합니다. neg_inf_mask = tf.fill(tf.shape(shaped_mask), tf.float32.min) # tf.where 함수를 이용하여, shaped_mask의 각 위치가 1.0이면 zeros_mask의 해당 위치 값을 선택하고, 그렇지 않으면 neg_inf_mask의 값을 선택합니다. # 여기서는 shaped_mask가 모두 1로 채워져 있기 때문에, 최종적으로 모든 위치가 0.0이 됩니다. # 결과적으로, 이 simple_mask는 모든 토큰이 유효하여 패딩이 없음을 나타내며, 어텐션 연산 시 모든 위치에 대해 아무런 마스킹 효과가 없음을 의미합니다. simple_mask = tf.where(tf.equal(shaped_mask, 1.0), zeros_mask, neg_inf_mask) # When: TransformerBlock의 call 메서드를 호출합니다. # training=False로 호출하면 드롭아웃 등이 적용되지 않습니다. output = transformer_block(dummy_input, attention_mask=simple_mask, training=False) # Then: 출력 텐서의 shape가 입력 텐서와 동일해야 합니다. expected_shape = (batch_size, seq_len, d_model) assert ( output.shape == expected_shape ), f"Expected output shape {expected_shape}, but got {output.shape}" # Then: 함수형 검증 도구를 사용해 출력값이 모두 유한한지 확인합니다. # (Residual connection과 MLP 계산 후에도 수치 불안정성이 없어야 함) np.testing.assert_allclose( output.numpy(), output.numpy(), atol=1e-6, err_msg="Output values are not consistent or contain non-finite numbers.", )
JavaScript
복사
테스트 함수는 먼저 layer_id가 0이면 글로벌 어텐션, 1이면 로컬 어텐션이 적용되는지 확인합니다. 그리고 이때 생성된 마스크가 정확히 -∞ 처리가 되는지, 최종 출력 결과가 np.testing.assert_allclose 함수로 검사했을 때 유효 범위 내에 있는지 등을 꼼꼼히 살펴봅니다. 이렇게 해서 ModernBERT가 제공하는 로컬/글로벌 혼합 어텐션 방식이 제대로 동작하는지 점검하게 됩니다.
실제로 TransformerBlock 클래스 코드는 ModernGTETFModel.py에 있습니다. 이 코드를 보면, 초기화 시점에 해당 레이어 번호가 global_attn_every_n_layers의 배수인지 판별하여 글로벌 어텐션인지 로컬 어텐션인지를 결정합니다. 로컬 어텐션이라면 create_local_sliding_window_mask로부터 얻은 마스크를 덧씌워, 특정 범위 밖의 토큰은 어텐션 점수가 -∞가 되도록 만드는 식입니다.
MLP 부분은 GLU(Gated Linear Unit) 구조를 채택합니다. Input의 Weight (Wi) 를 2배로 만든 뒤, 그중 절반은 게이트 역할로 사용해 곱셈을 수행함으로써, 일반적인 MLP보다 더 유연하고 효율적으로 정보를 처리하도록 설계되어 있습니다.
class TransformerBlock(keras.layers.Layer): def __init__( self, d_model, num_heads, intermediate_size, dropout_rate=0.1, layer_id=0, config=None, **kwargs, ): super().__init__(**kwargs) self.config = config self.layer_id = layer_id # 1. MultiHeadAttention 레이어 생성 # d_model, num_heads, dropout_rate, layer_id를 전달하여 어텐션 연산에 필요한 레이어를 구성합니다. self.attention = MultiHeadAttention(d_model, num_heads, dropout_rate, layer_id) # 2. 로타리 임베딩 파라미터 설정 (글로벌/로컬 구분) # 기본적으로 글로벌 어텐션의 rope_theta와 max_position_embeddings를 사용 self.rope_theta = self.config.global_rope_theta self.max_position_embeddings = self.config.max_position_embeddings """ 기본적으로 config에 정의된 글로벌 로타리 임베딩 (global_rope_theta)와 최대 포지션 임베딩 길이를 사용합니다. 만약 현재 레이어가 글로벌 어텐션이 아닌 로컬 어텐션을 사용할 경우, 필요한 경우 local_rope_theta를 적용하고 최대 시퀀스 길이를 로컬 어텐션 윈도우 크기로 변경합니다. """ # 만약 현재 layer_id가 글로벌 어텐션이 적용되지 않는(, 로컬 어텐션인) 층이면 if (self.layer_id % self.config.global_attn_every_n_layers) != 0: # 만약 로컬 전용 rope_theta가 지정되어 있다면 이를 사용 if self.config.local_rope_theta is not None: self.rope_theta = self.config.local_rope_theta # 로컬 어텐션일 경우 최대 position은 config에서 지정한 local_attention 값으로 사용 self.max_position_embeddings = self.config.local_attention # 3. NTKScalingRotaryEmbedding 레이어 생성 """ 이 레이어는 입력에 대해 cos, sin 값을 미리 계산하여 로타리 임베딩을 제공합니다. 각 헤드마다 차원을 d_model // num_heads로 맞춰주며, 앞서 설정한 최대 포지션과 rope_theta 값을 사용합니다. """ self.rotary_emb = NTKScalingRotaryEmbedding( dim=int(d_model // num_heads), max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, name=f"rotary_embeddings_{layer_id}", ) # 4. 어텐션 드롭아웃 및 MLP 정규화 # 어텐션 출력에 적용할 Dropout 레이어 self.attention_dropout = keras.layers.Dropout(dropout_rate) # MLP에 들어가기 전 residual connection 이후 정규화를 위한 LayerNormalization self.mlp_norm = keras.layers.LayerNormalization(epsilon=1e-5, center=False) # 5. MLP 구성: Wi (GLU를 위해 2배 차원)Wo (최종 출력 투영) """ Transformer의 MLP 부분을 Dense 레이어를 사용해 구성합니다. 여기서는 GLU(Gated Linear Unit) 방식을 사용하기 위해, 두 배 크기의 출력 차원을 갖는 Dense 레이어를 먼저 생성한 후 최종 출력 Dense 레이어를 구성합니다. Wi는 입력을 2배 차원으로 변환하여, 이후 GLU에서 두 부분(x_in과 gate)으로 분할할 수 있도록 합니다. Wo는 GLU 처리 후 나온 결과를 원래 모델 차원(d_model)으로 다시 투영합니다. """ # Wi: 중간 차원의 2 (GLU를 위해 입력과 게이트를 분리하기 위함) self.Wi = keras.layers.Dense( intermediate_size * 2, name="intermediate.dense", use_bias=False ) # Wo: 최종 출력은 d_model 차원으로 투영 self.Wo = keras.layers.Dense(d_model, name="output.dense", use_bias=False) # MLP 출력 후 적용할 Dropout 레이어 self.output_dropout = keras.layers.Dropout(dropout_rate) def gelu_approx(self, x): return tf.nn.gelu(x) def call(self, hidden_states, attention_mask=None, training=False): batch_size = tf.shape(hidden_states)[0] seq_len = tf.shape(hidden_states)[1] # 1. 로타리 임베딩 계산 """ NTKScalingRotaryEmbedding 레이어를 호출하여, 현재 시퀀스 길이에 맞는 cos, sin 텐서를 얻습니다. 이 값들은 이후 MultiHeadAttention에 전달되어, query와 key에 로타리 임베딩을 적용합니다. """ # RoPE: rotary_emb 레이어를 사용해 cos, sin 값을 계산합니다. cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) # 계산된 cos, sin을 tuple 형태로 저장합니다. rope_embeds = (cos, sin) # 2. 로컬 vs 글로벌 어텐션 판단 # 만약 현재 레이어가 글로벌 어텐션이 적용되지 않는(로컬 어텐션) 경우, window_size를 설정합니다. window_size = None if (self.layer_id % self.config.global_attn_every_n_layers) != 0: window_size = self.config.local_attention # 예: 128 # (2) 어텐션 마스크 처리: # 만약 window_size가 지정되면, 기존의 글로벌 마스크를 로컬 슬라이딩 윈도우 마스크로 변환합니다. # 그렇지 않으면, 글로벌 마스크(모든 토큰 attending)를 그대로 사용합니다. if window_size is not None and window_size > 0: attention_mask = create_local_sliding_window_mask( attention_mask, window_size ) """ MultiHeadAttention 레이어를 호출하여, 어텐션 결과를 얻습니다. 입력(hidden_states)과 어텐션 결과를 더해 residual connection을 구성합니다. """ # 4. MultiHeadAttention 수행 및 Residual 연결 attn_output = self.attention( hidden_states, mask=attention_mask, rope_embeds=rope_embeds, training=training, ) # Residual connection: 입력에 어텐션 출력(attn_output)을 더해줍니다. hidden_states = hidden_states + attn_output # 5. MLP 단계: LayerNormalization → Wi Dense → GLU → dropout → Wo Dense → Residual 연결 normed = self.mlp_norm(hidden_states) # Wi Dense 레이어를 통해 2배 차원의 출력을 생성합니다. mlp_out = self.Wi(normed) # GLU: 출력 mlp_out를 두 부분(x_in, gate)으로 분할합니다. x_in, gate = tf.split(mlp_out, 2, axis=-1) # 활성화 함수 GELU를 x_in에 적용합니다. x_in = tf.nn.gelu(x_in) # x_in과 gate의 element-wise 곱을 통해 GLU 결과를 얻습니다. mlp_out = x_in * gate # training일 경우 dropout을 적용합니다. if training: mlp_out = self.output_dropout(mlp_out, training=training) # Wo Dense 레이어를 통해 최종 출력 차원 d_model으로 투영합니다. mlp_out = self.Wo(mlp_out) # Residual connection: 어텐션 후 결과에 MLP 출력을 더합니다. hidden_states = hidden_states + mlp_out return hidden_states
JavaScript
복사

맺음말

해당 프로젝트를 이용해 모델의 Weight을 변환하고 변환된 TensorFlow 모델과 PyTorch 모델의 결과값을 비교하기 위해서는 다음의 명령어를 이용하시면 됩니다.
uv run ModernGTETFWeightConverter.py uv run model_conversion_validator.py Encoder Layer 22: PyTorch shape: torch.Size([2, 842, 768]) dims: [batch_size=2, seq_len=842, hidden_dim=768] TensorFlow shape: (2, 842, 768) dims: [batch_size=2, seq_len=842, hidden_dim=768] -> MSE: 32.437252 -> CLS Token Cosine Similarity: 1.000000 [TensorFlow] Final Embeddings Shape: (2, 768) === 3) PT vs. TF 최종 임베딩 비교 === [[ 0.22076873 -1.3506409 -2.0690584 ... -0.38223675 1.0725555 -0.5355745 ] [ 0.12727773 -1.3734885 -2.333182 ... -0.3719053 1.4427259 -0.78435 ]] [[ 0.2207656 -1.350643 -2.0690598 ... -0.38224083 1.0725522 -0.5355765 ] [ 0.12727615 -1.3734777 -2.3331754 ... -0.37190926 1.442732 -0.7843542 ]] ===== Queries ===== [0] 이 모델은 무엇을 하는 모델인가요? [1] 이 모델은 무엇을 하는 모델인가요?이 모델은 무엇을 ... ===== PyTorch Embeddings (shape) ===== (2, 768) ===== TF Embeddings (shape) ===== (2, 768) ===== Pairwise Cosine Similarity (PT vs TF) ===== Query 0 Cosine Similarity: 1.0000 Query 1 Cosine Similarity: 1.0000 ===== MSE (PT vs TF) ===== MSE: 0.000000 ===== Sample Differences (first query, first 5 dims) ===== [ 3.1292439e-06 2.1457672e-06 1.4305115e-06 -1.4305115e-06 2.6226044e-06]
JavaScript
복사
현업에서 모델을 처음 다룰 때, 특히나 이렇게 구조가 복잡한 모델이라면 모델링 코드만으로 이해하기 쉽지 않습니다. 따라서 이 튜토리얼을 읽으시는 독자 여러분 들께서는 다음의 흐름대로 실습해 보는 것을 추천합니다.
테스트 먼저 읽기: 각 테스트 파일을 열어, 어떤 식으로 함수명이 지어졌고(test_…), 어떤 인자를 주고(파라미터), 어떤 결과를 확인하고 있는지(assert) 꼼꼼히 살펴본다. 관련 구현 코드 확인: 이 테스트가 검증하려는 대상 이 어디에 구현되어 있는지를 뒤쫓아가면서, 실제 로직을 파악한다. 파일 이름과 라인번호를 잘 추적해 본다. 일부러 테스트 깨뜨려 보기: 예컨대 로컬 마스크 함수 안에서 distance 계산을 잘못 수정해 보거나, RoPE 적용에서 sin/cos 축을 반대로 바꿔 본다. 테스트를 재실행해 어떤 메시지로 실패하는지 를 읽어 보면, 코드 동작 이해가 한층 빨라진다.
이제 PyTorch 모델을 성공적으로 TensorFlow 모델로 변환하셨나요? 축하드립니다. 이제 여러분의 모델을 Python 뿐만 아니라 JavaScript, JVM 등 다양한 환경에서 서빙해보세요!