Post

Llama 3 implemented in pure NumPy

๐Ÿฆ™ Llama 3 Implemented in Pure NumPy: Understanding LLMs from First Principles

Curiosity: How can we understand LLM architectures at the most fundamental level? What insights can we retrieve by implementing complex models using only basic NumPy operations?

Implementing Llama 3 in pure NumPy offers a unique opportunity to understand transformer architectures from first principles. This educational implementation, inspired by Andrej Karpathyโ€™s work, provides clarity and interpretability that high-level frameworks often obscure.

Why Pure NumPy Implementation?

graph LR
    A[Understanding LLMs] --> B[High-Level Frameworks]
    A --> C[Pure NumPy]
    
    B --> B1[Fast Development]
    B --> B2[Abstraction Layers]
    B --> B3[Less Visibility]
    
    C --> C1[Full Control]
    C --> C2[Clear Operations]
    C --> C3[Educational Value]
    
    style A fill:#e1f5ff
    style C fill:#fff3cd
    style B3 fill:#f8d7da
    style C3 fill:#d4edda

Llama 3 Training Scale

Retrieve: The scale of Llama 3 training reveals the computational resources required for state-of-the-art models.

ResourceQuantityImpact
GPUs24,000Massive parallel processing
Training Data15T tokensComprehensive knowledge base
Instruction Data10M samplesFine-tuning for alignment
GPU Hours1.3M hoursExtensive compute investment

Architecture Overview

Key Insight: Despite transitioning to GQA (Grouped Query Attention), the model structure remains unchanged from Llama 2, making it a familiar yet powerful framework.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np

class Llama3Attention:
    """Simplified Llama 3 attention mechanism in NumPy"""
    
    def __init__(self, dim, num_heads, head_dim):
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # Weight matrices
        self.q_proj = np.random.randn(dim, num_heads * head_dim) * 0.02
        self.k_proj = np.random.randn(dim, num_heads * head_dim) * 0.02
        self.v_proj = np.random.randn(dim, num_heads * head_dim) * 0.02
        self.o_proj = np.random.randn(num_heads * head_dim, dim) * 0.02
    
    def forward(self, x, mask=None):
        """Forward pass of attention"""
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = x @ self.q_proj  # [batch, seq, num_heads * head_dim]
        K = x @ self.k_proj
        V = x @ v_proj
        
        # Reshape for multi-head attention
        Q = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # Scaled dot-product attention
        scores = np.einsum('bshd,bthd->bsht', Q, K) / np.sqrt(self.head_dim)
        
        if mask is not None:
            scores = np.where(mask, scores, -np.inf)
        
        attn_weights = self.softmax(scores)
        output = np.einsum('bsht,bthd->bshd', attn_weights, V)
        
        # Reshape and project output
        output = output.reshape(batch_size, seq_len, -1)
        return output @ self.o_proj
    
    def softmax(self, x):
        """Numerically stable softmax"""
        exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

GQA (Grouped Query Attention) Implementation

Note: While GQA is incorporated into the code structure, itโ€™s not applied to model behavior in this educational implementation, ensuring enhanced interpretability.

graph TB
    A[Input Tokens] --> B[Embedding Layer]
    B --> C[Transformer Blocks]
    C --> D[Attention Layer]
    D --> E[GQA Mechanism]
    E --> F[Feed Forward]
    F --> G[Output]
    
    H[Positional Encoding] --> C
    I[Layer Norm] --> D
    I --> F
    
    style A fill:#e1f5ff
    style D fill:#fff3cd
    style E fill:#d4edda
    style G fill:#f8d7da

Implementation Benefits

AspectBenefitImpact
EducationalClear understanding of operationsโฌ†๏ธ Learning
InterpretabilitySee every computation stepโฌ†๏ธ Debugging
PortabilityNo framework dependenciesโฌ†๏ธ Accessibility
ClarityIntuitive model structureโฌ†๏ธ Understanding

Key Implementation Details

1. Model Conversion

  • Leveraging stories15M model trained by Andrej Karpathy
  • Converting to NumPy compressed format
  • Maintaining clarity and precision

2. Architecture Preservation

  • Llama 2 structure compatibility
  • GQA integration (structural, not behavioral)
  • Transformer block implementation

3. Educational Focus

  • Step-by-step operations
  • Clear mathematical formulations
  • Practical examples

Comparison: Framework vs. Pure NumPy

FeaturePyTorch/TensorFlowPure NumPy
Speedโšกโšกโšก Very Fastโšก Slower
GPU Supportโœ… NativeโŒ CPU only
AbstractionHighLow
Understandingโš ๏ธ Limitedโœ… Complete
Educational Valueโš ๏ธ Mediumโœ… High

Use Cases

Retrieve: Pure NumPy implementations are valuable for:

  • Educational purposes
  • Understanding model internals
  • Debugging and verification
  • Research and experimentation

Innovate: By understanding the fundamentals, you can:

  • Optimize implementations
  • Create custom architectures
  • Debug complex issues
  • Build domain-specific models

Resources

๐Ÿง‘โ€๐Ÿ’ป Code Repository: https://github.com/likejazz/llama3.np

Key Features:

  • Pure NumPy implementation
  • Educational focus
  • Clear documentation
  • Karpathy-inspired approach

Key Takeaways

Retrieve: Implementing Llama 3 in pure NumPy provides deep insights into transformer architectures, revealing the mathematical operations that power modern LLMs.

Innovate: By understanding these fundamentals, you can innovate on architectures, optimize implementations, and build custom solutions tailored to specific needs.

Curiosity โ†’ Retrieve โ†’ Innovation: Start with curiosity about how LLMs work, retrieve knowledge through hands-on implementation, and innovate by applying these insights to new problems.

Next Steps:

  • Explore the GitHub repository
  • Run the implementation
  • Modify and experiment
  • Build your own variations
Translate to Korean

๐Ÿฆ™ ์ˆœ์ˆ˜ NumPy๐Ÿ‘ฉ ๐Ÿ”ฌ๋กœ ๊ตฌํ˜„๋œ ๋ผ๋งˆ 3

๐Ÿš€ ํฅ๋ฏธ ์ง„์ง„ํ•œ ๋ฐœ๊ฒฌ! @Andrej Karpathy์—์„œ ์˜๊ฐ์„ ๋ฐ›์•„ NumPy์—์„œ ๊ตฌํ˜„ ๋œ Llama 3 ๋ชจ๋ธ์— ๋Œ€ํ•œ ํฅ๋ฏธ๋กœ์šด ๊ธฐ์‚ฌ๋ฅผ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค. AI at Meta ์˜ ๋ผ๋งˆ 3 ๋ชจ๋ธ์€ ์ธ์ƒ์ ์ธ ๊ทœ๋ชจ์™€ ์„ฑ๋Šฅ์œผ๋กœ ํŒŒ์žฅ์„ ์ผ์œผํ‚ค๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๐ŸŒŸ

๐Ÿง‘ ์ฝ”๋“œ : https://github.com/likejazz/llama3.np

๐Ÿ” 24K GPU, 15T ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ, 10M ๋ช…๋ น ๋ฐ์ดํ„ฐ ๋ฐ 1.3M GPU ์‹œ๊ฐ„์„ ์‚ฌ์šฉํ•˜๋ฉด ๊ทธ ์ˆ˜์น˜๋Š” ์ •๋ง ์••๋„์ ์ž…๋‹ˆ๋‹ค. GQA๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ๋กœ ์ „ํ™˜ํ–ˆ์Œ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ๋ชจ๋ธ ๊ตฌ์กฐ๋Š” Llama 2์—์„œ ๋ณ€๊ฒฝ๋˜์ง€ ์•Š์•„ ์นœ์ˆ™ํ•˜๋ฉด์„œ๋„ ๊ฐ•๋ ฅํ•œ ํ”„๋ ˆ์ž„์›Œํฌ๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

๐Ÿง  ์ดํ•ด๋ฅผ ๋•๊ธฐ ์œ„ํ•ด ์ €์ž๋Š” NumPy๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ •ํ™•ํ•˜๊ฒŒ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐ ์ค‘์ ์„ ๋‘๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. Andrej Karpathy๊ฐ€ ํ›ˆ๋ จํ•œ stories15M ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜์—ฌ ๋ณด๋‹ค ์ง๊ด€์ ์ธ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ์œ„ํ•ด NumPy ์••์ถ• ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. Karpathy๊ฐ€ ํ›ˆ๋ จํ•œ Llama 2 ๋ชจ๋ธ์„ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์ฝ”๋“œ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ ‘๊ทผ ๋ฐฉ์‹์˜ ๋ช…ํ™•์„ฑ๊ณผ ์ •๋ฐ€๋„๋ฅผ ์œ ์ง€ํ•˜๋Š” ๋™์•ˆ ๊ณ„์† ์ง€์ผœ๋ด ์ฃผ์‹ญ์‹œ์˜ค.

๐Ÿ“Š GQA๋ฅผ ์ฝ”๋“œ์— ํ†ตํ•ฉํ•˜๋Š” ๋™์•ˆ ์ž‘์„ฑ์ž๋Š” GQA๋ฅผ ๋ชจ๋ธ ๋™์ž‘์— ์ ์šฉํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ํ•ด์„ ๊ฐ€๋Šฅ์„ฑ์„ ๋†’์ด๊ธฐ ์œ„ํ•ด NumPy๋ฅผ ์›ํ™œํ•˜๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ํ˜์‹ ์ ์ธ ์ ‘๊ทผ ๋ฐฉ์‹์— ๋Œ€ํ•œ ๋” ๋งŽ์€ ํ†ต์ฐฐ๋ ฅ์„ ๊ณ„์† ์ง€์ผœ๋ด ์ฃผ์‹ญ์‹œ์˜ค!

This post is licensed under CC BY 4.0 by the author.